From 23e019092a7236afa4eea393c08eadb41f73ccac Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Fri, 4 Apr 2025 18:42:41 +0100 Subject: [PATCH] Multiple improvements for April 2025 * Improve refresh token handling in the background. Resolves issue when user opens the website, allows the access token to expire, but continues browsing. The background requests are failing with CORS errors to OIDC provider. * fixup! Improve refresh token handling in the background. * Abstract the token blacklisting. --- blacklist.go | 110 -------------- blacklist_test.go | 74 ---------- helpers.go | 145 +----------------- helpers_test.go | 172 +-------------------- jwt.go | 7 +- main.go | 297 ++++++++++++++++++++++++++++++++----- main_test.go | 369 +++++++++++++++++++++++++++++++++++++--------- metadata_cache.go | 13 +- session_test.go | 11 +- 9 files changed, 587 insertions(+), 611 deletions(-) delete mode 100644 blacklist.go delete mode 100644 blacklist_test.go diff --git a/blacklist.go b/blacklist.go deleted file mode 100644 index 27a75c8..0000000 --- a/blacklist.go +++ /dev/null @@ -1,110 +0,0 @@ -package traefikoidc - -import ( - "sync" - "time" -) - -// TokenBlacklist manages a thread-safe list of revoked tokens with expiration. -type TokenBlacklist struct { - tokens map[string]time.Time - mutex sync.RWMutex -} - -// NewTokenBlacklist creates a new token blacklist instance. -func NewTokenBlacklist() *TokenBlacklist { - return &TokenBlacklist{ - tokens: make(map[string]time.Time), - } -} - -// Add adds a token to the blacklist with an expiration time. -func (b *TokenBlacklist) Add(token string, expiry time.Time) { - b.mutex.Lock() - defer b.mutex.Unlock() - - // Clean up expired tokens if we're at capacity - if len(b.tokens) >= 1000 { - now := time.Now() - futureThreshold := now.Add(time.Minute) - for t, exp := range b.tokens { - if now.After(exp) || futureThreshold.After(exp) { - delete(b.tokens, t) - } - } - - // If still at capacity, remove oldest token - if len(b.tokens) >= 1000 { - var oldestToken string - var oldestTime time.Time - first := true - for t, exp := range b.tokens { - if first || exp.Before(oldestTime) { - oldestToken = t - oldestTime = exp - first = false - } - } - if oldestToken != "" { - delete(b.tokens, oldestToken) - } - } - } - - b.tokens[token] = expiry -} - -// IsBlacklisted checks if a token is in the blacklist and not expired. -func (b *TokenBlacklist) IsBlacklisted(token string) bool { - b.mutex.RLock() - defer b.mutex.RUnlock() - - expiry, exists := b.tokens[token] - if !exists { - return false - } - - // If token is expired, remove it and return false - if time.Now().After(expiry) { - // Switch to write lock to remove expired token - b.mutex.RUnlock() - b.mutex.Lock() - delete(b.tokens, token) - b.mutex.Unlock() - b.mutex.RLock() - return false - } - - return true -} - -// Cleanup removes expired tokens from the blacklist. -// Also removes tokens that will expire within the next minute to prevent edge cases. -func (b *TokenBlacklist) Cleanup() { - b.mutex.Lock() - defer b.mutex.Unlock() - - now := time.Now() - futureThreshold := now.Add(time.Minute) - - for token, expiry := range b.tokens { - // Remove tokens that are expired or will expire soon - if now.After(expiry) || futureThreshold.After(expiry) { - delete(b.tokens, token) - } - } -} - -// Remove removes a token from the blacklist regardless of its expiration. -func (b *TokenBlacklist) Remove(token string) { - b.mutex.Lock() - defer b.mutex.Unlock() - delete(b.tokens, token) -} - -// Count returns the current number of tokens in the blacklist. -func (b *TokenBlacklist) Count() int { - b.mutex.RLock() - defer b.mutex.RUnlock() - return len(b.tokens) -} diff --git a/blacklist_test.go b/blacklist_test.go deleted file mode 100644 index 571348c..0000000 --- a/blacklist_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package traefikoidc - -import ( - "testing" - "time" -) - -func TestTokenBlacklist_Add(t *testing.T) { - blacklist := NewTokenBlacklist() - token := "testToken" - expiry := time.Now().Add(time.Hour) - - blacklist.Add(token, expiry) - - if !blacklist.IsBlacklisted(token) { - t.Errorf("Expected token to be blacklisted, but it was not") - } -} - -func TestTokenBlacklist_IsBlacklisted(t *testing.T) { - blacklist := NewTokenBlacklist() - token := "testToken" - expiry := time.Now().Add(time.Hour) - - blacklist.Add(token, expiry) - - if !blacklist.IsBlacklisted(token) { - t.Errorf("Expected token to be blacklisted, but it was not") - } - - if blacklist.IsBlacklisted("nonExistentToken") { - t.Errorf("Expected non-existent token to not be blacklisted, but it was") - } -} - -func TestTokenBlacklist_Cleanup(t *testing.T) { - blacklist := NewTokenBlacklist() - token := "testToken" - expiry := time.Now().Add(-time.Hour) // Expired token - - blacklist.Add(token, expiry) - blacklist.Cleanup() - - if blacklist.IsBlacklisted(token) { - t.Errorf("Expected expired token to be removed after cleanup, but it was not") - } -} - -func TestTokenBlacklist_Remove(t *testing.T) { - blacklist := NewTokenBlacklist() - token := "testToken" - expiry := time.Now().Add(time.Hour) - - blacklist.Add(token, expiry) - blacklist.Remove(token) - - if blacklist.IsBlacklisted(token) { - t.Errorf("Expected token to be removed, but it was not") - } -} - -func TestTokenBlacklist_Count(t *testing.T) { - blacklist := NewTokenBlacklist() - token1 := "token1" - token2 := "token2" - expiry := time.Now().Add(time.Hour) - - blacklist.Add(token1, expiry) - blacklist.Add(token2, expiry) - - if blacklist.Count() != 2 { - t.Errorf("Expected blacklist count to be 2, but got %d", blacklist.Count()) - } -} diff --git a/helpers.go b/helpers.go index 6d27528..9f9f64e 100644 --- a/helpers.go +++ b/helpers.go @@ -81,7 +81,7 @@ type TokenResponse struct { // - codeOrToken: Either the authorization code or refresh token // - redirectURL: The callback URL for authorization code grant // - codeVerifier: Optional PKCE code verifier for authorization code grant -func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string, codeVerifier string) (*TokenResponse, error) { +func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) { data := url.Values{ "grant_type": {grantType}, "client_id": {t.clientID}, @@ -153,149 +153,6 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe return tokenResponse, nil } -// handleExpiredToken manages token expiration by clearing the session -// and initiating a new authentication flow. -func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { - // Clear authentication data but preserve CSRF state - session.SetAuthenticated(false) - session.SetAccessToken("") - session.SetRefreshToken("") - session.SetEmail("") - - // Save the cleared session state - if err := session.Save(req, rw); err != nil { - t.logger.Errorf("Failed to save cleared session: %v", err) - http.Error(rw, "Internal Server Error", http.StatusInternalServerError) - return - } - - t.defaultInitiateAuthentication(rw, req, session, redirectURL) -} - -// handleCallback processes the authentication callback from the OIDC provider. -// It validates the callback parameters, exchanges the authorization code for -// tokens, verifies the tokens, and establishes the user's session. -func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) { - session, err := t.sessionManager.GetSession(req) - if err != nil { - t.logger.Errorf("Session error: %v", err) - http.Error(rw, "Session error", http.StatusInternalServerError) - return - } - - t.logger.Debugf("Handling callback, URL: %s", req.URL.String()) - - // Check for errors in the callback - if req.URL.Query().Get("error") != "" { - errorDescription := req.URL.Query().Get("error_description") - t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription) - http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest) - return - } - - // Validate CSRF state - state := req.URL.Query().Get("state") - if state == "" { - t.logger.Error("No state in callback") - http.Error(rw, "State parameter missing in callback", http.StatusBadRequest) - return - } - - csrfToken := session.GetCSRF() - if csrfToken == "" { - t.logger.Error("CSRF token missing in session") - http.Error(rw, "CSRF token missing", http.StatusBadRequest) - return - } - - if state != csrfToken { - t.logger.Error("State parameter does not match CSRF token in session") - http.Error(rw, "Invalid state parameter", http.StatusBadRequest) - return - } - - // Exchange code for tokens - code := req.URL.Query().Get("code") - if code == "" { - t.logger.Error("No code in callback") - http.Error(rw, "No code in callback", http.StatusBadRequest) - return - } - - // Get the code verifier from the session for PKCE flow - codeVerifier := session.GetCodeVerifier() - - tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL, codeVerifier) - if err != nil { - t.logger.Errorf("Failed to exchange code for token: %v", err) - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - // Verify tokens and claims - if err := t.verifyToken(tokenResponse.IDToken); err != nil { - t.logger.Errorf("Failed to verify id_token: %v", err) - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - claims, err := t.extractClaimsFunc(tokenResponse.IDToken) - if err != nil { - t.logger.Errorf("Failed to extract claims: %v", err) - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - // Verify nonce to prevent replay attacks - nonceClaim, ok := claims["nonce"].(string) - if !ok || nonceClaim == "" { - t.logger.Error("Nonce claim missing in id_token") - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - sessionNonce := session.GetNonce() - if sessionNonce == "" { - t.logger.Error("Nonce not found in session") - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - if nonceClaim != sessionNonce { - t.logger.Error("Nonce claim does not match session nonce") - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - // Validate user's email domain - email, _ := claims["email"].(string) - if email == "" || !t.isAllowedDomain(email) { - t.logger.Errorf("Invalid or disallowed email: %s", email) - http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden) - return - } - - // Update session with authentication data - session.SetAuthenticated(true) - session.SetEmail(email) - session.SetAccessToken(tokenResponse.IDToken) - session.SetRefreshToken(tokenResponse.RefreshToken) - - if err := session.Save(req, rw); err != nil { - t.logger.Errorf("Failed to save session: %v", err) - http.Error(rw, "Failed to save session", http.StatusInternalServerError) - return - } - - // Redirect to original path or root - redirectPath := "/" - if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath { - redirectPath = incomingPath - } - - http.Redirect(rw, req, redirectPath, http.StatusFound) -} - // extractClaims parses a JWT token and extracts its claims. // It handles base64url decoding and JSON parsing of the token payload. func extractClaims(tokenString string) (map[string]interface{}, error) { diff --git a/helpers_test.go b/helpers_test.go index a55372c..0b0e047 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -7,172 +7,12 @@ import ( "time" ) -func TestTokenBlacklistSizeLimit(t *testing.T) { - tb := NewTokenBlacklist() - - // Add tokens up to maxSize - for i := 0; i < 1000; i++ { - tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour)) - } - - // Verify size is at max - if tb.Count() != 1000 { - t.Errorf("Expected blacklist size to be 1000, got %d", tb.Count()) - } - - // Add one more token, should trigger cleanup/eviction - tb.Add("newtoken", time.Now().Add(time.Hour)) - - // Size should still be at max - if tb.Count() > 1000 { - t.Errorf("Blacklist exceeded max size: %d", tb.Count()) - } -} - -func TestTokenBlacklistExpiredCleanup(t *testing.T) { - tb := NewTokenBlacklist() - - // Add some expired tokens - for i := 0; i < 500; i++ { - tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour)) - } - - // Add some valid tokens - for i := 0; i < 500; i++ { - tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour)) - } - - // Force cleanup - tb.Cleanup() - - // Only valid tokens should remain - if tb.Count() != 500 { - t.Errorf("Expected 500 valid tokens after cleanup, got %d", tb.Count()) - } - - // Verify only valid tokens remain - tb.mutex.RLock() - defer tb.mutex.RUnlock() - for token, expiry := range tb.tokens { - if time.Now().After(expiry) { - t.Errorf("Found expired token after cleanup: %s", token) - } - } -} - -func TestTokenBlacklistOldestEviction(t *testing.T) { - tb := NewTokenBlacklist() - - // Add tokens at capacity with different expiration times - baseTime := time.Now() - oldestToken := "oldest" - - // Add oldest token first - tb.Add(oldestToken, baseTime.Add(time.Hour)) - - // Fill up to capacity with newer tokens - for i := 0; i < 999; i++ { - tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2)) - } - - // Add a new token that should evict the oldest - newToken := "newest" - tb.Add(newToken, baseTime.Add(time.Hour*3)) - - // Verify oldest token was evicted - if tb.IsBlacklisted(oldestToken) { - t.Error("Oldest token should have been evicted") - } - - // Verify newest token is present - if !tb.IsBlacklisted(newToken) { - t.Error("Newest token should be present") - } -} - -func TestTokenBlacklistMemoryUsage(t *testing.T) { - tb := NewTokenBlacklist() - iterations := 10000 - - // Force initial GC - runtime.GC() - - // Record initial memory stats - var m1, m2 runtime.MemStats - runtime.ReadMemStats(&m1) - - // Simulate heavy usage - for i := 0; i < iterations; i++ { - // Add new token - tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour)) - - // Periodically check blacklisted status - if i%100 == 0 { - tb.IsBlacklisted(fmt.Sprintf("token%d", i-50)) - } - - // Periodically cleanup - if i%1000 == 0 { - tb.Cleanup() - } - } - - // Force GC and wait for it to complete - runtime.GC() - time.Sleep(100 * time.Millisecond) - runtime.ReadMemStats(&m2) - - // Check memory growth (using HeapAlloc for more accurate measurement) - memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc) - maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth - - if memoryGrowth > maxAllowedGrowth { - t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc) - t.Errorf("Excessive memory growth: %d bytes", memoryGrowth) - } - - // Verify size stayed within limits - if tb.Count() > 1000 { - t.Errorf("Blacklist exceeded max size: %d", tb.Count()) - } -} - -func TestConcurrentTokenBlacklistOperations(t *testing.T) { - tb := NewTokenBlacklist() - iterations := 1000 - concurrency := 10 - done := make(chan bool) - - // Start multiple goroutines performing operations - for i := 0; i < concurrency; i++ { - go func(id int) { - for j := 0; j < iterations; j++ { - // Add tokens - token := fmt.Sprintf("token%d-%d", id, j) - tb.Add(token, time.Now().Add(time.Hour)) - - // Check blacklist status - tb.IsBlacklisted(token) - - // Periodic cleanup - if j%100 == 0 { - tb.Cleanup() - } - } - done <- true - }(i) - } - - // Wait for all goroutines to complete - for i := 0; i < concurrency; i++ { - <-done - } - - // Verify size constraints were maintained - if tb.Count() > 1000 { - t.Errorf("Blacklist exceeded max size under concurrent operations: %d", tb.Count()) - } -} +// Removed tests related to the old TokenBlacklist implementation: +// - TestTokenBlacklistSizeLimit +// - TestTokenBlacklistExpiredCleanup +// - TestTokenBlacklistOldestEviction +// - TestTokenBlacklistMemoryUsage +// - TestConcurrentTokenBlacklistOperations func TestTokenCacheMemoryUsage(t *testing.T) { tc := NewTokenCache() diff --git a/jwt.go b/jwt.go index 7a03b8c..43e81c4 100644 --- a/jwt.go +++ b/jwt.go @@ -15,8 +15,10 @@ import ( "time" ) -var replayCacheMu sync.Mutex -var replayCache = make(map[string]time.Time) +var ( + replayCacheMu sync.Mutex + replayCache = make(map[string]time.Time) +) func cleanupReplayCache() { now := time.Now() @@ -164,6 +166,7 @@ func (j *JWT) Verify(issuerURL, clientID string) error { return nil } + func verifyAudience(tokenAudience interface{}, expectedAudience string) error { switch aud := tokenAudience.(type) { case string: diff --git a/main.go b/main.go index 526d182..8704ff0 100644 --- a/main.go +++ b/main.go @@ -9,11 +9,10 @@ import ( "net" "net/http" "net/url" + "runtime" "strings" "time" - "runtime" - "github.com/google/uuid" "golang.org/x/time/rate" ) @@ -52,7 +51,10 @@ func createDefaultHTTPClient() *http.Client { } } -const ConstSessionTimeout = 86400 // Session timeout in seconds +const ( + ConstSessionTimeout = 86400 // Session timeout in seconds + defaultBlacklistDuration = 24 * time.Hour // Default duration to blacklist a JTI +) // TokenVerifier interface for token verification type TokenVerifier interface { @@ -64,6 +66,13 @@ type JWTVerifier interface { VerifyJWTSignatureAndClaims(jwt *JWT, token string) error } +// TokenExchanger defines methods for OIDC token operations +type TokenExchanger interface { + ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) + GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) + RevokeTokenWithProvider(token, tokenType string) error +} + // TraefikOidc is the main struct for the OIDC middleware type TraefikOidc struct { next http.Handler @@ -74,7 +83,7 @@ type TraefikOidc struct { revocationURL string jwkCache JWKCacheInterface metadataCache *MetadataCache - tokenBlacklist *TokenBlacklist + tokenBlacklist *Cache // Replaced TokenBlacklist with generic Cache jwksURL string clientID string clientSecret string @@ -94,12 +103,13 @@ type TraefikOidc struct { allowedUserDomains map[string]struct{} allowedRolesAndGroups map[string]struct{} initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) - exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) - extractClaimsFunc func(tokenString string) (map[string]interface{}, error) - initComplete chan struct{} - endSessionURL string - postLogoutRedirectURI string - sessionManager *SessionManager + // exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) // Replaced by interface + extractClaimsFunc func(tokenString string) (map[string]interface{}, error) + initComplete chan struct{} + endSessionURL string + postLogoutRedirectURI string + sessionManager *SessionManager + tokenExchanger TokenExchanger // Added field for mocking } // ProviderMetadata holds OIDC provider metadata @@ -155,6 +165,29 @@ func (t *TraefikOidc) VerifyToken(token string) error { // Cache the verified token t.cacheVerifiedToken(token, jwt.Claims) + // Add JTI to blacklist AFTER successful verification to prevent replay + if jti, ok := jwt.Claims["jti"].(string); ok && jti != "" { + // Calculate expiry based on 'exp' claim if available, otherwise use default + expiry := time.Now().Add(defaultBlacklistDuration) + if expClaim, expOk := jwt.Claims["exp"].(float64); expOk { + expTime := time.Unix(int64(expClaim), 0) + tokenDuration := time.Until(expTime) + // Use token expiry if longer than default, capped at a reasonable max (e.g., 24h) + if tokenDuration > defaultBlacklistDuration && tokenDuration < (24*time.Hour) { + expiry = expTime + } else if tokenDuration <= 0 { + // If token already expired but somehow passed verification, use default + expiry = time.Now().Add(defaultBlacklistDuration) + } else { + // Use default if token expiry is shorter or excessively long + expiry = time.Now().Add(defaultBlacklistDuration) + } + } + // Use Set with a duration. Value 'true' is arbitrary, we only care about existence. + t.tokenBlacklist.Set(jti, true, time.Until(expiry)) + t.logger.Debugf("Added JTI %s to blacklist cache", jti) + } + return nil } @@ -165,11 +198,22 @@ func (t *TraefikOidc) performPreVerificationChecks(token string) error { return fmt.Errorf("rate limit exceeded") } - // Check if token is blacklisted - if t.tokenBlacklist.IsBlacklisted(token) { - return fmt.Errorf("token is blacklisted") + // Check if the raw token string itself is blacklisted (e.g., via explicit revocation) + if _, exists := t.tokenBlacklist.Get(token); exists { + return fmt.Errorf("token is blacklisted (raw string) in cache") } + // Also check if the JTI claim is blacklisted (replay detection) + claims, err := extractClaims(token) // Use existing helper + if err == nil { // Only check JTI if claims could be extracted + if jti, ok := claims["jti"].(string); ok && jti != "" { + if _, exists := t.tokenBlacklist.Get(jti); exists { + // Use a specific error message for replay + return fmt.Errorf("token replay detected (jti: %s) in cache", jti) + } + } + } // If claims extraction fails, proceed; full validation will catch token issues later. + return nil } @@ -296,7 +340,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h } return config.PostLogoutRedirectURI }(), - tokenBlacklist: NewTokenBlacklist(), + tokenBlacklist: NewCache(), // Use generic cache for blacklist jwkCache: &JWKCache{}, metadataCache: NewMetadataCache(), clientID: config.ClientID, @@ -316,7 +360,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger) t.extractClaimsFunc = extractClaims - t.exchangeCodeForTokenFunc = t.exchangeCodeForToken + // t.exchangeCodeForTokenFunc = t.exchangeCodeForToken // Removed, using interface now t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { t.defaultInitiateAuthentication(rw, req, session, redirectURL) } @@ -329,6 +373,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h t.tokenVerifier = t t.jwtVerifier = t t.startTokenCleanup() + t.tokenExchanger = t // Initialize the interface field to self go t.initializeMetadata(config.ProviderURL) return t, nil @@ -532,15 +577,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if !authenticated { + // Original logic: Always initiate authentication if not authenticated + t.logger.Debug("User not authenticated, initiating OIDC flow") t.defaultInitiateAuthentication(rw, req, session, redirectURL) - return + return // Stop processing } if needsRefresh { refreshed := t.refreshToken(rw, req, session) if !refreshed { + // Original logic: Always handle failed refresh as an expired token + t.logger.Debug("Token refresh failed, handling as expired token") t.handleExpiredToken(rw, req, session, redirectURL) - return + return // Stop processing } } @@ -621,6 +670,151 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.next.ServeHTTP(rw, req) } +// handleExpiredToken manages token expiration by clearing the session +// and initiating a new authentication flow. +func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { + // Clear authentication data but preserve CSRF state + session.SetAuthenticated(false) + session.SetAccessToken("") + session.SetRefreshToken("") + session.SetEmail("") + + // Save the cleared session state + if err := session.Save(req, rw); err != nil { + t.logger.Errorf("Failed to save cleared session: %v", err) + http.Error(rw, "Internal Server Error", http.StatusInternalServerError) + return + } + + t.defaultInitiateAuthentication(rw, req, session, redirectURL) +} + +// handleCallback processes the authentication callback from the OIDC provider. +// It validates the callback parameters, exchanges the authorization code for +// tokens, verifies the tokens, and establishes the user's session. +func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) { + session, err := t.sessionManager.GetSession(req) + if err != nil { + t.logger.Errorf("Session error: %v", err) + http.Error(rw, "Session error", http.StatusInternalServerError) + return + } + + t.logger.Debugf("Handling callback, URL: %s", req.URL.String()) + + // Check for errors in the callback + if req.URL.Query().Get("error") != "" { + errorDescription := req.URL.Query().Get("error_description") + t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription) + http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest) + return + } + + // Validate CSRF state + state := req.URL.Query().Get("state") + if state == "" { + t.logger.Error("No state in callback") + http.Error(rw, "State parameter missing in callback", http.StatusBadRequest) + return + } + + csrfToken := session.GetCSRF() + if csrfToken == "" { + t.logger.Error("CSRF token missing in session") + http.Error(rw, "CSRF token missing", http.StatusBadRequest) + return + } + + if state != csrfToken { + t.logger.Error("State parameter does not match CSRF token in session") + http.Error(rw, "Invalid state parameter", http.StatusBadRequest) + return + } + + // Exchange code for tokens + code := req.URL.Query().Get("code") + if code == "" { + t.logger.Error("No code in callback") + http.Error(rw, "No code in callback", http.StatusBadRequest) + return + } + + // Get the code verifier from the session for PKCE flow + codeVerifier := session.GetCodeVerifier() + + tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier) + if err != nil { + t.logger.Errorf("Failed to exchange code for token: %v", err) + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + // Verify tokens and claims + // Use the exported VerifyToken method now that handleCallback is in main.go + if err := t.VerifyToken(tokenResponse.IDToken); err != nil { + t.logger.Errorf("Failed to verify id_token: %v", err) + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + claims, err := t.extractClaimsFunc(tokenResponse.IDToken) + if err != nil { + t.logger.Errorf("Failed to extract claims: %v", err) + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + // Verify nonce to prevent replay attacks + nonceClaim, ok := claims["nonce"].(string) + if !ok || nonceClaim == "" { + t.logger.Error("Nonce claim missing in id_token") + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + sessionNonce := session.GetNonce() + if sessionNonce == "" { + t.logger.Error("Nonce not found in session") + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + if nonceClaim != sessionNonce { + t.logger.Error("Nonce claim does not match session nonce") + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + // Validate user's email domain + // Use the unexported isAllowedDomain method now that handleCallback is in main.go + email, _ := claims["email"].(string) + if email == "" || !t.isAllowedDomain(email) { + t.logger.Errorf("Invalid or disallowed email: %s", email) + http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden) + return + } + + // Update session with authentication data + session.SetAuthenticated(true) + session.SetEmail(email) + session.SetAccessToken(tokenResponse.IDToken) + session.SetRefreshToken(tokenResponse.RefreshToken) + + if err := session.Save(req, rw); err != nil { + t.logger.Errorf("Failed to save session: %v", err) + http.Error(rw, "Failed to save session", http.StatusInternalServerError) + return + } + + // Redirect to original path or root + redirectPath := "/" + if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath { + redirectPath = incomingPath + } + + http.Redirect(rw, req, redirectPath, http.StatusFound) +} + // determineExcludedURL checks if the current request URL is in the excluded list func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool { for excludedURL := range t.excludedURLs { @@ -675,17 +869,26 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo return false, false, true // Session is invalid, consider it expired } - // Verify the token - if err := t.verifyToken(accessToken); err != nil { - t.logger.Errorf("Token verification failed: %v", err) - return false, false, true // Token is invalid, consider it expired + // Verify the token structure and signature first + jwt, err := parseJWT(accessToken) + if err != nil { + t.logger.Errorf("Failed to parse JWT during auth check: %v", err) + return false, false, true // Invalid format, treat as expired/invalid + } + if err := t.VerifyJWTSignatureAndClaims(jwt, accessToken); err != nil { + // Check if the error is specifically about expiration + if strings.Contains(err.Error(), "token has expired") { + t.logger.Debugf("Token signature/claims valid but token expired, attempting refresh") + // Token is expired but otherwise valid, signal for refresh + return true, true, false // Authenticated=true (was valid), NeedsRefresh=true, Expired=false (because refresh is possible) + } + // Other verification error (signature, issuer, audience etc.) + t.logger.Errorf("Token verification failed (non-expiration): %v", err) + return false, false, true // Token is invalid for other reasons } - claims, err := extractClaims(accessToken) - if err != nil { - t.logger.Errorf("Failed to extract claims: %v", err) - return false, false, true - } + // Claims already parsed within VerifyJWTSignatureAndClaims if it didn't error early + claims := jwt.Claims expClaim, ok := claims["exp"].(float64) if !ok { @@ -696,17 +899,18 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo now := time.Now().Unix() expTime := int64(expClaim) - if now > expTime { - t.logger.Debug("Token has expired") - return false, false, true - } - - gracePeriod := time.Minute * 5 - if now+int64(gracePeriod.Seconds()) > expTime { - t.logger.Debug("Token will expire soon") - return true, true, false // Token will expire soon, needs refresh + // Expiration check is now handled within VerifyJWTSignatureAndClaims logic above + // We only get here if the token is valid and not expired + + // Check if token is nearing expiration (needs refresh proactively) + // Define a grace period, e.g., 5 minutes before actual expiry + refreshGracePeriod := int64(5 * 60) + if expTime-now < refreshGracePeriod { + t.logger.Debugf("Token nearing expiration (within %d seconds), scheduling refresh", refreshGracePeriod) + return true, true, false // Needs proactive refresh } + // Token is valid, not expired, and not nearing expiration return true, false, false } @@ -827,7 +1031,7 @@ func (t *TraefikOidc) startTokenCleanup() { for range ticker.C { t.logger.Debug("Starting token cleanup cycle") t.tokenCache.Cleanup() - t.tokenBlacklist.Cleanup() + // t.tokenBlacklist.Cleanup() // Removed: Generic Cache handles its own cleanup t.jwkCache.Cleanup() // Assuming jwkCache is the cache from cache.go // Removed runtime.GC() call } @@ -841,7 +1045,8 @@ func (t *TraefikOidc) RevokeToken(token string) { // Add to blacklist with default expiration expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration - t.tokenBlacklist.Add(token, expiry) + // Use Set with a duration. Value 'true' is arbitrary, we only care about existence. + t.tokenBlacklist.Set(token, true, time.Until(expiry)) } // RevokeTokenWithProvider revokes the token with the provider @@ -890,7 +1095,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se return false } - newToken, err := t.getNewTokenWithRefreshToken(refreshToken) + newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken) if err != nil { t.logger.Errorf("Failed to refresh token: %v", err) return false @@ -986,3 +1191,19 @@ func buildFullURL(scheme, host, path string) string { return fmt.Sprintf("%s://%s%s", scheme, host, path) } + +// --- TokenExchanger Interface Implementation --- + +// ExchangeCodeForToken implements the TokenExchanger interface. +// It calls the existing exchangeTokens helper function. +func (t *TraefikOidc) ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) { + // Note: The original exchangeTokens helper is defined in helpers.go and is already a method on *TraefikOidc + return t.exchangeTokens(ctx, grantType, codeOrToken, redirectURL, codeVerifier) +} + +// GetNewTokenWithRefreshToken implements the TokenExchanger interface. +// It calls the existing getNewTokenWithRefreshToken helper function. +func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { + // Note: The original getNewTokenWithRefreshToken helper is defined in helpers.go and is already a method on *TraefikOidc + return t.getNewTokenWithRefreshToken(refreshToken) +} diff --git a/main_test.go b/main_test.go index 4535798..fff7910 100644 --- a/main_test.go +++ b/main_test.go @@ -101,29 +101,47 @@ func (ts *TestSuite) Setup() { jwksURL: "https://test-jwks-url.com", revocationURL: "https://revocation-endpoint.com", limiter: rate.NewLimiter(rate.Every(time.Second), 10), - tokenBlacklist: NewTokenBlacklist(), + tokenBlacklist: NewCache(), // Use generic cache for blacklist tokenCache: NewTokenCache(), logger: logger, allowedUserDomains: map[string]struct{}{"example.com": {}}, excludedURLs: map[string]struct{}{"/favicon": {}}, httpClient: &http.Client{}, - extractClaimsFunc: extractClaims, - initComplete: make(chan struct{}), - sessionManager: ts.sessionManager, + // Explicitly set paths as New() is bypassed + redirURLPath: "/callback", // Assume default callback path for tests + logoutURLPath: "/callback/logout", // Assume default logout path for tests + tokenURL: "https://test-issuer.com/token", // Explicitly set for refresh tests + extractClaimsFunc: extractClaims, + initComplete: make(chan struct{}), + sessionManager: ts.sessionManager, } close(ts.tOidc.initComplete) - ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc + // ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc // Removed ts.tOidc.tokenVerifier = ts.tOidc ts.tOidc.jwtVerifier = ts.tOidc + // Set default mock exchanger + ts.tOidc.tokenExchanger = &MockTokenExchanger{ + ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) { + // Default mock behavior for code exchange + return &TokenResponse{ + IDToken: ts.token, // Use the valid token from setup + AccessToken: ts.token, + RefreshToken: "default-refresh-token", + ExpiresIn: 3600, + }, nil + }, + RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) { + // Default mock behavior for refresh (can be overridden in tests) + return nil, fmt.Errorf("default mock: refresh not expected") + }, + RevokeTokenFunc: func(token, tokenType string) error { + // Default mock behavior for revoke + return nil + }, + } } -// Helper functions used by TraefikOidc -func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { - return &TokenResponse{ - IDToken: ts.token, - RefreshToken: "test-refresh-token", - }, nil -} +// Helper function exchangeCodeForTokenFunc removed as it's unused after refactoring to TokenExchanger interface. // MockJWKCache implements JWKCacheInterface type MockJWKCache struct { @@ -141,6 +159,34 @@ func (m *MockJWKCache) Cleanup() { m.Err = nil } +// MockTokenExchanger implements TokenExchanger for testing +type MockTokenExchanger struct { + ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) + RefreshTokenFunc func(refreshToken string) (*TokenResponse, error) + RevokeTokenFunc func(token, tokenType string) error +} + +func (m *MockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) { + if m.ExchangeCodeFunc != nil { + return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier) + } + return nil, fmt.Errorf("ExchangeCodeFunc not implemented in mock") +} + +func (m *MockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { + if m.RefreshTokenFunc != nil { + return m.RefreshTokenFunc(refreshToken) + } + return nil, fmt.Errorf("RefreshTokenFunc not implemented in mock") +} + +func (m *MockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error { + if m.RevokeTokenFunc != nil { + return m.RevokeTokenFunc(token, tokenType) + } + return fmt.Errorf("RevokeTokenFunc not implemented in mock") +} + // Helper function to create a JWT token func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) { header := map[string]interface{}{ @@ -228,13 +274,14 @@ func TestVerifyToken(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Reset token blacklist and cache for each test - ts.tOidc.tokenBlacklist = NewTokenBlacklist() + ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist ts.tOidc.tokenCache = NewTokenCache() ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10) // Set up the test case if tc.blacklist { - ts.tOidc.tokenBlacklist.Add(tc.token, time.Now().Add(1*time.Hour)) + // Use Set with a duration. Value 'true' is arbitrary. + ts.tOidc.tokenBlacklist.Set(tc.token, true, 1*time.Hour) } if tc.rateLimit { @@ -282,13 +329,53 @@ func TestServeHTTP(t *testing.T) { ts.tOidc.next = nextHandler ts.tOidc.name = "test" + // Helper to create an expired token + createExpiredToken := func() string { + exp := time.Now().Add(-1 * time.Hour).Unix() // Expired 1 hour ago + iat := time.Now().Add(-2 * time.Hour).Unix() + nbf := time.Now().Add(-2 * time.Hour).Unix() + expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce-expired", // Different nonce for clarity + "jti": generateRandomString(16), + }) + return expiredToken + } + + // Helper to create a new valid token (simulating refresh) + createNewValidToken := func() string { + exp := time.Now().Add(1 * time.Hour).Unix() // Valid for 1 hour + iat := time.Now().Unix() + nbf := time.Now().Unix() + newToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + // "nonce": "test-nonce-new", // Nonce is typically not included/validated in refreshed tokens + "jti": generateRandomString(16), + }) + return newToken + } + tests := []struct { - name string - requestPath string - sessionValues map[interface{}]interface{} - expectedStatus int - expectedBody string - setupSession func(*SessionData) + name string + requestPath string + sessionValues map[interface{}]interface{} + expectedStatus int + expectedBody string + setupSession func(*SessionData) + mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) + assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) // Added for post-request checks }{ { name: "Excluded URL", @@ -299,28 +386,77 @@ func TestServeHTTP(t *testing.T) { { name: "Unauthenticated request to protected URL", requestPath: "/protected", - expectedStatus: http.StatusFound, + expectedStatus: http.StatusFound, // Expect redirect to OIDC }, { - name: "Authenticated request to protected URL", + name: "Authenticated request to protected URL (Valid Token)", requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") - session.SetAccessToken(ts.token) + session.SetAccessToken(ts.token) // Use the valid token generated in Setup + session.SetRefreshToken("valid-refresh-token") }, expectedStatus: http.StatusOK, expectedBody: "OK", }, + { + name: "Authenticated request with expired token and successful refresh", + requestPath: "/protected", + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) // Still marked authenticated initially + session.SetEmail("user@example.com") + session.SetAccessToken(createExpiredToken()) // Set expired token + session.SetRefreshToken("valid-refresh-token") // Set valid refresh token + }, + mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { + return func(refreshToken string) (*TokenResponse, error) { + if refreshToken != "valid-refresh-token" { + return nil, fmt.Errorf("mock error: expected 'valid-refresh-token', got '%s'", refreshToken) + } + // Simulate successful refresh + newToken := createNewValidToken() + return &TokenResponse{ + IDToken: newToken, // Return new valid token + AccessToken: newToken, // Often the same as ID token in tests + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + }, nil + } + }, + expectedStatus: http.StatusOK, // Expect success after refresh + expectedBody: "OK", + assertSessionAfterRequest: func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) { + // Create a new request to read the cookies set by the response recorder + reqForCookieRead := httptest.NewRequest("GET", "/protected", nil) + for _, cookie := range rr.Result().Cookies() { + reqForCookieRead.AddCookie(cookie) + } + // Get session based on response cookies + session, err := sessionManager.GetSession(reqForCookieRead) + if err != nil { + t.Fatalf("Failed to get session after request: %v", err) + } + // Assert new tokens are in the session + // Direct comparison with createNewValidToken() is flawed as it generates a new token each time. + // Instead, check if the token was updated (not empty) and verify the refresh token. + if session.GetAccessToken() == "" { + t.Errorf("Expected access token to be updated in session, but it was empty") + } + if session.GetRefreshToken() != "new-refresh-token" { + t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken()) + } + }, + }, { name: "Logout URL", - requestPath: "/logout", + requestPath: "/logout", // Assuming logout path is configured or defaulted correctly setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") session.SetAccessToken(ts.token) }, - expectedStatus: http.StatusOK, + expectedStatus: http.StatusFound, // Expect redirect after logout expectedBody: "", }, } @@ -328,40 +464,79 @@ func TestServeHTTP(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest("GET", tc.requestPath, nil) - req.Header.Set("X-Forwarded-Proto", "http") - req.Header.Set("X-Forwarded-Host", "localhost") + // Set common headers needed by the logic (determineScheme, determineHost) + req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that + req.Header.Set("X-Forwarded-Host", "testhost.com") + req.Host = "testhost.com" // Also set Host header + rr := httptest.NewRecorder() // Setup session if needed session, err := ts.tOidc.sessionManager.GetSession(req) if err != nil { - t.Fatalf("Failed to get session: %v", err) + t.Fatalf("Test %s: Failed to get initial session: %v", tc.name, err) } if tc.setupSession != nil { tc.setupSession(session) - if err := session.Save(req, rr); err != nil { - t.Fatalf("Failed to save session: %v", err) + // Save session to recorder to get cookies + saveRecorder := httptest.NewRecorder() + if err := session.Save(req, saveRecorder); err != nil { + t.Fatalf("Test %s: Failed to save initial session: %v", tc.name, err) } - - // Copy cookies to the new request - for _, cookie := range rr.Result().Cookies() { + // Copy cookies from save recorder to the actual request + for _, cookie := range saveRecorder.Result().Cookies() { req.AddCookie(cookie) } - rr = httptest.NewRecorder() + } + + // Mocking setup for TokenExchanger + originalExchanger := ts.tOidc.tokenExchanger // Store original + mockExchanger, isMock := originalExchanger.(*MockTokenExchanger) + if !isMock { + // This case should ideally not happen if Setup correctly assigns the mock, + // but handle it defensively. + t.Logf("Warning: Default exchanger was not the mock. Creating a temporary mock.") + mockExchanger = &MockTokenExchanger{ + ExchangeCodeFunc: originalExchanger.ExchangeCodeForToken, + RefreshTokenFunc: originalExchanger.GetNewTokenWithRefreshToken, + RevokeTokenFunc: originalExchanger.RevokeTokenWithProvider, + } + ts.tOidc.tokenExchanger = mockExchanger // Temporarily assign mock + } + + // Override specific mock methods if needed for the test case + originalMockRefreshFunc := mockExchanger.RefreshTokenFunc // Store current mock func + if tc.mockRefreshTokenFunc != nil { + // Assign the test case specific mock function + mockExchanger.RefreshTokenFunc = tc.mockRefreshTokenFunc(originalExchanger.GetNewTokenWithRefreshToken) } // Call ServeHTTP ts.tOidc.ServeHTTP(rr, req) - // Check response - if rr.Code != tc.expectedStatus { - t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code) + // Restore original exchanger and mock function state + ts.tOidc.tokenExchanger = originalExchanger + if tc.mockRefreshTokenFunc != nil && mockExchanger != nil { + // Restore the previous mock function if we overrode it + mockExchanger.RefreshTokenFunc = originalMockRefreshFunc } + + // Check response status + if rr.Code != tc.expectedStatus { + t.Errorf("Test %s: Expected status %d, got %d. Body: %s", tc.name, tc.expectedStatus, rr.Code, rr.Body.String()) + } + + // Check response body if expected if tc.expectedBody != "" { if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody { - t.Errorf("Expected body %q, got %q", tc.expectedBody, body) + t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body) } } + + // Perform post-request session assertions if defined + if tc.assertSessionAfterRequest != nil { + tc.assertSessionAfterRequest(t, rr, req, ts.tOidc.sessionManager) + } }) } } @@ -552,17 +727,39 @@ func TestHandleCallback(t *testing.T) { name: "Disallowed Email", queryParams: "?code=test-code&state=test-csrf-token", exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { + // Generate a unique token for this test case to avoid replay issues + // Use claims relevant to this test (disallowed email) + now := time.Now() + exp := now.Add(1 * time.Hour).Unix() + iat := now.Unix() + nbf := now.Unix() + disallowedToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject-disallowed", + "email": "user@disallowed.com", // The disallowed email for this test + "nonce": "test-nonce", // Match the nonce set in sessionSetupFunc + "jti": generateRandomString(16), // Unique JTI + }) + if err != nil { + return nil, fmt.Errorf("failed to create disallowed token for test: %w", err) + } return &TokenResponse{ - IDToken: ts.token, - RefreshToken: "test-refresh-token", - }, nil - }, - extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { - return map[string]interface{}{ - "email": "user@disallowed.com", - "nonce": "test-nonce", + IDToken: disallowedToken, + RefreshToken: "test-refresh-token-disallowed", }, nil }, + // Remove mock extractClaimsFunc - let the real one parse the disallowedToken + // The test should still fail correctly on the email check later. + // extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { + // return map[string]interface{}{ + // "email": "user@disallowed.com", + // "nonce": "test-nonce", + // }, nil + // }, sessionSetupFunc: func(session *SessionData) { session.SetCSRF("test-csrf-token") session.SetNonce("test-nonce") @@ -635,20 +832,61 @@ func TestHandleCallback(t *testing.T) { } for _, tc := range tests { + tc := tc // Capture range variable t.Run(tc.name, func(t *testing.T) { + // Clear the global replay cache before each test run + replayCacheMu.Lock() + replayCache = make(map[string]time.Time) // Reset the global cache + replayCacheMu.Unlock() + + // Explicitly clear the shared blacklist at the start of each sub-test + // to ensure no state leaks, even though we expect the local one to be used. + // Note: This line might be redundant now that the verifier is local, but keep for safety. + ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist + logger := NewLogger("info") sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) // Create a new instance for each test to avoid state carryover - tOidc := &TraefikOidc{ - allowedUserDomains: map[string]struct{}{"example.com": {}}, - logger: logger, - exchangeCodeForTokenFunc: tc.exchangeCodeForToken, - extractClaimsFunc: tc.extractClaimsFunc, - tokenVerifier: ts.tOidc.tokenVerifier, - jwtVerifier: ts.tOidc.jwtVerifier, - sessionManager: sessionManager, + instanceExtractClaimsFunc := tc.extractClaimsFunc + if instanceExtractClaimsFunc == nil { + instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case } + tOidc := &TraefikOidc{ + allowedUserDomains: map[string]struct{}{"example.com": {}}, + logger: logger, + // exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field + extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function + tokenVerifier: nil, // Will be set to self below + jwtVerifier: nil, // Temporarily nil, will be set below + sessionManager: sessionManager, + tokenExchanger: &MockTokenExchanger{ // Create a new mock exchanger for this specific test run + ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) { + // Wrap the test case function to match the required signature + if tc.exchangeCodeForToken != nil { + // Only call if the test case provided a function + return tc.exchangeCodeForToken(codeOrToken, redirectURL, codeVerifier) + } + // Provide a default behavior or error if no mock was provided for this test case + return nil, fmt.Errorf("mock ExchangeCodeFunc not implemented for this test case") + }, + // Keep other mock funcs nil or provide defaults if needed by other parts of handleCallback + }, + tokenCache: NewTokenCache(), // Initialize token cache + limiter: rate.NewLimiter(rate.Inf, 0), // Initialize rate limiter + tokenBlacklist: NewCache(), // Initialize token blacklist cache + + // Add potentially missing fields based on New() comparison + clientID: ts.tOidc.clientID, + issuerURL: ts.tOidc.issuerURL, + jwkCache: ts.tOidc.jwkCache, // Use the mock cache from TestSuite + httpClient: ts.tOidc.httpClient, + initComplete: make(chan struct{}), // Initialize the channel + // Setting other fields like paths, enablePKCE etc. if needed + } + tOidc.tokenVerifier = tOidc // Point tokenVerifier to the local instance NOW + tOidc.jwtVerifier = tOidc // Point jwtVerifier to the local instance NOW + close(tOidc.initComplete) // Mark this test instance as initialized // Create request and response recorder req := httptest.NewRequest("GET", "/callback"+tc.queryParams, nil) @@ -839,13 +1077,14 @@ func TestOIDCHandler(t *testing.T) { tc := tc // Capture range variable t.Run(tc.name, func(t *testing.T) { // Reset token blacklist and cache - ts.tOidc.tokenBlacklist = NewTokenBlacklist() + ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist ts.tOidc.tokenCache = NewTokenCache() ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10) // Set up the test case if tc.blacklist { - ts.tOidc.tokenBlacklist.Add(ts.token, time.Now().Add(1*time.Hour)) + // Use Set with a duration. Value 'true' is arbitrary. + ts.tOidc.tokenBlacklist.Set(ts.token, true, 1*time.Hour) } if tc.rateLimit { @@ -948,7 +1187,7 @@ func TestHandleLogout(t *testing.T) { endSessionURL: tc.endSessionURL, scheme: "http", logger: logger, - tokenBlacklist: NewTokenBlacklist(), + tokenBlacklist: NewCache(), // Use generic cache for blacklist httpClient: &http.Client{}, clientID: "test-client-id", clientSecret: "test-client-secret", @@ -1018,13 +1257,13 @@ func TestHandleLogout(t *testing.T) { // Check token blacklist if token := session.GetAccessToken(); token != "" { - if !tOidc.tokenBlacklist.IsBlacklisted(token) { - t.Error("Access token was not blacklisted") + if _, exists := tOidc.tokenBlacklist.Get(token); !exists { + t.Error("Access token was not blacklisted in cache") } } if token := session.GetRefreshToken(); token != "" { - if !tOidc.tokenBlacklist.IsBlacklisted(token) { - t.Error("Refresh token was not blacklisted") + if _, exists := tOidc.tokenBlacklist.Get(token); !exists { + t.Error("Refresh token was not blacklisted in cache") } } }) @@ -1121,7 +1360,7 @@ func TestRevokeToken(t *testing.T) { t.Run("Token revocation", func(t *testing.T) { // Create a new instance for this specific test tOidc := &TraefikOidc{ - tokenBlacklist: NewTokenBlacklist(), + tokenBlacklist: NewCache(), // Use generic cache for blacklist tokenCache: NewTokenCache(), } @@ -1136,8 +1375,8 @@ func TestRevokeToken(t *testing.T) { t.Error("Token was not removed from cache") } - // Verify token was added to blacklist - if !tOidc.tokenBlacklist.IsBlacklisted(token) { + // Verify token was added to blacklist cache + if _, exists := tOidc.tokenBlacklist.Get(token); !exists { t.Error("Token was not added to blacklist") } }) @@ -1404,7 +1643,6 @@ func TestMultipleMiddlewareInstances(t *testing.T) { middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }), config, "test") - if err != nil { t.Fatalf("Failed to create middleware for route %s: %v", route, err) } @@ -2043,7 +2281,6 @@ func TestExchangeCodeForToken(t *testing.T) { // Test exchangeCodeForToken response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier) - if err != nil { t.Errorf("Unexpected error: %v", err) } diff --git a/metadata_cache.go b/metadata_cache.go index 64f4b6c..0e96d59 100644 --- a/metadata_cache.go +++ b/metadata_cache.go @@ -34,6 +34,7 @@ func (c *MetadataCache) Cleanup() { c.metadata = nil } } + func (c *MetadataCache) isCacheValid() bool { return c.metadata != nil && time.Now().Before(c.expiresAt) } @@ -67,15 +68,9 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, } c.metadata = metadata - // Calculate expiration time based on usage patterns - usageCount := 0 // This should be replaced with actual usage tracking logic - if usageCount < 10 { - c.expiresAt = time.Now().Add(30 * time.Minute) - } else if usageCount < 50 { - c.expiresAt = time.Now().Add(1 * time.Hour) - } else { - c.expiresAt = time.Now().Add(2 * time.Hour) - } + // Set a fixed cache lifetime (e.g., 1 hour) + // TODO: Consider making this configurable or respecting HTTP cache headers + c.expiresAt = time.Now().Add(1 * time.Hour) // End of GetMetadata return metadata, nil diff --git a/session_test.go b/session_test.go index 6ea172a..c145ddc 100644 --- a/session_test.go +++ b/session_test.go @@ -1,7 +1,9 @@ package traefikoidc import ( - "math/rand" + "crypto/rand" + "fmt" + "math/big" "net/http/httptest" "strings" "testing" @@ -12,7 +14,12 @@ func generateRandomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" b := make([]byte, length) for i := range b { - b[i] = charset[rand.Intn(len(charset))] + num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) + if err != nil { + // Handle error appropriately in a real application, maybe panic in test helper + panic(fmt.Sprintf("crypto/rand failed: %v", err)) + } + b[i] = charset[num.Int64()] } return string(b) }