diff --git a/blacklist.go b/blacklist.go deleted file mode 100644 index 27a75c8..0000000 --- a/blacklist.go +++ /dev/null @@ -1,110 +0,0 @@ -package traefikoidc - -import ( - "sync" - "time" -) - -// TokenBlacklist manages a thread-safe list of revoked tokens with expiration. -type TokenBlacklist struct { - tokens map[string]time.Time - mutex sync.RWMutex -} - -// NewTokenBlacklist creates a new token blacklist instance. -func NewTokenBlacklist() *TokenBlacklist { - return &TokenBlacklist{ - tokens: make(map[string]time.Time), - } -} - -// Add adds a token to the blacklist with an expiration time. -func (b *TokenBlacklist) Add(token string, expiry time.Time) { - b.mutex.Lock() - defer b.mutex.Unlock() - - // Clean up expired tokens if we're at capacity - if len(b.tokens) >= 1000 { - now := time.Now() - futureThreshold := now.Add(time.Minute) - for t, exp := range b.tokens { - if now.After(exp) || futureThreshold.After(exp) { - delete(b.tokens, t) - } - } - - // If still at capacity, remove oldest token - if len(b.tokens) >= 1000 { - var oldestToken string - var oldestTime time.Time - first := true - for t, exp := range b.tokens { - if first || exp.Before(oldestTime) { - oldestToken = t - oldestTime = exp - first = false - } - } - if oldestToken != "" { - delete(b.tokens, oldestToken) - } - } - } - - b.tokens[token] = expiry -} - -// IsBlacklisted checks if a token is in the blacklist and not expired. -func (b *TokenBlacklist) IsBlacklisted(token string) bool { - b.mutex.RLock() - defer b.mutex.RUnlock() - - expiry, exists := b.tokens[token] - if !exists { - return false - } - - // If token is expired, remove it and return false - if time.Now().After(expiry) { - // Switch to write lock to remove expired token - b.mutex.RUnlock() - b.mutex.Lock() - delete(b.tokens, token) - b.mutex.Unlock() - b.mutex.RLock() - return false - } - - return true -} - -// Cleanup removes expired tokens from the blacklist. -// Also removes tokens that will expire within the next minute to prevent edge cases. -func (b *TokenBlacklist) Cleanup() { - b.mutex.Lock() - defer b.mutex.Unlock() - - now := time.Now() - futureThreshold := now.Add(time.Minute) - - for token, expiry := range b.tokens { - // Remove tokens that are expired or will expire soon - if now.After(expiry) || futureThreshold.After(expiry) { - delete(b.tokens, token) - } - } -} - -// Remove removes a token from the blacklist regardless of its expiration. -func (b *TokenBlacklist) Remove(token string) { - b.mutex.Lock() - defer b.mutex.Unlock() - delete(b.tokens, token) -} - -// Count returns the current number of tokens in the blacklist. -func (b *TokenBlacklist) Count() int { - b.mutex.RLock() - defer b.mutex.RUnlock() - return len(b.tokens) -} diff --git a/blacklist_test.go b/blacklist_test.go deleted file mode 100644 index 571348c..0000000 --- a/blacklist_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package traefikoidc - -import ( - "testing" - "time" -) - -func TestTokenBlacklist_Add(t *testing.T) { - blacklist := NewTokenBlacklist() - token := "testToken" - expiry := time.Now().Add(time.Hour) - - blacklist.Add(token, expiry) - - if !blacklist.IsBlacklisted(token) { - t.Errorf("Expected token to be blacklisted, but it was not") - } -} - -func TestTokenBlacklist_IsBlacklisted(t *testing.T) { - blacklist := NewTokenBlacklist() - token := "testToken" - expiry := time.Now().Add(time.Hour) - - blacklist.Add(token, expiry) - - if !blacklist.IsBlacklisted(token) { - t.Errorf("Expected token to be blacklisted, but it was not") - } - - if blacklist.IsBlacklisted("nonExistentToken") { - t.Errorf("Expected non-existent token to not be blacklisted, but it was") - } -} - -func TestTokenBlacklist_Cleanup(t *testing.T) { - blacklist := NewTokenBlacklist() - token := "testToken" - expiry := time.Now().Add(-time.Hour) // Expired token - - blacklist.Add(token, expiry) - blacklist.Cleanup() - - if blacklist.IsBlacklisted(token) { - t.Errorf("Expected expired token to be removed after cleanup, but it was not") - } -} - -func TestTokenBlacklist_Remove(t *testing.T) { - blacklist := NewTokenBlacklist() - token := "testToken" - expiry := time.Now().Add(time.Hour) - - blacklist.Add(token, expiry) - blacklist.Remove(token) - - if blacklist.IsBlacklisted(token) { - t.Errorf("Expected token to be removed, but it was not") - } -} - -func TestTokenBlacklist_Count(t *testing.T) { - blacklist := NewTokenBlacklist() - token1 := "token1" - token2 := "token2" - expiry := time.Now().Add(time.Hour) - - blacklist.Add(token1, expiry) - blacklist.Add(token2, expiry) - - if blacklist.Count() != 2 { - t.Errorf("Expected blacklist count to be 2, but got %d", blacklist.Count()) - } -} diff --git a/helpers.go b/helpers.go index 6d27528..9f9f64e 100644 --- a/helpers.go +++ b/helpers.go @@ -81,7 +81,7 @@ type TokenResponse struct { // - codeOrToken: Either the authorization code or refresh token // - redirectURL: The callback URL for authorization code grant // - codeVerifier: Optional PKCE code verifier for authorization code grant -func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string, codeVerifier string) (*TokenResponse, error) { +func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) { data := url.Values{ "grant_type": {grantType}, "client_id": {t.clientID}, @@ -153,149 +153,6 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe return tokenResponse, nil } -// handleExpiredToken manages token expiration by clearing the session -// and initiating a new authentication flow. -func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { - // Clear authentication data but preserve CSRF state - session.SetAuthenticated(false) - session.SetAccessToken("") - session.SetRefreshToken("") - session.SetEmail("") - - // Save the cleared session state - if err := session.Save(req, rw); err != nil { - t.logger.Errorf("Failed to save cleared session: %v", err) - http.Error(rw, "Internal Server Error", http.StatusInternalServerError) - return - } - - t.defaultInitiateAuthentication(rw, req, session, redirectURL) -} - -// handleCallback processes the authentication callback from the OIDC provider. -// It validates the callback parameters, exchanges the authorization code for -// tokens, verifies the tokens, and establishes the user's session. -func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) { - session, err := t.sessionManager.GetSession(req) - if err != nil { - t.logger.Errorf("Session error: %v", err) - http.Error(rw, "Session error", http.StatusInternalServerError) - return - } - - t.logger.Debugf("Handling callback, URL: %s", req.URL.String()) - - // Check for errors in the callback - if req.URL.Query().Get("error") != "" { - errorDescription := req.URL.Query().Get("error_description") - t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription) - http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest) - return - } - - // Validate CSRF state - state := req.URL.Query().Get("state") - if state == "" { - t.logger.Error("No state in callback") - http.Error(rw, "State parameter missing in callback", http.StatusBadRequest) - return - } - - csrfToken := session.GetCSRF() - if csrfToken == "" { - t.logger.Error("CSRF token missing in session") - http.Error(rw, "CSRF token missing", http.StatusBadRequest) - return - } - - if state != csrfToken { - t.logger.Error("State parameter does not match CSRF token in session") - http.Error(rw, "Invalid state parameter", http.StatusBadRequest) - return - } - - // Exchange code for tokens - code := req.URL.Query().Get("code") - if code == "" { - t.logger.Error("No code in callback") - http.Error(rw, "No code in callback", http.StatusBadRequest) - return - } - - // Get the code verifier from the session for PKCE flow - codeVerifier := session.GetCodeVerifier() - - tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL, codeVerifier) - if err != nil { - t.logger.Errorf("Failed to exchange code for token: %v", err) - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - // Verify tokens and claims - if err := t.verifyToken(tokenResponse.IDToken); err != nil { - t.logger.Errorf("Failed to verify id_token: %v", err) - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - claims, err := t.extractClaimsFunc(tokenResponse.IDToken) - if err != nil { - t.logger.Errorf("Failed to extract claims: %v", err) - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - // Verify nonce to prevent replay attacks - nonceClaim, ok := claims["nonce"].(string) - if !ok || nonceClaim == "" { - t.logger.Error("Nonce claim missing in id_token") - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - sessionNonce := session.GetNonce() - if sessionNonce == "" { - t.logger.Error("Nonce not found in session") - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - if nonceClaim != sessionNonce { - t.logger.Error("Nonce claim does not match session nonce") - http.Error(rw, "Authentication failed", http.StatusInternalServerError) - return - } - - // Validate user's email domain - email, _ := claims["email"].(string) - if email == "" || !t.isAllowedDomain(email) { - t.logger.Errorf("Invalid or disallowed email: %s", email) - http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden) - return - } - - // Update session with authentication data - session.SetAuthenticated(true) - session.SetEmail(email) - session.SetAccessToken(tokenResponse.IDToken) - session.SetRefreshToken(tokenResponse.RefreshToken) - - if err := session.Save(req, rw); err != nil { - t.logger.Errorf("Failed to save session: %v", err) - http.Error(rw, "Failed to save session", http.StatusInternalServerError) - return - } - - // Redirect to original path or root - redirectPath := "/" - if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath { - redirectPath = incomingPath - } - - http.Redirect(rw, req, redirectPath, http.StatusFound) -} - // extractClaims parses a JWT token and extracts its claims. // It handles base64url decoding and JSON parsing of the token payload. func extractClaims(tokenString string) (map[string]interface{}, error) { diff --git a/helpers_test.go b/helpers_test.go index a55372c..0b0e047 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -7,172 +7,12 @@ import ( "time" ) -func TestTokenBlacklistSizeLimit(t *testing.T) { - tb := NewTokenBlacklist() - - // Add tokens up to maxSize - for i := 0; i < 1000; i++ { - tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour)) - } - - // Verify size is at max - if tb.Count() != 1000 { - t.Errorf("Expected blacklist size to be 1000, got %d", tb.Count()) - } - - // Add one more token, should trigger cleanup/eviction - tb.Add("newtoken", time.Now().Add(time.Hour)) - - // Size should still be at max - if tb.Count() > 1000 { - t.Errorf("Blacklist exceeded max size: %d", tb.Count()) - } -} - -func TestTokenBlacklistExpiredCleanup(t *testing.T) { - tb := NewTokenBlacklist() - - // Add some expired tokens - for i := 0; i < 500; i++ { - tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour)) - } - - // Add some valid tokens - for i := 0; i < 500; i++ { - tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour)) - } - - // Force cleanup - tb.Cleanup() - - // Only valid tokens should remain - if tb.Count() != 500 { - t.Errorf("Expected 500 valid tokens after cleanup, got %d", tb.Count()) - } - - // Verify only valid tokens remain - tb.mutex.RLock() - defer tb.mutex.RUnlock() - for token, expiry := range tb.tokens { - if time.Now().After(expiry) { - t.Errorf("Found expired token after cleanup: %s", token) - } - } -} - -func TestTokenBlacklistOldestEviction(t *testing.T) { - tb := NewTokenBlacklist() - - // Add tokens at capacity with different expiration times - baseTime := time.Now() - oldestToken := "oldest" - - // Add oldest token first - tb.Add(oldestToken, baseTime.Add(time.Hour)) - - // Fill up to capacity with newer tokens - for i := 0; i < 999; i++ { - tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2)) - } - - // Add a new token that should evict the oldest - newToken := "newest" - tb.Add(newToken, baseTime.Add(time.Hour*3)) - - // Verify oldest token was evicted - if tb.IsBlacklisted(oldestToken) { - t.Error("Oldest token should have been evicted") - } - - // Verify newest token is present - if !tb.IsBlacklisted(newToken) { - t.Error("Newest token should be present") - } -} - -func TestTokenBlacklistMemoryUsage(t *testing.T) { - tb := NewTokenBlacklist() - iterations := 10000 - - // Force initial GC - runtime.GC() - - // Record initial memory stats - var m1, m2 runtime.MemStats - runtime.ReadMemStats(&m1) - - // Simulate heavy usage - for i := 0; i < iterations; i++ { - // Add new token - tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour)) - - // Periodically check blacklisted status - if i%100 == 0 { - tb.IsBlacklisted(fmt.Sprintf("token%d", i-50)) - } - - // Periodically cleanup - if i%1000 == 0 { - tb.Cleanup() - } - } - - // Force GC and wait for it to complete - runtime.GC() - time.Sleep(100 * time.Millisecond) - runtime.ReadMemStats(&m2) - - // Check memory growth (using HeapAlloc for more accurate measurement) - memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc) - maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth - - if memoryGrowth > maxAllowedGrowth { - t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc) - t.Errorf("Excessive memory growth: %d bytes", memoryGrowth) - } - - // Verify size stayed within limits - if tb.Count() > 1000 { - t.Errorf("Blacklist exceeded max size: %d", tb.Count()) - } -} - -func TestConcurrentTokenBlacklistOperations(t *testing.T) { - tb := NewTokenBlacklist() - iterations := 1000 - concurrency := 10 - done := make(chan bool) - - // Start multiple goroutines performing operations - for i := 0; i < concurrency; i++ { - go func(id int) { - for j := 0; j < iterations; j++ { - // Add tokens - token := fmt.Sprintf("token%d-%d", id, j) - tb.Add(token, time.Now().Add(time.Hour)) - - // Check blacklist status - tb.IsBlacklisted(token) - - // Periodic cleanup - if j%100 == 0 { - tb.Cleanup() - } - } - done <- true - }(i) - } - - // Wait for all goroutines to complete - for i := 0; i < concurrency; i++ { - <-done - } - - // Verify size constraints were maintained - if tb.Count() > 1000 { - t.Errorf("Blacklist exceeded max size under concurrent operations: %d", tb.Count()) - } -} +// Removed tests related to the old TokenBlacklist implementation: +// - TestTokenBlacklistSizeLimit +// - TestTokenBlacklistExpiredCleanup +// - TestTokenBlacklistOldestEviction +// - TestTokenBlacklistMemoryUsage +// - TestConcurrentTokenBlacklistOperations func TestTokenCacheMemoryUsage(t *testing.T) { tc := NewTokenCache() diff --git a/jwt.go b/jwt.go index 7a03b8c..43e81c4 100644 --- a/jwt.go +++ b/jwt.go @@ -15,8 +15,10 @@ import ( "time" ) -var replayCacheMu sync.Mutex -var replayCache = make(map[string]time.Time) +var ( + replayCacheMu sync.Mutex + replayCache = make(map[string]time.Time) +) func cleanupReplayCache() { now := time.Now() @@ -164,6 +166,7 @@ func (j *JWT) Verify(issuerURL, clientID string) error { return nil } + func verifyAudience(tokenAudience interface{}, expectedAudience string) error { switch aud := tokenAudience.(type) { case string: diff --git a/main.go b/main.go index 526d182..8704ff0 100644 --- a/main.go +++ b/main.go @@ -9,11 +9,10 @@ import ( "net" "net/http" "net/url" + "runtime" "strings" "time" - "runtime" - "github.com/google/uuid" "golang.org/x/time/rate" ) @@ -52,7 +51,10 @@ func createDefaultHTTPClient() *http.Client { } } -const ConstSessionTimeout = 86400 // Session timeout in seconds +const ( + ConstSessionTimeout = 86400 // Session timeout in seconds + defaultBlacklistDuration = 24 * time.Hour // Default duration to blacklist a JTI +) // TokenVerifier interface for token verification type TokenVerifier interface { @@ -64,6 +66,13 @@ type JWTVerifier interface { VerifyJWTSignatureAndClaims(jwt *JWT, token string) error } +// TokenExchanger defines methods for OIDC token operations +type TokenExchanger interface { + ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) + GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) + RevokeTokenWithProvider(token, tokenType string) error +} + // TraefikOidc is the main struct for the OIDC middleware type TraefikOidc struct { next http.Handler @@ -74,7 +83,7 @@ type TraefikOidc struct { revocationURL string jwkCache JWKCacheInterface metadataCache *MetadataCache - tokenBlacklist *TokenBlacklist + tokenBlacklist *Cache // Replaced TokenBlacklist with generic Cache jwksURL string clientID string clientSecret string @@ -94,12 +103,13 @@ type TraefikOidc struct { allowedUserDomains map[string]struct{} allowedRolesAndGroups map[string]struct{} initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) - exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) - extractClaimsFunc func(tokenString string) (map[string]interface{}, error) - initComplete chan struct{} - endSessionURL string - postLogoutRedirectURI string - sessionManager *SessionManager + // exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) // Replaced by interface + extractClaimsFunc func(tokenString string) (map[string]interface{}, error) + initComplete chan struct{} + endSessionURL string + postLogoutRedirectURI string + sessionManager *SessionManager + tokenExchanger TokenExchanger // Added field for mocking } // ProviderMetadata holds OIDC provider metadata @@ -155,6 +165,29 @@ func (t *TraefikOidc) VerifyToken(token string) error { // Cache the verified token t.cacheVerifiedToken(token, jwt.Claims) + // Add JTI to blacklist AFTER successful verification to prevent replay + if jti, ok := jwt.Claims["jti"].(string); ok && jti != "" { + // Calculate expiry based on 'exp' claim if available, otherwise use default + expiry := time.Now().Add(defaultBlacklistDuration) + if expClaim, expOk := jwt.Claims["exp"].(float64); expOk { + expTime := time.Unix(int64(expClaim), 0) + tokenDuration := time.Until(expTime) + // Use token expiry if longer than default, capped at a reasonable max (e.g., 24h) + if tokenDuration > defaultBlacklistDuration && tokenDuration < (24*time.Hour) { + expiry = expTime + } else if tokenDuration <= 0 { + // If token already expired but somehow passed verification, use default + expiry = time.Now().Add(defaultBlacklistDuration) + } else { + // Use default if token expiry is shorter or excessively long + expiry = time.Now().Add(defaultBlacklistDuration) + } + } + // Use Set with a duration. Value 'true' is arbitrary, we only care about existence. + t.tokenBlacklist.Set(jti, true, time.Until(expiry)) + t.logger.Debugf("Added JTI %s to blacklist cache", jti) + } + return nil } @@ -165,11 +198,22 @@ func (t *TraefikOidc) performPreVerificationChecks(token string) error { return fmt.Errorf("rate limit exceeded") } - // Check if token is blacklisted - if t.tokenBlacklist.IsBlacklisted(token) { - return fmt.Errorf("token is blacklisted") + // Check if the raw token string itself is blacklisted (e.g., via explicit revocation) + if _, exists := t.tokenBlacklist.Get(token); exists { + return fmt.Errorf("token is blacklisted (raw string) in cache") } + // Also check if the JTI claim is blacklisted (replay detection) + claims, err := extractClaims(token) // Use existing helper + if err == nil { // Only check JTI if claims could be extracted + if jti, ok := claims["jti"].(string); ok && jti != "" { + if _, exists := t.tokenBlacklist.Get(jti); exists { + // Use a specific error message for replay + return fmt.Errorf("token replay detected (jti: %s) in cache", jti) + } + } + } // If claims extraction fails, proceed; full validation will catch token issues later. + return nil } @@ -296,7 +340,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h } return config.PostLogoutRedirectURI }(), - tokenBlacklist: NewTokenBlacklist(), + tokenBlacklist: NewCache(), // Use generic cache for blacklist jwkCache: &JWKCache{}, metadataCache: NewMetadataCache(), clientID: config.ClientID, @@ -316,7 +360,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger) t.extractClaimsFunc = extractClaims - t.exchangeCodeForTokenFunc = t.exchangeCodeForToken + // t.exchangeCodeForTokenFunc = t.exchangeCodeForToken // Removed, using interface now t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { t.defaultInitiateAuthentication(rw, req, session, redirectURL) } @@ -329,6 +373,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h t.tokenVerifier = t t.jwtVerifier = t t.startTokenCleanup() + t.tokenExchanger = t // Initialize the interface field to self go t.initializeMetadata(config.ProviderURL) return t, nil @@ -532,15 +577,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if !authenticated { + // Original logic: Always initiate authentication if not authenticated + t.logger.Debug("User not authenticated, initiating OIDC flow") t.defaultInitiateAuthentication(rw, req, session, redirectURL) - return + return // Stop processing } if needsRefresh { refreshed := t.refreshToken(rw, req, session) if !refreshed { + // Original logic: Always handle failed refresh as an expired token + t.logger.Debug("Token refresh failed, handling as expired token") t.handleExpiredToken(rw, req, session, redirectURL) - return + return // Stop processing } } @@ -621,6 +670,151 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.next.ServeHTTP(rw, req) } +// handleExpiredToken manages token expiration by clearing the session +// and initiating a new authentication flow. +func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { + // Clear authentication data but preserve CSRF state + session.SetAuthenticated(false) + session.SetAccessToken("") + session.SetRefreshToken("") + session.SetEmail("") + + // Save the cleared session state + if err := session.Save(req, rw); err != nil { + t.logger.Errorf("Failed to save cleared session: %v", err) + http.Error(rw, "Internal Server Error", http.StatusInternalServerError) + return + } + + t.defaultInitiateAuthentication(rw, req, session, redirectURL) +} + +// handleCallback processes the authentication callback from the OIDC provider. +// It validates the callback parameters, exchanges the authorization code for +// tokens, verifies the tokens, and establishes the user's session. +func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) { + session, err := t.sessionManager.GetSession(req) + if err != nil { + t.logger.Errorf("Session error: %v", err) + http.Error(rw, "Session error", http.StatusInternalServerError) + return + } + + t.logger.Debugf("Handling callback, URL: %s", req.URL.String()) + + // Check for errors in the callback + if req.URL.Query().Get("error") != "" { + errorDescription := req.URL.Query().Get("error_description") + t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription) + http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest) + return + } + + // Validate CSRF state + state := req.URL.Query().Get("state") + if state == "" { + t.logger.Error("No state in callback") + http.Error(rw, "State parameter missing in callback", http.StatusBadRequest) + return + } + + csrfToken := session.GetCSRF() + if csrfToken == "" { + t.logger.Error("CSRF token missing in session") + http.Error(rw, "CSRF token missing", http.StatusBadRequest) + return + } + + if state != csrfToken { + t.logger.Error("State parameter does not match CSRF token in session") + http.Error(rw, "Invalid state parameter", http.StatusBadRequest) + return + } + + // Exchange code for tokens + code := req.URL.Query().Get("code") + if code == "" { + t.logger.Error("No code in callback") + http.Error(rw, "No code in callback", http.StatusBadRequest) + return + } + + // Get the code verifier from the session for PKCE flow + codeVerifier := session.GetCodeVerifier() + + tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier) + if err != nil { + t.logger.Errorf("Failed to exchange code for token: %v", err) + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + // Verify tokens and claims + // Use the exported VerifyToken method now that handleCallback is in main.go + if err := t.VerifyToken(tokenResponse.IDToken); err != nil { + t.logger.Errorf("Failed to verify id_token: %v", err) + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + claims, err := t.extractClaimsFunc(tokenResponse.IDToken) + if err != nil { + t.logger.Errorf("Failed to extract claims: %v", err) + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + // Verify nonce to prevent replay attacks + nonceClaim, ok := claims["nonce"].(string) + if !ok || nonceClaim == "" { + t.logger.Error("Nonce claim missing in id_token") + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + sessionNonce := session.GetNonce() + if sessionNonce == "" { + t.logger.Error("Nonce not found in session") + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + if nonceClaim != sessionNonce { + t.logger.Error("Nonce claim does not match session nonce") + http.Error(rw, "Authentication failed", http.StatusInternalServerError) + return + } + + // Validate user's email domain + // Use the unexported isAllowedDomain method now that handleCallback is in main.go + email, _ := claims["email"].(string) + if email == "" || !t.isAllowedDomain(email) { + t.logger.Errorf("Invalid or disallowed email: %s", email) + http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden) + return + } + + // Update session with authentication data + session.SetAuthenticated(true) + session.SetEmail(email) + session.SetAccessToken(tokenResponse.IDToken) + session.SetRefreshToken(tokenResponse.RefreshToken) + + if err := session.Save(req, rw); err != nil { + t.logger.Errorf("Failed to save session: %v", err) + http.Error(rw, "Failed to save session", http.StatusInternalServerError) + return + } + + // Redirect to original path or root + redirectPath := "/" + if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath { + redirectPath = incomingPath + } + + http.Redirect(rw, req, redirectPath, http.StatusFound) +} + // determineExcludedURL checks if the current request URL is in the excluded list func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool { for excludedURL := range t.excludedURLs { @@ -675,17 +869,26 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo return false, false, true // Session is invalid, consider it expired } - // Verify the token - if err := t.verifyToken(accessToken); err != nil { - t.logger.Errorf("Token verification failed: %v", err) - return false, false, true // Token is invalid, consider it expired + // Verify the token structure and signature first + jwt, err := parseJWT(accessToken) + if err != nil { + t.logger.Errorf("Failed to parse JWT during auth check: %v", err) + return false, false, true // Invalid format, treat as expired/invalid + } + if err := t.VerifyJWTSignatureAndClaims(jwt, accessToken); err != nil { + // Check if the error is specifically about expiration + if strings.Contains(err.Error(), "token has expired") { + t.logger.Debugf("Token signature/claims valid but token expired, attempting refresh") + // Token is expired but otherwise valid, signal for refresh + return true, true, false // Authenticated=true (was valid), NeedsRefresh=true, Expired=false (because refresh is possible) + } + // Other verification error (signature, issuer, audience etc.) + t.logger.Errorf("Token verification failed (non-expiration): %v", err) + return false, false, true // Token is invalid for other reasons } - claims, err := extractClaims(accessToken) - if err != nil { - t.logger.Errorf("Failed to extract claims: %v", err) - return false, false, true - } + // Claims already parsed within VerifyJWTSignatureAndClaims if it didn't error early + claims := jwt.Claims expClaim, ok := claims["exp"].(float64) if !ok { @@ -696,17 +899,18 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo now := time.Now().Unix() expTime := int64(expClaim) - if now > expTime { - t.logger.Debug("Token has expired") - return false, false, true - } - - gracePeriod := time.Minute * 5 - if now+int64(gracePeriod.Seconds()) > expTime { - t.logger.Debug("Token will expire soon") - return true, true, false // Token will expire soon, needs refresh + // Expiration check is now handled within VerifyJWTSignatureAndClaims logic above + // We only get here if the token is valid and not expired + + // Check if token is nearing expiration (needs refresh proactively) + // Define a grace period, e.g., 5 minutes before actual expiry + refreshGracePeriod := int64(5 * 60) + if expTime-now < refreshGracePeriod { + t.logger.Debugf("Token nearing expiration (within %d seconds), scheduling refresh", refreshGracePeriod) + return true, true, false // Needs proactive refresh } + // Token is valid, not expired, and not nearing expiration return true, false, false } @@ -827,7 +1031,7 @@ func (t *TraefikOidc) startTokenCleanup() { for range ticker.C { t.logger.Debug("Starting token cleanup cycle") t.tokenCache.Cleanup() - t.tokenBlacklist.Cleanup() + // t.tokenBlacklist.Cleanup() // Removed: Generic Cache handles its own cleanup t.jwkCache.Cleanup() // Assuming jwkCache is the cache from cache.go // Removed runtime.GC() call } @@ -841,7 +1045,8 @@ func (t *TraefikOidc) RevokeToken(token string) { // Add to blacklist with default expiration expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration - t.tokenBlacklist.Add(token, expiry) + // Use Set with a duration. Value 'true' is arbitrary, we only care about existence. + t.tokenBlacklist.Set(token, true, time.Until(expiry)) } // RevokeTokenWithProvider revokes the token with the provider @@ -890,7 +1095,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se return false } - newToken, err := t.getNewTokenWithRefreshToken(refreshToken) + newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken) if err != nil { t.logger.Errorf("Failed to refresh token: %v", err) return false @@ -986,3 +1191,19 @@ func buildFullURL(scheme, host, path string) string { return fmt.Sprintf("%s://%s%s", scheme, host, path) } + +// --- TokenExchanger Interface Implementation --- + +// ExchangeCodeForToken implements the TokenExchanger interface. +// It calls the existing exchangeTokens helper function. +func (t *TraefikOidc) ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) { + // Note: The original exchangeTokens helper is defined in helpers.go and is already a method on *TraefikOidc + return t.exchangeTokens(ctx, grantType, codeOrToken, redirectURL, codeVerifier) +} + +// GetNewTokenWithRefreshToken implements the TokenExchanger interface. +// It calls the existing getNewTokenWithRefreshToken helper function. +func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { + // Note: The original getNewTokenWithRefreshToken helper is defined in helpers.go and is already a method on *TraefikOidc + return t.getNewTokenWithRefreshToken(refreshToken) +} diff --git a/main_test.go b/main_test.go index 4535798..fff7910 100644 --- a/main_test.go +++ b/main_test.go @@ -101,29 +101,47 @@ func (ts *TestSuite) Setup() { jwksURL: "https://test-jwks-url.com", revocationURL: "https://revocation-endpoint.com", limiter: rate.NewLimiter(rate.Every(time.Second), 10), - tokenBlacklist: NewTokenBlacklist(), + tokenBlacklist: NewCache(), // Use generic cache for blacklist tokenCache: NewTokenCache(), logger: logger, allowedUserDomains: map[string]struct{}{"example.com": {}}, excludedURLs: map[string]struct{}{"/favicon": {}}, httpClient: &http.Client{}, - extractClaimsFunc: extractClaims, - initComplete: make(chan struct{}), - sessionManager: ts.sessionManager, + // Explicitly set paths as New() is bypassed + redirURLPath: "/callback", // Assume default callback path for tests + logoutURLPath: "/callback/logout", // Assume default logout path for tests + tokenURL: "https://test-issuer.com/token", // Explicitly set for refresh tests + extractClaimsFunc: extractClaims, + initComplete: make(chan struct{}), + sessionManager: ts.sessionManager, } close(ts.tOidc.initComplete) - ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc + // ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc // Removed ts.tOidc.tokenVerifier = ts.tOidc ts.tOidc.jwtVerifier = ts.tOidc + // Set default mock exchanger + ts.tOidc.tokenExchanger = &MockTokenExchanger{ + ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) { + // Default mock behavior for code exchange + return &TokenResponse{ + IDToken: ts.token, // Use the valid token from setup + AccessToken: ts.token, + RefreshToken: "default-refresh-token", + ExpiresIn: 3600, + }, nil + }, + RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) { + // Default mock behavior for refresh (can be overridden in tests) + return nil, fmt.Errorf("default mock: refresh not expected") + }, + RevokeTokenFunc: func(token, tokenType string) error { + // Default mock behavior for revoke + return nil + }, + } } -// Helper functions used by TraefikOidc -func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { - return &TokenResponse{ - IDToken: ts.token, - RefreshToken: "test-refresh-token", - }, nil -} +// Helper function exchangeCodeForTokenFunc removed as it's unused after refactoring to TokenExchanger interface. // MockJWKCache implements JWKCacheInterface type MockJWKCache struct { @@ -141,6 +159,34 @@ func (m *MockJWKCache) Cleanup() { m.Err = nil } +// MockTokenExchanger implements TokenExchanger for testing +type MockTokenExchanger struct { + ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) + RefreshTokenFunc func(refreshToken string) (*TokenResponse, error) + RevokeTokenFunc func(token, tokenType string) error +} + +func (m *MockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) { + if m.ExchangeCodeFunc != nil { + return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier) + } + return nil, fmt.Errorf("ExchangeCodeFunc not implemented in mock") +} + +func (m *MockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { + if m.RefreshTokenFunc != nil { + return m.RefreshTokenFunc(refreshToken) + } + return nil, fmt.Errorf("RefreshTokenFunc not implemented in mock") +} + +func (m *MockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error { + if m.RevokeTokenFunc != nil { + return m.RevokeTokenFunc(token, tokenType) + } + return fmt.Errorf("RevokeTokenFunc not implemented in mock") +} + // Helper function to create a JWT token func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) { header := map[string]interface{}{ @@ -228,13 +274,14 @@ func TestVerifyToken(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Reset token blacklist and cache for each test - ts.tOidc.tokenBlacklist = NewTokenBlacklist() + ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist ts.tOidc.tokenCache = NewTokenCache() ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10) // Set up the test case if tc.blacklist { - ts.tOidc.tokenBlacklist.Add(tc.token, time.Now().Add(1*time.Hour)) + // Use Set with a duration. Value 'true' is arbitrary. + ts.tOidc.tokenBlacklist.Set(tc.token, true, 1*time.Hour) } if tc.rateLimit { @@ -282,13 +329,53 @@ func TestServeHTTP(t *testing.T) { ts.tOidc.next = nextHandler ts.tOidc.name = "test" + // Helper to create an expired token + createExpiredToken := func() string { + exp := time.Now().Add(-1 * time.Hour).Unix() // Expired 1 hour ago + iat := time.Now().Add(-2 * time.Hour).Unix() + nbf := time.Now().Add(-2 * time.Hour).Unix() + expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce-expired", // Different nonce for clarity + "jti": generateRandomString(16), + }) + return expiredToken + } + + // Helper to create a new valid token (simulating refresh) + createNewValidToken := func() string { + exp := time.Now().Add(1 * time.Hour).Unix() // Valid for 1 hour + iat := time.Now().Unix() + nbf := time.Now().Unix() + newToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + // "nonce": "test-nonce-new", // Nonce is typically not included/validated in refreshed tokens + "jti": generateRandomString(16), + }) + return newToken + } + tests := []struct { - name string - requestPath string - sessionValues map[interface{}]interface{} - expectedStatus int - expectedBody string - setupSession func(*SessionData) + name string + requestPath string + sessionValues map[interface{}]interface{} + expectedStatus int + expectedBody string + setupSession func(*SessionData) + mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) + assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) // Added for post-request checks }{ { name: "Excluded URL", @@ -299,28 +386,77 @@ func TestServeHTTP(t *testing.T) { { name: "Unauthenticated request to protected URL", requestPath: "/protected", - expectedStatus: http.StatusFound, + expectedStatus: http.StatusFound, // Expect redirect to OIDC }, { - name: "Authenticated request to protected URL", + name: "Authenticated request to protected URL (Valid Token)", requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") - session.SetAccessToken(ts.token) + session.SetAccessToken(ts.token) // Use the valid token generated in Setup + session.SetRefreshToken("valid-refresh-token") }, expectedStatus: http.StatusOK, expectedBody: "OK", }, + { + name: "Authenticated request with expired token and successful refresh", + requestPath: "/protected", + setupSession: func(session *SessionData) { + session.SetAuthenticated(true) // Still marked authenticated initially + session.SetEmail("user@example.com") + session.SetAccessToken(createExpiredToken()) // Set expired token + session.SetRefreshToken("valid-refresh-token") // Set valid refresh token + }, + mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { + return func(refreshToken string) (*TokenResponse, error) { + if refreshToken != "valid-refresh-token" { + return nil, fmt.Errorf("mock error: expected 'valid-refresh-token', got '%s'", refreshToken) + } + // Simulate successful refresh + newToken := createNewValidToken() + return &TokenResponse{ + IDToken: newToken, // Return new valid token + AccessToken: newToken, // Often the same as ID token in tests + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + }, nil + } + }, + expectedStatus: http.StatusOK, // Expect success after refresh + expectedBody: "OK", + assertSessionAfterRequest: func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) { + // Create a new request to read the cookies set by the response recorder + reqForCookieRead := httptest.NewRequest("GET", "/protected", nil) + for _, cookie := range rr.Result().Cookies() { + reqForCookieRead.AddCookie(cookie) + } + // Get session based on response cookies + session, err := sessionManager.GetSession(reqForCookieRead) + if err != nil { + t.Fatalf("Failed to get session after request: %v", err) + } + // Assert new tokens are in the session + // Direct comparison with createNewValidToken() is flawed as it generates a new token each time. + // Instead, check if the token was updated (not empty) and verify the refresh token. + if session.GetAccessToken() == "" { + t.Errorf("Expected access token to be updated in session, but it was empty") + } + if session.GetRefreshToken() != "new-refresh-token" { + t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken()) + } + }, + }, { name: "Logout URL", - requestPath: "/logout", + requestPath: "/logout", // Assuming logout path is configured or defaulted correctly setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") session.SetAccessToken(ts.token) }, - expectedStatus: http.StatusOK, + expectedStatus: http.StatusFound, // Expect redirect after logout expectedBody: "", }, } @@ -328,40 +464,79 @@ func TestServeHTTP(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest("GET", tc.requestPath, nil) - req.Header.Set("X-Forwarded-Proto", "http") - req.Header.Set("X-Forwarded-Host", "localhost") + // Set common headers needed by the logic (determineScheme, determineHost) + req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that + req.Header.Set("X-Forwarded-Host", "testhost.com") + req.Host = "testhost.com" // Also set Host header + rr := httptest.NewRecorder() // Setup session if needed session, err := ts.tOidc.sessionManager.GetSession(req) if err != nil { - t.Fatalf("Failed to get session: %v", err) + t.Fatalf("Test %s: Failed to get initial session: %v", tc.name, err) } if tc.setupSession != nil { tc.setupSession(session) - if err := session.Save(req, rr); err != nil { - t.Fatalf("Failed to save session: %v", err) + // Save session to recorder to get cookies + saveRecorder := httptest.NewRecorder() + if err := session.Save(req, saveRecorder); err != nil { + t.Fatalf("Test %s: Failed to save initial session: %v", tc.name, err) } - - // Copy cookies to the new request - for _, cookie := range rr.Result().Cookies() { + // Copy cookies from save recorder to the actual request + for _, cookie := range saveRecorder.Result().Cookies() { req.AddCookie(cookie) } - rr = httptest.NewRecorder() + } + + // Mocking setup for TokenExchanger + originalExchanger := ts.tOidc.tokenExchanger // Store original + mockExchanger, isMock := originalExchanger.(*MockTokenExchanger) + if !isMock { + // This case should ideally not happen if Setup correctly assigns the mock, + // but handle it defensively. + t.Logf("Warning: Default exchanger was not the mock. Creating a temporary mock.") + mockExchanger = &MockTokenExchanger{ + ExchangeCodeFunc: originalExchanger.ExchangeCodeForToken, + RefreshTokenFunc: originalExchanger.GetNewTokenWithRefreshToken, + RevokeTokenFunc: originalExchanger.RevokeTokenWithProvider, + } + ts.tOidc.tokenExchanger = mockExchanger // Temporarily assign mock + } + + // Override specific mock methods if needed for the test case + originalMockRefreshFunc := mockExchanger.RefreshTokenFunc // Store current mock func + if tc.mockRefreshTokenFunc != nil { + // Assign the test case specific mock function + mockExchanger.RefreshTokenFunc = tc.mockRefreshTokenFunc(originalExchanger.GetNewTokenWithRefreshToken) } // Call ServeHTTP ts.tOidc.ServeHTTP(rr, req) - // Check response - if rr.Code != tc.expectedStatus { - t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code) + // Restore original exchanger and mock function state + ts.tOidc.tokenExchanger = originalExchanger + if tc.mockRefreshTokenFunc != nil && mockExchanger != nil { + // Restore the previous mock function if we overrode it + mockExchanger.RefreshTokenFunc = originalMockRefreshFunc } + + // Check response status + if rr.Code != tc.expectedStatus { + t.Errorf("Test %s: Expected status %d, got %d. Body: %s", tc.name, tc.expectedStatus, rr.Code, rr.Body.String()) + } + + // Check response body if expected if tc.expectedBody != "" { if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody { - t.Errorf("Expected body %q, got %q", tc.expectedBody, body) + t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body) } } + + // Perform post-request session assertions if defined + if tc.assertSessionAfterRequest != nil { + tc.assertSessionAfterRequest(t, rr, req, ts.tOidc.sessionManager) + } }) } } @@ -552,17 +727,39 @@ func TestHandleCallback(t *testing.T) { name: "Disallowed Email", queryParams: "?code=test-code&state=test-csrf-token", exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { + // Generate a unique token for this test case to avoid replay issues + // Use claims relevant to this test (disallowed email) + now := time.Now() + exp := now.Add(1 * time.Hour).Unix() + iat := now.Unix() + nbf := now.Unix() + disallowedToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject-disallowed", + "email": "user@disallowed.com", // The disallowed email for this test + "nonce": "test-nonce", // Match the nonce set in sessionSetupFunc + "jti": generateRandomString(16), // Unique JTI + }) + if err != nil { + return nil, fmt.Errorf("failed to create disallowed token for test: %w", err) + } return &TokenResponse{ - IDToken: ts.token, - RefreshToken: "test-refresh-token", - }, nil - }, - extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { - return map[string]interface{}{ - "email": "user@disallowed.com", - "nonce": "test-nonce", + IDToken: disallowedToken, + RefreshToken: "test-refresh-token-disallowed", }, nil }, + // Remove mock extractClaimsFunc - let the real one parse the disallowedToken + // The test should still fail correctly on the email check later. + // extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { + // return map[string]interface{}{ + // "email": "user@disallowed.com", + // "nonce": "test-nonce", + // }, nil + // }, sessionSetupFunc: func(session *SessionData) { session.SetCSRF("test-csrf-token") session.SetNonce("test-nonce") @@ -635,20 +832,61 @@ func TestHandleCallback(t *testing.T) { } for _, tc := range tests { + tc := tc // Capture range variable t.Run(tc.name, func(t *testing.T) { + // Clear the global replay cache before each test run + replayCacheMu.Lock() + replayCache = make(map[string]time.Time) // Reset the global cache + replayCacheMu.Unlock() + + // Explicitly clear the shared blacklist at the start of each sub-test + // to ensure no state leaks, even though we expect the local one to be used. + // Note: This line might be redundant now that the verifier is local, but keep for safety. + ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist + logger := NewLogger("info") sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger) // Create a new instance for each test to avoid state carryover - tOidc := &TraefikOidc{ - allowedUserDomains: map[string]struct{}{"example.com": {}}, - logger: logger, - exchangeCodeForTokenFunc: tc.exchangeCodeForToken, - extractClaimsFunc: tc.extractClaimsFunc, - tokenVerifier: ts.tOidc.tokenVerifier, - jwtVerifier: ts.tOidc.jwtVerifier, - sessionManager: sessionManager, + instanceExtractClaimsFunc := tc.extractClaimsFunc + if instanceExtractClaimsFunc == nil { + instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case } + tOidc := &TraefikOidc{ + allowedUserDomains: map[string]struct{}{"example.com": {}}, + logger: logger, + // exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field + extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function + tokenVerifier: nil, // Will be set to self below + jwtVerifier: nil, // Temporarily nil, will be set below + sessionManager: sessionManager, + tokenExchanger: &MockTokenExchanger{ // Create a new mock exchanger for this specific test run + ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) { + // Wrap the test case function to match the required signature + if tc.exchangeCodeForToken != nil { + // Only call if the test case provided a function + return tc.exchangeCodeForToken(codeOrToken, redirectURL, codeVerifier) + } + // Provide a default behavior or error if no mock was provided for this test case + return nil, fmt.Errorf("mock ExchangeCodeFunc not implemented for this test case") + }, + // Keep other mock funcs nil or provide defaults if needed by other parts of handleCallback + }, + tokenCache: NewTokenCache(), // Initialize token cache + limiter: rate.NewLimiter(rate.Inf, 0), // Initialize rate limiter + tokenBlacklist: NewCache(), // Initialize token blacklist cache + + // Add potentially missing fields based on New() comparison + clientID: ts.tOidc.clientID, + issuerURL: ts.tOidc.issuerURL, + jwkCache: ts.tOidc.jwkCache, // Use the mock cache from TestSuite + httpClient: ts.tOidc.httpClient, + initComplete: make(chan struct{}), // Initialize the channel + // Setting other fields like paths, enablePKCE etc. if needed + } + tOidc.tokenVerifier = tOidc // Point tokenVerifier to the local instance NOW + tOidc.jwtVerifier = tOidc // Point jwtVerifier to the local instance NOW + close(tOidc.initComplete) // Mark this test instance as initialized // Create request and response recorder req := httptest.NewRequest("GET", "/callback"+tc.queryParams, nil) @@ -839,13 +1077,14 @@ func TestOIDCHandler(t *testing.T) { tc := tc // Capture range variable t.Run(tc.name, func(t *testing.T) { // Reset token blacklist and cache - ts.tOidc.tokenBlacklist = NewTokenBlacklist() + ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist ts.tOidc.tokenCache = NewTokenCache() ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10) // Set up the test case if tc.blacklist { - ts.tOidc.tokenBlacklist.Add(ts.token, time.Now().Add(1*time.Hour)) + // Use Set with a duration. Value 'true' is arbitrary. + ts.tOidc.tokenBlacklist.Set(ts.token, true, 1*time.Hour) } if tc.rateLimit { @@ -948,7 +1187,7 @@ func TestHandleLogout(t *testing.T) { endSessionURL: tc.endSessionURL, scheme: "http", logger: logger, - tokenBlacklist: NewTokenBlacklist(), + tokenBlacklist: NewCache(), // Use generic cache for blacklist httpClient: &http.Client{}, clientID: "test-client-id", clientSecret: "test-client-secret", @@ -1018,13 +1257,13 @@ func TestHandleLogout(t *testing.T) { // Check token blacklist if token := session.GetAccessToken(); token != "" { - if !tOidc.tokenBlacklist.IsBlacklisted(token) { - t.Error("Access token was not blacklisted") + if _, exists := tOidc.tokenBlacklist.Get(token); !exists { + t.Error("Access token was not blacklisted in cache") } } if token := session.GetRefreshToken(); token != "" { - if !tOidc.tokenBlacklist.IsBlacklisted(token) { - t.Error("Refresh token was not blacklisted") + if _, exists := tOidc.tokenBlacklist.Get(token); !exists { + t.Error("Refresh token was not blacklisted in cache") } } }) @@ -1121,7 +1360,7 @@ func TestRevokeToken(t *testing.T) { t.Run("Token revocation", func(t *testing.T) { // Create a new instance for this specific test tOidc := &TraefikOidc{ - tokenBlacklist: NewTokenBlacklist(), + tokenBlacklist: NewCache(), // Use generic cache for blacklist tokenCache: NewTokenCache(), } @@ -1136,8 +1375,8 @@ func TestRevokeToken(t *testing.T) { t.Error("Token was not removed from cache") } - // Verify token was added to blacklist - if !tOidc.tokenBlacklist.IsBlacklisted(token) { + // Verify token was added to blacklist cache + if _, exists := tOidc.tokenBlacklist.Get(token); !exists { t.Error("Token was not added to blacklist") } }) @@ -1404,7 +1643,6 @@ func TestMultipleMiddlewareInstances(t *testing.T) { middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }), config, "test") - if err != nil { t.Fatalf("Failed to create middleware for route %s: %v", route, err) } @@ -2043,7 +2281,6 @@ func TestExchangeCodeForToken(t *testing.T) { // Test exchangeCodeForToken response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier) - if err != nil { t.Errorf("Unexpected error: %v", err) } diff --git a/metadata_cache.go b/metadata_cache.go index 64f4b6c..0e96d59 100644 --- a/metadata_cache.go +++ b/metadata_cache.go @@ -34,6 +34,7 @@ func (c *MetadataCache) Cleanup() { c.metadata = nil } } + func (c *MetadataCache) isCacheValid() bool { return c.metadata != nil && time.Now().Before(c.expiresAt) } @@ -67,15 +68,9 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, } c.metadata = metadata - // Calculate expiration time based on usage patterns - usageCount := 0 // This should be replaced with actual usage tracking logic - if usageCount < 10 { - c.expiresAt = time.Now().Add(30 * time.Minute) - } else if usageCount < 50 { - c.expiresAt = time.Now().Add(1 * time.Hour) - } else { - c.expiresAt = time.Now().Add(2 * time.Hour) - } + // Set a fixed cache lifetime (e.g., 1 hour) + // TODO: Consider making this configurable or respecting HTTP cache headers + c.expiresAt = time.Now().Add(1 * time.Hour) // End of GetMetadata return metadata, nil diff --git a/session_test.go b/session_test.go index 6ea172a..c145ddc 100644 --- a/session_test.go +++ b/session_test.go @@ -1,7 +1,9 @@ package traefikoidc import ( - "math/rand" + "crypto/rand" + "fmt" + "math/big" "net/http/httptest" "strings" "testing" @@ -12,7 +14,12 @@ func generateRandomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" b := make([]byte, length) for i := range b { - b[i] = charset[rand.Intn(len(charset))] + num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) + if err != nil { + // Handle error appropriately in a real application, maybe panic in test helper + panic(fmt.Sprintf("crypto/rand failed: %v", err)) + } + b[i] = charset[num.Int64()] } return string(b) }