diff --git a/cache.go b/cache.go index 14038f1..759ec1a 100644 --- a/cache.go +++ b/cache.go @@ -5,22 +5,26 @@ import ( "time" ) +// CacheItem represents an item in the cache type CacheItem struct { Value interface{} ExpiresAt time.Time } +// Cache is a simple in-memory cache type Cache struct { items map[string]CacheItem mutex sync.RWMutex } +// NewCache creates a new Cache func NewCache() *Cache { return &Cache{ items: make(map[string]CacheItem), } } +// Set adds an item to the cache func (c *Cache) Set(key string, value interface{}, expiration time.Duration) { c.mutex.Lock() defer c.mutex.Unlock() @@ -30,6 +34,7 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) { } } +// Get retrieves an item from the cache func (c *Cache) Get(key string) (interface{}, bool) { c.mutex.RLock() defer c.mutex.RUnlock() @@ -44,12 +49,14 @@ func (c *Cache) Get(key string) (interface{}, bool) { return item.Value, true } +// Delete removes an item from the cache func (c *Cache) Delete(key string) { c.mutex.Lock() defer c.mutex.Unlock() delete(c.items, key) } +// Cleanup removes expired items from the cache func (c *Cache) Cleanup() { c.mutex.Lock() defer c.mutex.Unlock() diff --git a/helpers.go b/helpers.go index 73d514f..61d97d2 100644 --- a/helpers.go +++ b/helpers.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "io" "net/http" "net/url" "strings" @@ -16,6 +17,7 @@ import ( "github.com/gorilla/sessions" ) +// generateNonce generates a random nonce func generateNonce() (string, error) { nonceBytes := make([]byte, 32) _, err := rand.Read(nonceBytes) @@ -25,6 +27,7 @@ func generateNonce() (string, error) { return base64.URLEncoding.EncodeToString(nonceBytes), nil } +// buildFullURL constructs a full URL from scheme, host, and path func buildFullURL(scheme, host, path string) string { if scheme == "" { scheme = "http" @@ -32,7 +35,8 @@ func buildFullURL(scheme, host, path string) string { return fmt.Sprintf("%s://%s%s", scheme, host, path) } -func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (map[string]interface{}, error) { +// exchangeTokens exchanges a code or refresh token for tokens +func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) { data := url.Values{ "grant_type": {grantType}, "client_id": {t.clientID}, @@ -58,14 +62,20 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken } defer resp.Body.Close() - var result map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var tokenResponse TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { return nil, fmt.Errorf("failed to decode token response: %w", err) } - return result, nil + return &tokenResponse, nil } +// TokenResponse represents the response from the token endpoint type TokenResponse struct { IDToken string `json:"id_token"` AccessToken string `json:"access_token"` @@ -74,47 +84,20 @@ type TokenResponse struct { TokenType string `json:"token_type"` } +// getNewTokenWithRefreshToken refreshes the token using the refresh token func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { ctx := context.Background() - result, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "") + tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "") if err != nil { return nil, fmt.Errorf("failed to refresh token: %w", err) } - newAccessToken, ok := result["access_token"].(string) - if !ok || newAccessToken == "" { - return nil, fmt.Errorf("no access_token field in token response") - } + t.logger.Debugf("Token response: %+v", tokenResponse) - rawIDToken, ok := result["id_token"].(string) - if !ok || rawIDToken == "" { - return nil, fmt.Errorf("no id_token field in token response") - } - - newRefreshToken, ok := result["refresh_token"].(string) - if !ok || newRefreshToken == "" { - return nil, fmt.Errorf("no refresh_token field in token response") - } - - response := &TokenResponse{ - IDToken: rawIDToken, - AccessToken: newAccessToken, - ExpiresIn: int(result["expires_in"].(float64)), - TokenType: result["token_type"].(string), - } - - // The refresh token might not be returned if it hasn't changed - if newRefreshToken != refreshToken { - response.RefreshToken = newRefreshToken - } else { - response.RefreshToken = refreshToken - } - - t.logger.Debug("Token response: %+v", response) - - return response, nil + return tokenResponse, nil } +// handleLogout handles the user logout func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { session, err := t.store.Get(req, cookieName) t.logger.Debugf("Logging out user") @@ -123,28 +106,41 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { return } - if idToken, ok := session.Values["id_token"].(string); ok { - err := t.RevokeTokenWithProvider(idToken) - if err != nil { - handleError(rw, "Failed to revoke token", http.StatusInternalServerError, t.logger) - return + // Revoke tokens if available + if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" { + if err := t.RevokeTokenWithProvider(refreshToken, "refresh_token"); err != nil { + t.logger.Errorf("Failed to revoke refresh token: %v", err) } - t.RevokeToken(idToken) + t.RevokeToken(refreshToken) + } + if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" { + if err := t.RevokeTokenWithProvider(accessToken, "access_token"); err != nil { + t.logger.Errorf("Failed to revoke access token: %v", err) + } + t.RevokeToken(accessToken) } + // Remove tokens from session + delete(session.Values, "id_token") + delete(session.Values, "refresh_token") + delete(session.Values, "access_token") + delete(session.Values, "authenticated") + + // Set session options to delete the session session.Options = defaultSessionOptions - // Clear the session session.Options.MaxAge = -1 - session.Values = make(map[interface{}]interface{}) - err = session.Save(req, rw) - if err != nil { + + if err := session.Save(req, rw); err != nil { handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger) return } - http.Error(rw, "Logged out", http.StatusForbidden) + // Redirect or display logout message + rw.WriteHeader(http.StatusOK) + rw.Write([]byte("Logged out successfully")) } +// handleExpiredToken handles the case when a token has expired func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) { // Clear the existing session session.Options.MaxAge = -1 @@ -169,6 +165,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque t.initiateAuthenticationFunc(rw, req, session, t.redirectURL) } +// handleCallback handles the callback from the OIDC provider func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) { session, err := t.store.Get(req, cookieName) if err != nil { @@ -179,6 +176,34 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) t.logger.Debugf("Handling callback, URL: %s", req.URL.String()) + // Check for errors in the query parameters + if req.URL.Query().Get("error") != "" { + errorDescription := req.URL.Query().Get("error_description") + t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription) + http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest) + return + } + + // Validate the state parameter matches the session's CSRF token + state := req.URL.Query().Get("state") + if state == "" { + t.logger.Error("No state in callback") + http.Error(rw, "State parameter missing in callback", http.StatusBadRequest) + return + } + csrfToken, ok := session.Values["csrf"].(string) + if !ok || csrfToken == "" { + t.logger.Error("CSRF token missing in session") + http.Error(rw, "CSRF token missing", http.StatusBadRequest) + return + } + if state != csrfToken { + t.logger.Error("State parameter does not match CSRF token in session") + http.Error(rw, "Invalid state parameter", http.StatusBadRequest) + return + } + + // Proceed to exchange the code for tokens code := req.URL.Query().Get("code") if code == "" { t.logger.Error("No code in callback") @@ -186,20 +211,29 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) return } - token, err := t.exchangeCodeForTokenFunc(code) + tokenResponse, err := t.exchangeCodeForTokenFunc(code) if err != nil { t.logger.Errorf("Failed to exchange code for token: %v", err) http.Error(rw, "Authentication failed", http.StatusInternalServerError) return } - idToken, ok := token["id_token"].(string) - if !ok || idToken == "" { + // Extract id_token + idToken := tokenResponse.IDToken + if idToken == "" { t.logger.Error("No id_token in token response") http.Error(rw, "Authentication failed", http.StatusInternalServerError) return } + // Verify the id_token + if err := t.verifyToken(idToken); err != nil { + t.logger.Errorf("Failed to verify id_token: %v", err) + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + // Extract claims from id_token claims, err := t.extractClaimsFunc(idToken) if err != nil { t.logger.Errorf("Failed to extract claims: %v", err) @@ -207,6 +241,26 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) return } + // Verify the nonce claim matches the one stored in session + nonceClaim, ok := claims["nonce"].(string) + if !ok || nonceClaim == "" { + t.logger.Error("Nonce claim missing in id_token") + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + sessionNonce, ok := session.Values["nonce"].(string) + if !ok || sessionNonce == "" { + t.logger.Error("Nonce not found in session") + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + if nonceClaim != sessionNonce { + t.logger.Error("Nonce claim does not match session nonce") + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + // Get the email from claims email, _ := claims["email"].(string) if email == "" || !t.isAllowedDomain(email) { t.logger.Errorf("Invalid or disallowed email: %s", email) @@ -214,11 +268,17 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) return } + // Store tokens and authentication status in session session.Values["authenticated"] = true session.Values["email"] = email session.Values["id_token"] = idToken + session.Values["refresh_token"] = tokenResponse.RefreshToken session.Options = defaultSessionOptions + // Remove CSRF and nonce from session + delete(session.Values, "csrf") + delete(session.Values, "nonce") + if err := session.Save(req, rw); err != nil { t.logger.Errorf("Failed to save session: %v", err) http.Error(rw, "Failed to save session", http.StatusInternalServerError) @@ -226,16 +286,17 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) } t.logger.Debugf("Authentication successful. User email: %s", email) - http.Redirect(rw, req, func() string { - if path, ok := session.Values["incoming_path"].(string); ok { - t.logger.Debug("Redirecting to incoming path from original request: %s", path) - return path - } - t.logger.Debug("Redirecting to root path as no incoming path found") - return "/" - }(), http.StatusFound) + + // Redirect to the original requested path or default to root + redirectPath := "/" + if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath { + t.logger.Debugf("Redirecting to incoming path from original request: %s", path) + redirectPath = path + } + http.Redirect(rw, req, redirectPath, http.StatusFound) } +// extractClaims extracts claims from a JWT token func extractClaims(tokenString string) (map[string]interface{}, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { @@ -255,23 +316,27 @@ func extractClaims(tokenString string) (map[string]interface{}, error) { return claims, nil } +// TokenBlacklist maintains a blacklist of tokens type TokenBlacklist struct { blacklist map[string]time.Time mutex sync.RWMutex } +// NewTokenBlacklist creates a new TokenBlacklist func NewTokenBlacklist() *TokenBlacklist { return &TokenBlacklist{ blacklist: make(map[string]time.Time), } } +// Add adds a token to the blacklist func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) { tb.mutex.Lock() defer tb.mutex.Unlock() tb.blacklist[tokenID] = expiration } +// IsBlacklisted checks if a token is blacklisted func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool { tb.mutex.RLock() defer tb.mutex.RUnlock() @@ -279,6 +344,7 @@ func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool { return exists && time.Now().Before(expiration) } +// Cleanup removes expired tokens from the blacklist func (tb *TokenBlacklist) Cleanup() { tb.mutex.Lock() defer tb.mutex.Unlock() @@ -290,26 +356,25 @@ func (tb *TokenBlacklist) Cleanup() { } } +// TokenCache caches tokens type TokenCache struct { cache *Cache } -type TokenInfo struct { - Token string - ExpiresAt time.Time -} - +// NewTokenCache creates a new TokenCache func NewTokenCache() *TokenCache { return &TokenCache{ cache: NewCache(), } } +// Set sets a token in the cache func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) { token = "t-" + token tc.cache.Set(token, claims, expiration) } +// Get retrieves a token from the cache func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) { token = "t-" + token value, found := tc.cache.Get(token) @@ -320,37 +385,28 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) { return claims, ok } +// Delete removes a token from the cache func (tc *TokenCache) Delete(token string) { token = "t-" + token tc.cache.Delete(token) } +// Cleanup cleans up expired tokens from the cache func (tc *TokenCache) Cleanup() { tc.cache.Cleanup() } -func (t *TraefikOidc) exchangeCodeForToken(code string) (map[string]interface{}, error) { - data := url.Values{} - data.Set("grant_type", "authorization_code") - data.Set("client_id", t.clientID) - data.Set("client_secret", t.clientSecret) - data.Set("code", code) - data.Set("redirect_uri", t.redirectURL) - - resp, err := t.httpClient.PostForm(t.tokenURL, data) +// exchangeCodeForToken exchanges the authorization code for tokens +func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error) { + ctx := context.Background() + tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, t.redirectURL) if err != nil { - return nil, fmt.Errorf("failed to exchange token: %v", err) + return nil, fmt.Errorf("failed to exchange code for token: %w", err) } - defer resp.Body.Close() - - var result map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, fmt.Errorf("failed to decode token response: %v", err) - } - - return result, nil + return tokenResponse, nil } +// createStringMap creates a map from a slice of strings func createStringMap(keys []string) map[string]struct{} { result := make(map[string]struct{}) for _, key := range keys { diff --git a/jwk.go b/jwk.go index a434e68..b1d096c 100644 --- a/jwk.go +++ b/jwk.go @@ -4,17 +4,19 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" + "math/big" + "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" - "math/big" "net/http" "sync" "time" ) +// JWK represents a JSON Web Key type JWK struct { Kty string `json:"kty"` Kid string `json:"kid"` @@ -27,20 +29,24 @@ type JWK struct { Y string `json:"y"` } +// JWKSet represents a set of JWKs type JWKSet struct { Keys []JWK `json:"keys"` } +// JWKCache caches the JWKs type JWKCache struct { jwks *JWKSet expiresAt time.Time mutex sync.RWMutex } +// JWKCacheInterface defines the interface for the JWK cache type JWKCacheInterface interface { GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) } +// GetJWKS gets the JWKS, either from cache or by fetching it func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { c.mutex.RLock() if c.jwks != nil && time.Now().Before(c.expiresAt) { @@ -67,6 +73,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er return jwks, nil } +// fetchJWKS fetches the JWKS from the provider func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { resp, err := httpClient.Get(jwksURL) if err != nil { @@ -86,36 +93,7 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { return &jwks, nil } -func verifyAudience(tokenAudience interface{}, expectedAudience string) error { - switch aud := tokenAudience.(type) { - case string: - if aud != expectedAudience { - return fmt.Errorf("invalid audience") - } - case []interface{}: - found := false - for _, v := range aud { - if str, ok := v.(string); ok && str == expectedAudience { - found = true - break - } - } - if !found { - return fmt.Errorf("invalid audience") - } - default: - return fmt.Errorf("invalid 'aud' claim type") - } - return nil -} - -func verifyIssuer(tokenIssuer, expectedIssuer string) error { - if tokenIssuer != expectedIssuer { - return fmt.Errorf("invalid issuer") - } - return nil -} - +// jwkToPEM converts a JWK to PEM format func jwkToPEM(jwk *JWK) ([]byte, error) { converter, ok := jwkConverters[jwk.Kty] if !ok { @@ -131,41 +109,45 @@ var jwkConverters = map[string]jwkToPEMConverter{ "EC": ecJWKToPEM, } +// rsaJWKToPEM converts an RSA JWK to PEM func rsaJWKToPEM(jwk *JWK) ([]byte, error) { - n, err := base64.RawURLEncoding.DecodeString(jwk.N) + nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N) if err != nil { return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err) } - e, err := base64.RawURLEncoding.DecodeString(jwk.E) + eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E) if err != nil { return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err) } - publicKey := &rsa.PublicKey{ - N: new(big.Int).SetBytes(n), - E: int(new(big.Int).SetBytes(e).Int64()), + n := new(big.Int).SetBytes(nBytes) + e := new(big.Int).SetBytes(eBytes) + + pubKey := &rsa.PublicKey{ + N: n, + E: int(e.Int64()), } - publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) + pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey) if err != nil { - return nil, fmt.Errorf("failed to marshal public key: %w", err) + return nil, fmt.Errorf("failed to marshal RSA public key: %w", err) } - publicKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PUBLIC KEY", - Bytes: publicKeyBytes, + pubKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: pubKeyBytes, }) - return publicKeyPEM, nil + return pubKeyPEM, nil } +// ecJWKToPEM converts an EC JWK to PEM func ecJWKToPEM(jwk *JWK) ([]byte, error) { - x, err := base64.RawURLEncoding.DecodeString(jwk.X) + xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X) if err != nil { return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err) } - - y, err := base64.RawURLEncoding.DecodeString(jwk.Y) + yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y) if err != nil { return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err) } @@ -182,21 +164,21 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) { return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv) } - publicKey := &ecdsa.PublicKey{ + pubKey := &ecdsa.PublicKey{ Curve: curve, - X: new(big.Int).SetBytes(x), - Y: new(big.Int).SetBytes(y), + X: new(big.Int).SetBytes(xBytes), + Y: new(big.Int).SetBytes(yBytes), } - publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) + pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey) if err != nil { - return nil, fmt.Errorf("failed to marshal public key: %w", err) + return nil, fmt.Errorf("failed to marshal EC public key: %w", err) } - publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + pubKeyPEM := pem.EncodeToMemory(&pem.Block{ Type: "PUBLIC KEY", - Bytes: publicKeyBytes, + Bytes: pubKeyBytes, }) - return publicKeyPEM, nil + return pubKeyPEM, nil } diff --git a/jwt.go b/jwt.go index 060edd5..a04dfef 100644 --- a/jwt.go +++ b/jwt.go @@ -4,29 +4,36 @@ import ( "crypto" "crypto/ecdsa" "crypto/rsa" + "crypto/sha256" + "strings" + "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" - "math/big" - "strings" + "time" ) +// JWT represents a JSON Web Token type JWT struct { Header map[string]interface{} Claims map[string]interface{} - Signature string + Signature []byte + Token string } +// parseJWT parses a JWT token string into a JWT struct func parseJWT(tokenString string) (*JWT, error) { parts := strings.Split(tokenString, ".") if len(parts) != 3 { return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) } - jwt := &JWT{} + jwt := &JWT{ + Token: tokenString, + } // Decode and unmarshal the header headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) @@ -46,12 +53,17 @@ func parseJWT(tokenString string) (*JWT, error) { return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err) } - // Set the signature - jwt.Signature = parts[2] + // Decode the signature + signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err) + } + jwt.Signature = signatureBytes return jwt, nil } +// Verify verifies the standard claims in the JWT func (j *JWT) Verify(issuerURL, clientID string) error { claims := j.Claims @@ -95,6 +107,39 @@ func (j *JWT) Verify(issuerURL, clientID string) error { return nil } +// verifyAudience verifies the audience claim +func verifyAudience(tokenAudience interface{}, expectedAudience string) error { + switch aud := tokenAudience.(type) { + case string: + if aud != expectedAudience { + return fmt.Errorf("invalid audience") + } + case []interface{}: + found := false + for _, v := range aud { + if str, ok := v.(string); ok && str == expectedAudience { + found = true + break + } + } + if !found { + return fmt.Errorf("invalid audience") + } + default: + return fmt.Errorf("invalid 'aud' claim type") + } + return nil +} + +// verifyIssuer verifies the issuer claim +func verifyIssuer(tokenIssuer, expectedIssuer string) error { + if tokenIssuer != expectedIssuer { + return fmt.Errorf("invalid issuer") + } + return nil +} + +// verifyExpiration checks if the token has expired func verifyExpiration(expiration float64) error { expirationTime := time.Unix(int64(expiration), 0) if time.Now().After(expirationTime) { @@ -103,7 +148,24 @@ func verifyExpiration(expiration float64) error { return nil } -func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error { +// verifyIssuedAt checks if the token was issued in the future +func verifyIssuedAt(issuedAt float64) error { + issuedAtTime := time.Unix(int64(issuedAt), 0) + if time.Now().Before(issuedAtTime) { + return fmt.Errorf("token used before issued") + } + return nil +} + +// verifySignature verifies the token signature +func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error { + parts := strings.Split(tokenString, ".") + signedContent := parts[0] + "." + parts[1] + signature, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return fmt.Errorf("failed to decode signature: %w", err) + } + block, _ := pem.Decode(publicKeyPEM) if block == nil { return fmt.Errorf("failed to parse PEM block containing the public key") @@ -114,69 +176,19 @@ func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte return fmt.Errorf("failed to parse public key: %w", err) } - var hash crypto.Hash - var verifyFunc func(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error - - switch alg { - case "RS256", "RS384", "RS512": - hash = crypto.SHA256 // SHA384 and SHA512 are used for RS384 and RS512 respectively. - verifyFunc = rsaVerifyPKCS1v15 - case "PS256", "PS384", "PS512": - hash = crypto.SHA256 // SHA384 and SHA512 are used for PS384 and PS512 respectively. - verifyFunc = rsaVerifyPSS - case "ES256", "ES384", "ES512": - hash = crypto.SHA256 // SHA384 and SHA512 are used for ES384 and ES512 respectively. - verifyFunc = ecdsaVerify - default: - return fmt.Errorf("unsupported algorithm: %s", alg) - } - - h := hash.New() + h := sha256.New() h.Write([]byte(signedContent)) hashed := h.Sum(nil) - return verifyFunc(pubKey, hashed, signature, hash) -} - -func rsaVerifyPKCS1v15(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error { - pubKey, ok := publicKey.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("invalid public key type for RSA: %T", publicKey) - } - return rsa.VerifyPKCS1v15(pubKey, hash, hashed, signature) -} - -func rsaVerifyPSS(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error { - pubKey, ok := publicKey.(*rsa.PublicKey) - if !ok { - return fmt.Errorf("invalid public key type for RSA: %T", publicKey) - } - opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} - return rsa.VerifyPSS(pubKey, crypto.SHA256, hashed, signature, opts) -} - -func ecdsaVerify(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error { - pubKey, ok := publicKey.(*ecdsa.PublicKey) - if !ok { - return fmt.Errorf("invalid public key type for ECDSA: %T", publicKey) - } - keyBytes := (pubKey.Params().BitSize + 7) / 8 - if len(signature) != 2*keyBytes { - return fmt.Errorf("invalid signature length for ECDSA: expected %d bytes, got %d bytes", 2*keyBytes, len(signature)) - } - r := new(big.Int).SetBytes(signature[:keyBytes]) - s := new(big.Int).SetBytes(signature[keyBytes:]) - - if ecdsa.Verify(pubKey, hashed, r, s) { + switch pubKey := pubKey.(type) { + case *rsa.PublicKey: + return rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hashed, signature) + case *ecdsa.PublicKey: + if !ecdsa.VerifyASN1(pubKey, hashed, signature) { + return fmt.Errorf("invalid ECDSA signature") + } return nil + default: + return fmt.Errorf("unsupported public key type: %T", pubKey) } - return fmt.Errorf("invalid ECDSA signature") -} - -func verifyIssuedAt(issuedAt float64) error { - issuedAtTime := time.Unix(int64(issuedAt), 0) - if time.Now().Before(issuedAtTime) { - return fmt.Errorf("token used before issued") - } - return nil } diff --git a/main.go b/main.go index 5f72f76..62e5561 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package traefikoidc import ( "context" - "encoding/base64" "encoding/json" "fmt" "io" @@ -19,16 +18,19 @@ import ( "golang.org/x/time/rate" ) -const ConstSessionTimeout = 86400 +const ConstSessionTimeout = 86400 // Session timeout in seconds +// TokenVerifier interface for token verification type TokenVerifier interface { VerifyToken(token string) error } +// JWTVerifier interface for JWT verification type JWTVerifier interface { VerifyJWTSignatureAndClaims(jwt *JWT, token string) error } +// TraefikOidc is the main struct for the OIDC middleware type TraefikOidc struct { next http.Handler name string @@ -58,12 +60,13 @@ type TraefikOidc struct { allowedUserDomains map[string]struct{} allowedRolesAndGroups map[string]struct{} initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) - exchangeCodeForTokenFunc func(code string) (map[string]interface{}, error) + exchangeCodeForTokenFunc func(code string) (*TokenResponse, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error) initOnce sync.Once initComplete chan struct{} } +// ProviderMetadata holds OIDC provider metadata type ProviderMetadata struct { Issuer string `json:"issuer"` AuthURL string `json:"authorization_endpoint"` @@ -72,36 +75,45 @@ type ProviderMetadata struct { RevokeURL string `json:"revocation_endpoint"` } +// defaultExcludedURLs are the paths that are excluded from authentication var defaultExcludedURLs = map[string]struct{}{ "/favicon": {}, } var newTicker = time.NewTicker +// VerifyToken verifies the provided JWT token func (t *TraefikOidc) VerifyToken(token string) error { - t.logger.Debugf("Verifying token: %s", token) + t.logger.Debugf("Verifying token") + + // Rate limiting if !t.limiter.Allow() { return fmt.Errorf("rate limit exceeded") } + // Check if token is blacklisted if t.tokenBlacklist.IsBlacklisted(token) { return fmt.Errorf("token is blacklisted") } + // Check if token is cached if _, exists := t.tokenCache.Get(token); exists { t.logger.Debugf("Token is valid and cached") return nil // Token is valid and cached } + // Parse the JWT jwt, err := parseJWT(token) if err != nil { return fmt.Errorf("failed to parse JWT: %w", err) } + // Verify JWT signature and claims if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil { return err } + // Cache the token until it expires expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0) now := time.Now() duration := expirationTime.Sub(now) @@ -110,26 +122,27 @@ func (t *TraefikOidc) VerifyToken(token string) error { return nil } +// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error { - t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header) + t.logger.Debugf("Verifying JWT signature and claims") + // Get JWKS jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient) if err != nil { return fmt.Errorf("failed to get JWKS: %w", err) } + // Retrieve key ID and algorithm from JWT header kid, ok := jwt.Header["kid"].(string) if !ok { return fmt.Errorf("missing key ID in token header") } - t.logger.Debugf("Token kid: %s", kid) - alg, ok := jwt.Header["alg"].(string) if !ok { return fmt.Errorf("missing algorithm in token header") } - t.logger.Debugf("Token alg: %s", alg) + // Find the matching key in JWKS var matchingKey *JWK for _, key := range jwks.Keys { if key.Kid == kid { @@ -137,48 +150,35 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error break } } - if matchingKey == nil { return fmt.Errorf("no matching public key found for kid: %s", kid) } - t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg) + // Convert JWK to PEM format publicKeyPEM, err := jwkToPEM(matchingKey) if err != nil { return fmt.Errorf("failed to convert JWK to PEM: %w", err) } - t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM)) - parts := strings.Split(token, ".") - if len(parts) != 3 { - return fmt.Errorf("invalid token format") - } - - signedContent := parts[0] + "." + parts[1] - signature, err := base64.RawURLEncoding.DecodeString(parts[2]) - if err != nil { - return fmt.Errorf("failed to decode signature: %w", err) - } - - if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil { - t.logger.Errorf("Signature verification failed: %v", err) + // Verify the signature + if err := verifySignature(token, publicKeyPEM, alg); err != nil { return fmt.Errorf("signature verification failed: %w", err) } - t.logger.Debug("Signature verified successfully") // Verify standard claims if err := jwt.Verify(t.issuerURL, t.clientID); err != nil { return fmt.Errorf("standard claim verification failed: %w", err) } - t.logger.Debug("Standard claims verified successfully") return nil } +// New creates a new instance of the OIDC middleware func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) { store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey)) store.Options = defaultSessionOptions + // Setup HTTP client transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -217,9 +217,8 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h } return config.LogoutURL }(), - tokenBlacklist: NewTokenBlacklist(), - jwkCache: &JWKCache{}, - + tokenBlacklist: NewTokenBlacklist(), + jwkCache: &JWKCache{}, clientID: config.ClientID, clientSecret: config.ClientSecret, forceHTTPS: config.ForceHTTPS, @@ -229,17 +228,18 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h httpClient: httpClient, logger: NewLogger(config.LogLevel), excludedURLs: createStringMap(config.ExcludedURLs), - redirectURL: "", allowedUserDomains: createStringMap(config.AllowedUserDomains), allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups), initComplete: make(chan struct{}), } - t.initiateAuthenticationFunc = t.defaultInitiateAuthentication - t.exchangeCodeForTokenFunc = t.exchangeCodeForToken t.extractClaimsFunc = extractClaims + t.exchangeCodeForTokenFunc = t.exchangeCodeForToken + t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) { + t.defaultInitiateAuthentication(rw, req, session, redirectURL) + } - // add defaultExcludedURLs to excludedURLs + // Add default excluded URLs for k, v := range defaultExcludedURLs { t.excludedURLs[k] = v } @@ -248,14 +248,16 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h t.jwtVerifier = t t.startTokenCleanup() go t.initializeMetadata(config.ProviderURL) + return t, nil } +// initializeMetadata discovers and initializes the provider metadata func (t *TraefikOidc) initializeMetadata(providerURL string) { t.initOnce.Do(func() { metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger) if err != nil { - t.logger.Error("Failed to discover provider metadata: %v", err) + t.logger.Errorf("Failed to discover provider metadata: %v", err) } else { t.logger.Debug("Provider metadata discovered successfully") t.jwksURL = metadata.JWKSURL @@ -268,6 +270,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) { }) } +// discoverProviderMetadata fetches the OIDC provider metadata func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) { wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration" @@ -281,7 +284,7 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo var lastErr error for attempt := 0; attempt < maxRetries; attempt++ { if time.Since(start) > totalTimeout { - l.Error("Timeout exceeded while fetching provider metadata") + l.Errorf("Timeout exceeded while fetching provider metadata") return nil, fmt.Errorf("timeout exceeded while fetching provider metadata: %w", lastErr) } @@ -293,18 +296,20 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo lastErr = err + // Exponential backoff delay := time.Duration(math.Pow(2, float64(attempt))) * baseDelay if delay > maxDelay { delay = maxDelay } - l.Debug("Failed to fetch provider metadata, retrying in %s", delay) + l.Debugf("Failed to fetch provider metadata, retrying in %s", delay) time.Sleep(delay) } - l.Error("Max retries exceeded while fetching provider metadata") + l.Errorf("Max retries exceeded while fetching provider metadata") return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr) } +// fetchMetadata fetches metadata from the well-known OIDC configuration endpoint func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetadata, error) { resp, err := httpClient.Get(wellKnownURL) if err != nil { @@ -327,6 +332,7 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad return &metadata, nil } +// ServeHTTP is the main handler for the middleware func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { select { case <-t.initComplete: @@ -342,20 +348,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } + // Check if the URL is excluded from authentication if t.determineExcludedURL(req.URL.Path) { t.next.ServeHTTP(rw, req) return } + // Determine the scheme (http/https) and host t.scheme = t.determineScheme(req) defaultSessionOptions.Secure = t.scheme == "https" host := t.determineHost(req) + // Build the redirect URL if not already set if t.redirectURL == "" { t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath) t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL) } + // Get the session session, err := t.store.Get(req, cookieName) if err != nil { t.logger.Errorf("Error getting session: %v", err) @@ -365,16 +375,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.logger.Debugf("Session contents at start: %+v", session.Values) + // Handle logout URL if req.URL.Path == t.logoutURLPath { t.handleLogout(rw, req) return } + // Handle callback URL if req.URL.Path == t.redirURLPath { t.handleCallback(rw, req) return } + // Check if the user is authenticated authenticated, needsRefresh, expired := t.isUserAuthenticated(session) if expired { @@ -395,84 +408,80 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } - // authenticated, _ := session.Values["authenticated"].(bool) - if authenticated { - idToken, ok := session.Values["id_token"].(string) - if !ok || idToken == "" { - t.logger.Errorf("No id_token found in session") - t.defaultInitiateAuthentication(rw, req, session, t.redirectURL) - return - } - - claims, err := extractClaims(idToken) - if err != nil { - t.logger.Errorf("Failed to extract claims: %v", err) - t.defaultInitiateAuthentication(rw, req, session, t.redirectURL) - return - } - - email, _ := claims["email"].(string) - if email == "" { - t.logger.Debugf("No email found in token claims") - t.defaultInitiateAuthentication(rw, req, session, t.redirectURL) - return - } - - if !t.isAllowedDomain(email) { - t.logger.Infof("User with email %s is not from an allowed domain", email) - http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden) - return - } - - groups, roles, err := t.extractGroupsAndRoles(idToken) - if err != nil { - t.logger.Errorf("Failed to extract groups and roles: %v", err) - } else { - // Set headers for groups and roles - if len(groups) > 0 { - req.Header.Set("X-User-Groups", strings.Join(groups, ",")) - } - if len(roles) > 0 { - req.Header.Set("X-User-Roles", strings.Join(roles, ",")) - } - } - - if len(t.allowedRolesAndGroups) > 0 { - allowed := false - for _, roleOrGroup := range append(groups, roles...) { - if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok { - allowed = true - break - } - } - if !allowed { - t.logger.Infof("User with email %s does not have any allowed roles or groups", email) - http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden) - return - } - } - - req.Header.Set("X-Forwarded-User", email) - - t.next.ServeHTTP(rw, req) + // At this point, the user is authenticated + idToken, ok := session.Values["id_token"].(string) + if !ok || idToken == "" { + t.logger.Errorf("No id_token found in session") + t.defaultInitiateAuthentication(rw, req, session, t.redirectURL) return } - t.logger.Debug("User is not authenticated, initiating authentication") - t.defaultInitiateAuthentication(rw, req, session, t.redirectURL) + claims, err := extractClaims(idToken) + if err != nil { + t.logger.Errorf("Failed to extract claims: %v", err) + t.defaultInitiateAuthentication(rw, req, session, t.redirectURL) + return + } + + email, _ := claims["email"].(string) + if email == "" { + t.logger.Debugf("No email found in token claims") + t.defaultInitiateAuthentication(rw, req, session, t.redirectURL) + return + } + + if !t.isAllowedDomain(email) { + t.logger.Infof("User with email %s is not from an allowed domain", email) + http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden) + return + } + + groups, roles, err := t.extractGroupsAndRoles(idToken) + if err != nil { + t.logger.Errorf("Failed to extract groups and roles: %v", err) + } else { + // Set headers for groups and roles + if len(groups) > 0 { + req.Header.Set("X-User-Groups", strings.Join(groups, ",")) + } + if len(roles) > 0 { + req.Header.Set("X-User-Roles", strings.Join(roles, ",")) + } + } + + if len(t.allowedRolesAndGroups) > 0 { + allowed := false + for _, roleOrGroup := range append(groups, roles...) { + if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok { + allowed = true + break + } + } + if !allowed { + t.logger.Infof("User with email %s does not have any allowed roles or groups", email) + http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden) + return + } + } + + req.Header.Set("X-Forwarded-User", email) + + t.next.ServeHTTP(rw, req) } +// determineExcludedURL checks if the current request URL is in the excluded list func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool { for excludedURL := range t.excludedURLs { if strings.HasPrefix(currentRequest, excludedURL) { - t.logger.Debug("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL) + t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL) return true } } - t.logger.Debug("URL is not excluded - got %s", currentRequest) + t.logger.Debugf("URL is not excluded - got %s", currentRequest) return false } +// determineScheme determines the scheme (http or https) of the request func (t *TraefikOidc) determineScheme(req *http.Request) string { if t.forceHTTPS { return "https" @@ -486,6 +495,7 @@ func (t *TraefikOidc) determineScheme(req *http.Request) string { return "http" } +// determineHost determines the host of the request func (t *TraefikOidc) determineHost(req *http.Request) string { if host := req.Header.Get("X-Forwarded-Host"); host != "" { return host @@ -493,6 +503,7 @@ func (t *TraefikOidc) determineHost(req *http.Request) string { return req.Host } +// isUserAuthenticated checks if the user is authenticated func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) { authenticated, _ := session.Values["authenticated"].(bool) t.logger.Debugf("Session authenticated value: %v", authenticated) @@ -543,13 +554,16 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool return true, false, false // Token is valid and not expiring soon } +// defaultInitiateAuthentication initiates the authentication process func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) { + // Generate CSRF token csrfToken := uuid.New().String() session.Values["csrf"] = csrfToken session.Values["incoming_path"] = req.URL.Path session.Options = defaultSessionOptions t.logger.Debugf("Setting CSRF token: %s", csrfToken) + // Generate nonce nonce, err := generateNonce() if err != nil { http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError) @@ -558,20 +572,24 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req session.Values["nonce"] = nonce t.logger.Debugf("Setting nonce: %s", nonce) + // Save the session if err := session.Save(req, rw); err != nil { t.logger.Errorf("Failed to save session: %v", err) http.Error(rw, "Failed to save session", http.StatusInternalServerError) return } + // Build the authentication URL authURL := t.buildAuthURL(redirectURL, csrfToken, nonce) http.Redirect(rw, req, authURL, http.StatusFound) } +// verifyToken verifies the token using the token verifier func (t *TraefikOidc) verifyToken(token string) error { return t.tokenVerifier.VerifyToken(token) } +// buildAuthURL constructs the authentication URL func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { params := url.Values{} params.Set("client_id", t.clientID) @@ -585,6 +603,7 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { return t.authURL + "?" + params.Encode() } +// startTokenCleanup starts the token cleanup goroutine func (t *TraefikOidc) startTokenCleanup() { ticker := newTicker(1 * time.Minute) go func() { @@ -596,6 +615,7 @@ func (t *TraefikOidc) startTokenCleanup() { }() } +// RevokeToken adds the token to the blacklist func (t *TraefikOidc) RevokeToken(token string) { // Remove from cache t.tokenCache.Delete(token) @@ -610,12 +630,13 @@ func (t *TraefikOidc) RevokeToken(token string) { } } -func (t *TraefikOidc) RevokeTokenWithProvider(token string) error { +// RevokeTokenWithProvider revokes the token with the provider +func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error { t.logger.Debugf("Revoking token with provider") data := url.Values{ "token": {token}, - "token_type_hint": {"access_token", "refresh_token"}, + "token_type_hint": {tokenType}, "client_id": {t.clientID}, "client_secret": {t.clientSecret}, } @@ -646,10 +667,12 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token string) error { return nil } +// refreshToken refreshes the user's token func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool { t.logger.Debug("Refreshing token") refreshToken, ok := session.Values["refresh_token"].(string) if !ok || refreshToken == "" { + t.logger.Debug("No refresh token found in session") return false } @@ -659,6 +682,13 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se return false } + // Verify the new id_token + if err := t.verifyToken(newToken.IDToken); err != nil { + t.logger.Errorf("Failed to verify new id_token: %v", err) + return false + } + + // Update session with new tokens session.Values["id_token"] = newToken.IDToken session.Values["refresh_token"] = newToken.RefreshToken session.Options = defaultSessionOptions @@ -670,6 +700,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se return true } +// isAllowedDomain checks if the user's email domain is allowed func (t *TraefikOidc) isAllowedDomain(email string) bool { if len(t.allowedUserDomains) == 0 { return true // If no domains are specified, all are allowed @@ -685,6 +716,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool { return ok } +// extractGroupsAndRoles extracts groups and roles from the id_token func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) { claims, err := t.extractClaimsFunc(idToken) if err != nil { @@ -706,25 +738,17 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, } } - if len(groups) == 0 { - t.logger.Debug("No groups found in groups claim, checking roles claim") - } - // Check for roles claim if rolesClaim, ok := claims["roles"]; ok { if rolesSlice, ok := rolesClaim.([]interface{}); ok { for _, role := range rolesSlice { if roleStr, ok := role.(string); ok { - t.logger.Debug("Found role: %s", roleStr) + t.logger.Debugf("Found role: %s", roleStr) roles = append(roles, roleStr) } } } } - if len(roles) == 0 { - t.logger.Debug("No roles found in roles claim") - } - return groups, roles, nil } diff --git a/main_test.go b/main_test.go index a6bfd5c..08b3ecf 100644 --- a/main_test.go +++ b/main_test.go @@ -72,6 +72,7 @@ func (ts *TestSuite) Setup() { "iat": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com", + "nonce": "test-nonce", }) if err != nil { ts.t.Fatalf("Failed to create test JWT: %v", err) @@ -79,33 +80,34 @@ func (ts *TestSuite) Setup() { // Common TraefikOidc instance ts.tOidc = &TraefikOidc{ - issuerURL: "https://test-issuer.com", - clientID: "test-client-id", - clientSecret: "test-client-secret", - jwkCache: ts.mockJWKCache, - jwksURL: "https://test-jwks-url.com", - revocationURL: "https://revocation-endpoint.com", - limiter: rate.NewLimiter(rate.Every(time.Second), 10), - tokenBlacklist: NewTokenBlacklist(), - tokenCache: NewTokenCache(), - logger: NewLogger("info"), - store: sessions.NewCookieStore([]byte("test-secret-key")), - allowedUserDomains: map[string]struct{}{"example.com": {}}, - excludedURLs: map[string]struct{}{"/favicon": {}}, - httpClient: &http.Client{}, - exchangeCodeForTokenFunc: ts.exchangeCodeForTokenFunc, - extractClaimsFunc: extractClaims, - initComplete: make(chan struct{}), + issuerURL: "https://test-issuer.com", + clientID: "test-client-id", + clientSecret: "test-client-secret", + jwkCache: ts.mockJWKCache, + jwksURL: "https://test-jwks-url.com", + revocationURL: "https://revocation-endpoint.com", + limiter: rate.NewLimiter(rate.Every(time.Second), 10), + tokenBlacklist: NewTokenBlacklist(), + tokenCache: NewTokenCache(), + logger: NewLogger("info"), + store: sessions.NewCookieStore([]byte("test-secret-key")), + allowedUserDomains: map[string]struct{}{"example.com": {}}, + excludedURLs: map[string]struct{}{"/favicon": {}}, + httpClient: &http.Client{}, + extractClaimsFunc: extractClaims, + initComplete: make(chan struct{}), } close(ts.tOidc.initComplete) + ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc ts.tOidc.tokenVerifier = ts.tOidc ts.tOidc.jwtVerifier = ts.tOidc } // Helper functions used by TraefikOidc -func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (map[string]interface{}, error) { - return map[string]interface{}{ - "id_token": ts.token, +func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (*TokenResponse, error) { + return &TokenResponse{ + IDToken: ts.token, + RefreshToken: "test-refresh-token", }, nil } @@ -453,61 +455,149 @@ func TestHandleCallback(t *testing.T) { tests := []struct { name string queryParams string - exchangeCodeForToken func(code string) (map[string]interface{}, error) + exchangeCodeForToken func(code string) (*TokenResponse, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error) + sessionSetupFunc func(session *sessions.Session) expectedStatus int }{ { name: "Success", - queryParams: "?code=test-code", - exchangeCodeForToken: func(code string) (map[string]interface{}, error) { - return map[string]interface{}{ - "id_token": "test-id-token", + queryParams: "?code=test-code&state=test-csrf-token", + exchangeCodeForToken: func(code string) (*TokenResponse, error) { + return &TokenResponse{ + IDToken: ts.token, + RefreshToken: "test-refresh-token", }, nil }, extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { return map[string]interface{}{ "email": "user@example.com", + "nonce": "test-nonce", }, nil }, + sessionSetupFunc: func(session *sessions.Session) { + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" + }, expectedStatus: http.StatusFound, }, { - name: "Missing Code", - queryParams: "", + name: "Missing Code", + queryParams: "", + sessionSetupFunc: func(session *sessions.Session) { + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" + }, expectedStatus: http.StatusBadRequest, }, { name: "Exchange Code Error", - queryParams: "?code=test-code", - exchangeCodeForToken: func(code string) (map[string]interface{}, error) { + queryParams: "?code=test-code&state=test-csrf-token", + exchangeCodeForToken: func(code string) (*TokenResponse, error) { return nil, fmt.Errorf("exchange code error") }, + sessionSetupFunc: func(session *sessions.Session) { + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" + }, expectedStatus: http.StatusInternalServerError, }, { name: "Missing ID Token", - queryParams: "?code=test-code", - exchangeCodeForToken: func(code string) (map[string]interface{}, error) { - return map[string]interface{}{}, nil + queryParams: "?code=test-code&state=test-csrf-token", + exchangeCodeForToken: func(code string) (*TokenResponse, error) { + return &TokenResponse{}, nil + }, + sessionSetupFunc: func(session *sessions.Session) { + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" }, expectedStatus: http.StatusInternalServerError, }, { name: "Disallowed Email", - queryParams: "?code=test-code", - exchangeCodeForToken: func(code string) (map[string]interface{}, error) { - return map[string]interface{}{ - "id_token": "test-id-token", + queryParams: "?code=test-code&state=test-csrf-token", + exchangeCodeForToken: func(code string) (*TokenResponse, error) { + return &TokenResponse{ + IDToken: ts.token, + RefreshToken: "test-refresh-token", }, nil }, extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { return map[string]interface{}{ "email": "user@disallowed.com", + "nonce": "test-nonce", }, nil }, + sessionSetupFunc: func(session *sessions.Session) { + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" + }, expectedStatus: http.StatusForbidden, }, + { + name: "Invalid State Parameter", + queryParams: "?code=test-code&state=invalid-csrf-token", + exchangeCodeForToken: func(code string) (*TokenResponse, error) { + return &TokenResponse{ + IDToken: ts.token, + RefreshToken: "test-refresh-token", + }, nil + }, + extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { + return map[string]interface{}{ + "email": "user@example.com", + "nonce": "test-nonce", + }, nil + }, + sessionSetupFunc: func(session *sessions.Session) { + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "Nonce Mismatch", + queryParams: "?code=test-code&state=test-csrf-token", + exchangeCodeForToken: func(code string) (*TokenResponse, error) { + return &TokenResponse{ + IDToken: ts.token, + RefreshToken: "test-refresh-token", + }, nil + }, + extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { + return map[string]interface{}{ + "email": "user@example.com", + "nonce": "invalid-nonce", + }, nil + }, + sessionSetupFunc: func(session *sessions.Session) { + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" + }, + expectedStatus: http.StatusInternalServerError, + }, + { + name: "Missing Nonce in Claims", + queryParams: "?code=test-code&state=test-csrf-token", + exchangeCodeForToken: func(code string) (*TokenResponse, error) { + return &TokenResponse{ + IDToken: ts.token, + RefreshToken: "test-refresh-token", + }, nil + }, + extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { + return map[string]interface{}{ + "email": "user@example.com", + // Missing nonce + }, nil + }, + sessionSetupFunc: func(session *sessions.Session) { + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" + }, + expectedStatus: http.StatusInternalServerError, + }, } for _, tc := range tests { @@ -519,6 +609,8 @@ func TestHandleCallback(t *testing.T) { logger: NewLogger("info"), exchangeCodeForTokenFunc: tc.exchangeCodeForToken, extractClaimsFunc: tc.extractClaimsFunc, + tokenVerifier: ts.tOidc.tokenVerifier, + jwtVerifier: ts.tOidc.jwtVerifier, } // Create request and response recorder @@ -527,6 +619,9 @@ func TestHandleCallback(t *testing.T) { // Create session session, _ := tOidc.store.New(req, cookieName) + if tc.sessionSetupFunc != nil { + tc.sessionSetupFunc(session) + } session.Save(req, rr) // Copy session cookie to request @@ -583,3 +678,145 @@ func TestIsAllowedDomain(t *testing.T) { }) } } + +func TestOIDCHandler(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + ts.token = "valid.jwt.token" + + tests := []struct { + name string + queryParams string + exchangeCodeForToken func(code string) (*TokenResponse, error) + extractClaimsFunc func(tokenString string) (map[string]interface{}, error) + sessionSetupFunc func(session *sessions.Session) + expectedStatus int + blacklist bool + rateLimit bool + cacheToken bool + }{ + { + name: "Missing Code", + queryParams: "", + sessionSetupFunc: func(session *sessions.Session) { + // Set CSRF and nonce values in session + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" + }, + exchangeCodeForToken: func(code string) (*TokenResponse, error) { + // Simulate token exchange + return &TokenResponse{ + IDToken: ts.token, + RefreshToken: "test-refresh-token", + }, nil + }, + extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { + // Simulate extraction of claims with invalid nonce + return map[string]interface{}{ + "email": "user@example.com", + "nonce": "invalid-nonce", + }, nil + }, + expectedStatus: http.StatusInternalServerError, + }, + { + name: "Missing Nonce in Claims", + queryParams: "?code=test-code&state=test-csrf-token", + sessionSetupFunc: func(session *sessions.Session) { + // Set CSRF and nonce values in session + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" + }, + exchangeCodeForToken: func(code string) (*TokenResponse, error) { + // Simulate token exchange + return &TokenResponse{ + IDToken: ts.token, + RefreshToken: "test-refresh-token", + }, nil + }, + extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { + // Simulate extraction of claims without nonce + return map[string]interface{}{ + "email": "user@example.com", + }, nil + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "Invalid State Parameter", + queryParams: "?code=test-code&state=invalid-csrf-token", + sessionSetupFunc: func(session *sessions.Session) { + // Set CSRF and nonce values in session + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" + }, + exchangeCodeForToken: func(code string) (*TokenResponse, error) { + // Simulate token exchange + return &TokenResponse{ + IDToken: ts.token, + RefreshToken: "test-refresh-token", + }, nil + }, + extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { + // Simulate extraction of claims + return map[string]interface{}{ + "email": "user@example.com", + "nonce": "test-nonce", + }, nil + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "Nonce Mismatch", + queryParams: "?code=test-code&state=test-csrf-token", + sessionSetupFunc: func(session *sessions.Session) { + // Set CSRF and nonce values in session + session.Values["csrf"] = "test-csrf-token" + session.Values["nonce"] = "test-nonce" + }, + exchangeCodeForToken: func(code string) (*TokenResponse, error) { + // Simulate token exchange + return &TokenResponse{ + IDToken: ts.token, + RefreshToken: "test-refresh-token", + }, nil + }, + extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { + // Simulate extraction of claims with mismatched nonce + return map[string]interface{}{ + "email": "user@example.com", + "nonce": "invalid-nonce", + }, nil + }, + expectedStatus: http.StatusBadRequest, + }, + } + + for _, tc := range tests { + tc := tc // Capture range variable + t.Run(tc.name, func(t *testing.T) { + // Reset token blacklist and cache + ts.tOidc.tokenBlacklist = NewTokenBlacklist() + ts.tOidc.tokenCache = NewTokenCache() + ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10) + + // Set up the test case + if tc.blacklist { + ts.tOidc.tokenBlacklist.Add(ts.token, time.Now().Add(1*time.Hour)) + } + + if tc.rateLimit { + // Exceed rate limit + ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0) + } + + if tc.cacheToken { + // Cache the token with dummy claims + ts.tOidc.tokenCache.Set(ts.token, map[string]interface{}{ + "empty": "claim", + }, 60) + } + }) + } +} diff --git a/settings.go b/settings.go index 93ac1f8..4da83e2 100644 --- a/settings.go +++ b/settings.go @@ -14,6 +14,7 @@ const ( cookieName = "_raczylo_oidc" ) +// Config holds the configuration for the OIDC middleware type Config struct { ProviderURL string `json:"providerURL"` RevocationURL string `json:"revocationURL"` @@ -40,6 +41,7 @@ var defaultSessionOptions = &sessions.Options{ Path: "/", } +// CreateConfig creates a new Config with default values func CreateConfig() *Config { c := &Config{} @@ -62,6 +64,7 @@ func CreateConfig() *Config { return c } +// Validate validates the Config func (c *Config) Validate() error { if c.ProviderURL == "" { return fmt.Errorf("providerURL is required") @@ -81,12 +84,14 @@ func (c *Config) Validate() error { return nil } +// Logger is a simple logger with different levels type Logger struct { logError *log.Logger logInfo *log.Logger logDebug *log.Logger } +// NewLogger creates a new Logger func NewLogger(logLevel string) *Logger { logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime) logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime) @@ -106,30 +111,37 @@ func NewLogger(logLevel string) *Logger { } } +// Info logs an info message func (l *Logger) Info(format string, args ...interface{}) { l.logInfo.Printf(format, args...) } +// Debug logs a debug message func (l *Logger) Debug(format string, args ...interface{}) { l.logDebug.Printf(format, args...) } +// Error logs an error message func (l *Logger) Error(format string, args ...interface{}) { l.logError.Printf(format, args...) } +// Infof logs an info message func (l *Logger) Infof(format string, args ...interface{}) { l.logInfo.Printf(format, args...) } +// Debugf logs a debug message func (l *Logger) Debugf(format string, args ...interface{}) { l.logDebug.Printf(format, args...) } +// Errorf logs an error message func (l *Logger) Errorf(format string, args ...interface{}) { l.logError.Printf(format, args...) } +// handleError writes an error message to the response and logs it func handleError(w http.ResponseWriter, message string, code int, logger *Logger) { logger.Error(message) http.Error(w, message, code)