mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
230 lines
5.9 KiB
Go
230 lines
5.9 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/sessions"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
type TraefikOidc struct {
|
|
next http.Handler
|
|
name string
|
|
store sessions.Store
|
|
redirURLPath string
|
|
issuerURL string
|
|
jwkCache *JWKCache
|
|
tokenBlacklist *TokenBlacklist
|
|
jwksURL string
|
|
clientID string
|
|
clientSecret string
|
|
authURL string
|
|
tokenURL string
|
|
scopes []string
|
|
limiter *rate.Limiter
|
|
forceHTTPS bool
|
|
scheme string
|
|
tokenCache *TokenCache
|
|
httpClient HTTPClient
|
|
logger Logger
|
|
}
|
|
|
|
type ProviderMetadata struct {
|
|
Issuer string `json:"issuer"`
|
|
AuthURL string `json:"authorization_endpoint"`
|
|
TokenURL string `json:"token_endpoint"`
|
|
JWKSURL string `json:"jwks_uri"`
|
|
}
|
|
|
|
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
|
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
|
|
store.Options = &sessions.Options{
|
|
Path: "/",
|
|
MaxAge: 3600,
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
}
|
|
|
|
metadata, err := discoverProviderMetadata(config.ProviderURL, &http.Client{})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to discover provider metadata: %w", err)
|
|
}
|
|
logger := NewLogger(config.LogLevel)
|
|
|
|
t := &TraefikOidc{
|
|
next: next,
|
|
name: name,
|
|
store: store,
|
|
redirURLPath: config.CallbackURL,
|
|
issuerURL: metadata.Issuer,
|
|
tokenBlacklist: NewTokenBlacklist(),
|
|
jwkCache: &JWKCache{},
|
|
jwksURL: metadata.JWKSURL,
|
|
clientID: config.ClientID,
|
|
clientSecret: config.ClientSecret,
|
|
forceHTTPS: config.ForceHTTPS,
|
|
authURL: metadata.AuthURL,
|
|
tokenURL: metadata.TokenURL,
|
|
scopes: config.Scopes,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 100),
|
|
tokenCache: NewTokenCache(),
|
|
httpClient: &http.Client{},
|
|
logger: logger,
|
|
}
|
|
|
|
t.startTokenCleanup()
|
|
return t, nil
|
|
}
|
|
|
|
func discoverProviderMetadata(providerURL string, httpClient HTTPClient) (*ProviderMetadata, error) {
|
|
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
|
|
resp, err := httpClient.Get(wellKnownURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("failed to fetch provider metadata: status code %d", resp.StatusCode)
|
|
}
|
|
|
|
var metadata ProviderMetadata
|
|
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
|
|
return nil, fmt.Errorf("failed to decode provider metadata: %w", err)
|
|
}
|
|
|
|
return &metadata, nil
|
|
}
|
|
|
|
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
t.scheme = t.determineScheme(req)
|
|
host := t.determineHost(req)
|
|
|
|
redirectURL := buildFullURL(t.scheme, host, t.redirURLPath)
|
|
t.logger.Infof("Final redirect URL: %s", redirectURL)
|
|
|
|
session, err := t.store.Get(req, cookieName)
|
|
if err != nil {
|
|
t.logger.Errorf("Error getting session: %v", err)
|
|
http.Error(rw, "Session error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if req.URL.Path == t.redirURLPath {
|
|
t.logger.Infof("Handling callback, URL: %s", req.URL.String())
|
|
authSuccess, originalPath := t.handleCallback(rw, req)
|
|
if authSuccess {
|
|
http.Redirect(rw, req, originalPath, http.StatusFound)
|
|
return
|
|
}
|
|
http.Error(rw, "Authentication failed", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
if t.isUserAuthenticated(session) {
|
|
t.next.ServeHTTP(rw, req)
|
|
return
|
|
}
|
|
|
|
// User is not authenticated, start the auth process
|
|
t.initiateAuthentication(rw, req, session, redirectURL)
|
|
}
|
|
|
|
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
|
scheme := req.URL.Scheme
|
|
if scheme == "" {
|
|
scheme = req.Header.Get("X-Forwarded-Proto")
|
|
}
|
|
if scheme == "" {
|
|
if req.TLS != nil {
|
|
scheme = "https"
|
|
} else {
|
|
scheme = "http"
|
|
}
|
|
}
|
|
if t.forceHTTPS {
|
|
scheme = "https"
|
|
}
|
|
return scheme
|
|
}
|
|
|
|
func (t *TraefikOidc) determineHost(req *http.Request) string {
|
|
host := req.URL.Host
|
|
if host == "" {
|
|
host = req.Header.Get("X-Forwarded-Host")
|
|
}
|
|
if host == "" {
|
|
host = req.Host
|
|
}
|
|
return host
|
|
}
|
|
|
|
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) bool {
|
|
authenticated, _ := session.Values["authenticated"].(bool)
|
|
if authenticated {
|
|
idToken, ok := session.Values["id_token"].(string)
|
|
if !ok || idToken == "" {
|
|
return false
|
|
}
|
|
return t.verifyToken(idToken) == nil
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
|
|
csrfToken := uuid.New().String()
|
|
session.Values["csrf"] = csrfToken
|
|
session.Values["incoming_path"] = req.URL.Path
|
|
t.logger.Infof("Setting CSRF token: %s", csrfToken)
|
|
|
|
if err := session.Save(req, rw); err != nil {
|
|
t.logger.Errorf("Failed to save session: %v", err)
|
|
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
nonce, err := generateNonce()
|
|
if err != nil {
|
|
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
|
|
http.Redirect(rw, req, authURL, http.StatusFound)
|
|
}
|
|
|
|
func (t *TraefikOidc) verifyToken(token string) error {
|
|
return t.verifyAndCacheToken(token)
|
|
}
|
|
|
|
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
|
params := url.Values{
|
|
"client_id": {t.clientID},
|
|
"response_type": {"code"},
|
|
"redirect_uri": {redirectURL},
|
|
"scope": {strings.Join(t.scopes, " ")},
|
|
"state": {state},
|
|
"nonce": {nonce},
|
|
}
|
|
|
|
return fmt.Sprintf("%s?%s", t.authURL, params.Encode())
|
|
}
|
|
|
|
func (t *TraefikOidc) startTokenCleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
go func() {
|
|
for range ticker.C {
|
|
t.tokenCache.Cleanup()
|
|
t.tokenBlacklist.Cleanup()
|
|
}
|
|
}()
|
|
}
|