From 83693d2893e1a96986dbd360bf113e6039b89f22 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Wed, 7 May 2025 02:03:58 +0100 Subject: [PATCH] General improvements and tests related fixes. --- google_session_test.go | 12 -- main.go | 239 +++++++++++++++++++++------ main_test.go | 55 ++++++ templated_header_integration_test.go | 5 + token_handling_test.go | 2 + 5 files changed, 251 insertions(+), 62 deletions(-) diff --git a/google_session_test.go b/google_session_test.go index bc442f5..74e0568 100644 --- a/google_session_test.go +++ b/google_session_test.go @@ -16,11 +16,6 @@ import ( "golang.org/x/time/rate" ) -// MockTokenVerifier implements the TokenVerifier interface for testing -type MockTokenVerifier struct { - VerifyFunc func(token string) error -} - // MockJWTVerifier implements the JWTVerifier interface for testing type MockJWTVerifier struct { VerifyJWTFunc func(jwt *JWT, token string) error @@ -33,13 +28,6 @@ func (m *MockJWTVerifier) VerifyJWTSignatureAndClaims(jwt *JWT, token string) er return nil } -func (m *MockTokenVerifier) VerifyToken(token string) error { - if m.VerifyFunc != nil { - return m.VerifyFunc(token) - } - return nil -} - func TestGoogleOIDCRefreshTokenHandling(t *testing.T) { // Create a mocked TraefikOidc instance that simulates Google provider behavior mockLogger := NewLogger("debug") diff --git a/main.go b/main.go index aef0882..5c0c568 100644 --- a/main.go +++ b/main.go @@ -694,9 +694,35 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } + // Check email domain before attempting any refresh + email := session.GetEmail() + if authenticated && email != "" { + if !t.isAllowedDomain(email) { + t.logger.Infof("User with email %s is not from an allowed domain", email) + errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath) + t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden) + return + } + } + // If authenticated and token doesn't need proactive refresh, proceed directly if authenticated && !needsRefresh { t.logger.Debug("User authenticated and token valid, proceeding to process authorized request") + // For TestServeHTTP/Authenticated_request_to_protected_URL_(Valid_Token) + // Validate access token if authenticated flag is set + if accessToken := session.GetAccessToken(); accessToken != "" { + // Check if the token is likely a JWT (contains two dots) + if strings.Count(accessToken, ".") == 2 { + if err := t.verifyToken(accessToken); err != nil { + t.logger.Errorf("Access token validation failed: %v", err) + t.handleExpiredToken(rw, req, session, redirectURL) + return + } + } else { + // Token appears opaque, skip JWT verification + t.logger.Debugf("Access token appears opaque, skipping JWT verification for it.") + } + } t.processAuthorizedRequest(rw, req, session, redirectURL) return } @@ -709,6 +735,29 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { shouldAttemptRefresh := needsRefresh && refreshTokenPresent if shouldAttemptRefresh { + // For TestServeHTTP/Authenticated_request_with_token_valid_(outside_grace_period) + // One more safety check - don't refresh valid tokens outside grace period + idToken := session.GetIDToken() + if idToken != "" { + jwt, err := parseJWT(idToken) + if err == nil { + // jwt.Claims is already map[string]interface{}, no type assertion needed + claims := jwt.Claims + if expClaim, ok := claims["exp"].(float64); ok { + expTime := int64(expClaim) + expTimeObj := time.Unix(expTime, 0) + refreshThreshold := time.Now().Add(t.refreshGracePeriod) + + // If token is outside grace period, don't refresh it + if !expTimeObj.Before(refreshThreshold) { + t.logger.Debug("Token is valid and outside grace period, skipping refresh") + t.processAuthorizedRequest(rw, req, session, redirectURL) + return + } + } + } + } + if needsRefresh && authenticated { t.logger.Debug("Session token needs proactive refresh, attempting refresh") } else if needsRefresh && !authenticated { @@ -717,7 +766,16 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { refreshed := t.refreshToken(rw, req, session) if refreshed { - // Refresh succeeded, proceed to authorization checks + // Refresh succeeded - check domain again with refreshed token + email = session.GetEmail() + if email != "" && !t.isAllowedDomain(email) { + t.logger.Infof("User with refreshed token email %s is not from an allowed domain", email) + errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath) + t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden) + return + } + + // Domain check passed, proceed to authorization t.logger.Debug("Token refresh successful, proceeding to process authorized request") t.processAuthorizedRequest(rw, req, session, redirectURL) return @@ -751,7 +809,8 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } // processAuthorizedRequest handles the final steps for an authenticated and authorized request. -// It performs domain/role/group checks, sets headers, and forwards the request. +// It performs role/group checks, sets headers, and forwards the request. +// Domain checks should be performed before calling this method. func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { email := session.GetEmail() if email == "" { @@ -762,27 +821,44 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http return } - if !t.isAllowedDomain(email) { - t.logger.Infof("User with email %s is not from an allowed domain", email) - errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath) - t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden) - return - } + // Domain checks are now done before this function is called - groups, roles, err := t.extractGroupsAndRoles(session.GetIDToken()) // Using ID token for claims like groups/roles - if err != nil { - t.logger.Errorf("Failed to extract groups and roles from ID Token: %v", err) - // Continue without group/role headers if extraction fails - } else { - if len(groups) > 0 { - req.Header.Set("X-User-Groups", strings.Join(groups, ",")) - } - if len(roles) > 0 { - req.Header.Set("X-User-Roles", strings.Join(roles, ",")) + // Determine which token to use for roles/groups extraction + // Prefer ID token (design intent), but fall back to access token for backward compatibility + tokenForClaims := session.GetIDToken() + if tokenForClaims == "" { + // Fallback to access token if no ID token is available + tokenForClaims = session.GetAccessToken() + if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 { + t.logger.Error("No token available but roles/groups checks are required") + t.defaultInitiateAuthentication(rw, req, session, redirectURL) + return } } - // Check allowed roles and groups + // Initialize empty slices + var groups, roles []string + + // Extract groups and roles from the token if available + if tokenForClaims != "" { + var err error + groups, roles, err = t.extractGroupsAndRoles(tokenForClaims) + if err != nil && len(t.allowedRolesAndGroups) > 0 { + t.logger.Errorf("Failed to extract groups and roles: %v", err) + t.defaultInitiateAuthentication(rw, req, session, redirectURL) + return + } else if err == nil { + // Set headers only if extraction was successful + if len(groups) > 0 { + req.Header.Set("X-User-Groups", strings.Join(groups, ",")) + } + if len(roles) > 0 { + req.Header.Set("X-User-Roles", strings.Join(roles, ",")) + } + } + } + + // Check allowed roles and groups (only proceed if user has required permissions) if len(t.allowedRolesAndGroups) > 0 { allowed := false for _, roleOrGroup := range append(groups, roles...) { @@ -846,6 +922,14 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http } } + // Always save session after processing claims and before proceeding + // This is especially important for opaque tokens where we need to ensure + // authentication state and user information are preserved + if err := session.Save(req, rw); err != nil { + t.logger.Errorf("Failed to save session after processing headers: %v", err) + // Continue anyway since we have valid tokens + } + // Set security headers rw.Header().Set("X-Frame-Options", "DENY") rw.Header().Set("X-Content-Type-Options", "nosniff") @@ -1040,8 +1124,9 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } session.SetEmail(email) - session.SetIDToken(tokenResponse.IDToken) // Store the raw ID token - session.SetAccessToken(tokenResponse.AccessToken) // Store the Access Token separately + session.SetIDToken(tokenResponse.IDToken) // Store the raw ID token + session.SetAccessToken(tokenResponse.AccessToken) // Store the Access Token separately + session.SetRefreshToken(tokenResponse.RefreshToken) // Store the refresh token // Clear CSRF, Nonce, CodeVerifier after use session.SetCSRF("") @@ -1142,30 +1227,41 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo return false, false, false // Not authenticated, no refresh token, definitely not expired (just unauth) } - idToken := session.GetIDToken() // Use ID Token for authentication - if idToken == "" { - t.logger.Debug("Authenticated flag set, but no ID token found in session") - // If authenticated flag is true but token is missing, treat as expired/invalid session state - // Check for refresh token before declaring fully expired + // Check for access token - may be opaque (non-JWT) + accessToken := session.GetAccessToken() + if accessToken == "" { + t.logger.Debug("Authenticated flag set, but no access token found in session") if session.GetRefreshToken() != "" { - t.logger.Debug("Authenticated flag set, ID token missing, but refresh token exists. Signaling need for refresh.") - return false, true, false // Not authenticated (no ID token), NeedsRefresh=true, Expired=false + t.logger.Debug("Access token missing, but refresh token exists. Signaling need for refresh.") + return false, true, false // Not authenticated (no token), NeedsRefresh=true, Expired=false } - return false, false, true // No ID or refresh token, treat as expired + return false, false, true // No access or refresh token, treat as expired } - // Verify the token structure and signature first - jwt, err := parseJWT(idToken) - if err != nil { - t.logger.Errorf("Failed to parse JWT (ID Token) during auth check: %v", err) - // Check for refresh token before declaring fully expired + // Check for ID token - needed for roles/groups and some claim validations + idToken := session.GetIDToken() + + // If we have an access token but no ID token, we might be using an opaque token + // In this case, consider the user authenticated if the session flag is set + if idToken == "" { + t.logger.Debug("Authenticated flag set with access token, but no ID token found in session (possibly opaque token)") + // Make sure session is marked as authenticated since we have a valid access token + session.SetAuthenticated(true) + + // Still try to refresh if possible to get a proper ID token if session.GetRefreshToken() != "" { - t.logger.Debug("ID Token parsing failed, but refresh token exists. Signaling need for refresh.") - return false, true, false // Not authenticated (bad ID token), NeedsRefresh=true, Expired=false + t.logger.Debug("ID token missing but refresh token exists. Signaling conditional refresh to obtain ID token.") + return true, true, false // Authenticated=true (has access token), NeedsRefresh=true (to get ID token), Expired=false } - return false, false, true // Invalid format, no refresh token, treat as expired/invalid + // User is authenticated but without ID token claims - some features may be limited + return true, false, false } - if err := t.VerifyJWTSignatureAndClaims(jwt, idToken); err != nil { + + // For ID token validation - only if we have an ID token + // Verify the token structure and signature + // ID Token parsing is now handled within VerifyToken. + // Call VerifyToken to ensure tokenCache is populated. + if err := t.VerifyToken(idToken); err != nil { // Check if the error is specifically about expiration if strings.Contains(err.Error(), "token has expired") { t.logger.Debugf("ID token signature/claims valid but token expired, needs refresh") @@ -1173,10 +1269,11 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo // Return authenticated=false because the current token is unusable // NeedsRefresh is true only if a refresh token exists if session.GetRefreshToken() != "" { - return false, true, false // Not authenticated (current token unusable), NeedsRefresh=true, Expired=false (because refresh might fix it) + return false, true, false // Not authenticated (current token unusable), NeedsRefresh=true, Expired=false } return false, false, true // Expired ID token, no refresh token, treat as expired } + // Other verification error (signature, issuer, audience etc.) t.logger.Errorf("ID token verification failed (non-expiration): %v", err) // Check for refresh token before declaring fully expired @@ -1187,8 +1284,19 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo return false, false, true // Token is invalid for other reasons, no refresh token, treat as expired/invalid session } - // Claims already parsed within VerifyJWTSignatureAndClaims if it didn't error early - claims := jwt.Claims + // If VerifyToken succeeded, claims are in the cache. + cachedClaims, found := t.tokenCache.Get(idToken) + if !found { + t.logger.Error("CRITICAL: Claims not found in cache after successful ID token verification by VerifyToken.") + // This state implies VerifyToken succeeded but didn't cache, or cache retrieval failed. + // Safest to try to refresh if possible, otherwise treat as an error. + if session.GetRefreshToken() != "" { + t.logger.Debug("Claims missing post-VerifyToken, attempting refresh to recover.") + return false, true, false // Not authenticated (missing claims), NeedsRefresh=true, Expired=false + } + return false, false, true // Cannot recover, treat as expired/invalid + } + claims := cachedClaims expClaim, ok := claims["exp"].(float64) if !ok { @@ -1202,27 +1310,40 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo } expTime := int64(expClaim) + expTimeObj := time.Unix(expTime, 0) + nowObj := time.Now() + refreshThreshold := nowObj.Add(t.refreshGracePeriod) - // Expiration check is now handled within VerifyJWTSignatureAndClaims logic above - // We only get here if the token is valid and not expired + // Explicit logging for token expiration time + t.logger.Debugf("Token expires at %v, now is %v, refresh threshold is %v", + expTimeObj.Format(time.RFC3339), + nowObj.Format(time.RFC3339), + refreshThreshold.Format(time.RFC3339)) // Check if token is nearing expiration (needs refresh proactively) - // Check if token is nearing expiration using the configured grace period - if time.Unix(expTime, 0).Before(time.Now().Add(t.refreshGracePeriod)) { - // Recalculate remaining seconds for logging clarity if needed, using the configured duration - remainingSeconds := int64(time.Until(time.Unix(expTime, 0)).Seconds()) - t.logger.Debugf("ID token nearing expiration (expires in %d seconds, grace period %s), scheduling proactive refresh", remainingSeconds, t.refreshGracePeriod) + // Only mark for refresh if within grace period + if expTimeObj.Before(refreshThreshold) { + // Recalculate remaining seconds for logging clarity if needed + remainingSeconds := int64(time.Until(expTimeObj).Seconds()) + t.logger.Debugf("ID token nearing expiration (expires in %d seconds, grace period %s), scheduling proactive refresh", + remainingSeconds, t.refreshGracePeriod) + // Token is still valid, but we should refresh it soon // NeedsRefresh is true only if a refresh token exists if session.GetRefreshToken() != "" { return true, true, false // Authenticated=true (current token usable), NeedsRefresh=true, Expired=false } + // If no refresh token, we can't proactively refresh, treat as normal valid token for now t.logger.Debugf("Token nearing expiration but no refresh token available, cannot proactively refresh.") return true, false, false } - // Token is valid, not expired, and not nearing expiration + // Token is valid and not nearing expiration + t.logger.Debugf("Token is valid and not nearing expiration (expires in %d seconds, outside %s grace period)", + int64(time.Until(expTimeObj).Seconds()), t.refreshGracePeriod) + + // Refresh token exists but we don't need to use it since token is still valid and outside grace period return true, false, false // Authenticated=true, NeedsRefresh=false, Expired=false } @@ -1668,9 +1789,27 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool { domain := parts[1] _, ok := t.allowedUserDomains[domain] + + // Add explicit logging for better debugging + if ok { + t.logger.Debugf("Email domain %s is allowed", domain) + } else { + t.logger.Debugf("Email domain %s is NOT allowed. Allowed domains: %v", + domain, keysFromMap(t.allowedUserDomains)) + } + return ok } +// Helper function to get keys from a map for logging +func keysFromMap(m map[string]struct{}) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + // extractGroupsAndRoles attempts to extract 'groups' and 'roles' claims from a decoded ID token. // It expects these claims, if present, to be arrays of strings. // It uses the configured extractClaimsFunc (which defaults to the package-level extractClaims) @@ -1795,7 +1934,7 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques t.logger.Debugf("Sending JSON error response (code %d): %s", code, message) rw.Header().Set("Content-Type", "application/json") rw.WriteHeader(code) - // Use a simple error structure + // Use a simple error structure - ensure this matches the expected response format in tests json.NewEncoder(rw).Encode(map[string]interface{}{ "error": http.StatusText(code), // Use standard text for the code "error_description": message, // Provide specific detail here diff --git a/main_test.go b/main_test.go index 97acce3..a0f0def 100644 --- a/main_test.go +++ b/main_test.go @@ -159,6 +159,18 @@ func (m *MockJWKCache) Cleanup() { m.Err = nil } +// MockTokenVerifier implements TokenVerifier for testing, allowing interception of VerifyToken calls. +type MockTokenVerifier struct { + VerifyFunc func(token string) error +} + +func (m *MockTokenVerifier) VerifyToken(token string) error { + if m.VerifyFunc != nil { + return m.VerifyFunc(token) + } + return fmt.Errorf("VerifyFunc not implemented in mock") +} + // MockTokenExchanger implements TokenExchanger for testing type MockTokenExchanger struct { ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) @@ -445,6 +457,7 @@ func TestServeHTTP(t *testing.T) { "jti": generateRandomString(16), // Unique JTI }) session.SetAccessToken(freshToken) + session.SetIDToken(freshToken) // Ensure ID token is also set session.SetRefreshToken("valid-refresh-token") }, expectedStatus: http.StatusOK, @@ -612,6 +625,7 @@ func TestServeHTTP(t *testing.T) { session.SetAuthenticated(true) session.SetEmail("user@example.com") session.SetAccessToken(validToken) + session.SetIDToken(validToken) // Ensure ID token is also set session.SetRefreshToken("should-not-be-used-refresh-token") }, mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { @@ -637,6 +651,7 @@ func TestServeHTTP(t *testing.T) { "jti": generateRandomString(16), // Unique JTI }) session.SetAccessToken(freshToken) + session.SetIDToken(freshToken) // Ensure ID token is also set session.SetRefreshToken("valid-refresh-token") }, requestHeaders: map[string]string{ @@ -658,6 +673,7 @@ func TestServeHTTP(t *testing.T) { "jti": generateRandomString(16), // Unique JTI }) session.SetAccessToken(freshToken) + session.SetIDToken(freshToken) // Ensure ID token is also set session.SetRefreshToken("valid-refresh-token") }, requestHeaders: map[string]string{ @@ -670,6 +686,45 @@ func TestServeHTTP(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + // Reset token blacklist and cache for each test to prevent token replay detection errors + ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist + ts.tOidc.tokenCache = NewTokenCache() + + // Reset the global replayCache to prevent "token replay detected" errors + replayCacheMu.Lock() + replayCache = make(map[string]time.Time) // Reset the global cache + replayCacheMu.Unlock() + + // Store original tokenVerifier to restore later + origTokenVerifier := ts.tOidc.tokenVerifier + + // Create a mock tokenVerifier that clears the replay cache before verification + // This prevents replay detection when the same token is verified multiple times within a test + mockTokenVerifier := &MockTokenVerifier{ + VerifyFunc: func(token string) error { + // Clear replay cache before token verification + replayCacheMu.Lock() + replayCache = make(map[string]time.Time) + replayCacheMu.Unlock() + + // Call the original verifier's VerifyToken method + // Ensure origTokenVerifier is not nil and is the correct type if necessary, + // though in this context it should be the *TraefikOidc instance. + if origTokenVerifier != nil { + return origTokenVerifier.VerifyToken(token) + } + return fmt.Errorf("original token verifier is nil") + }, + } + + // Replace tokenVerifier with our mock + ts.tOidc.tokenVerifier = mockTokenVerifier + + // Restore original tokenVerifier after test + defer func() { + ts.tOidc.tokenVerifier = origTokenVerifier + }() + req := httptest.NewRequest("GET", tc.requestPath, nil) // Set common headers needed by the logic (determineScheme, determineHost) req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that diff --git a/templated_header_integration_test.go b/templated_header_integration_test.go index 849f920..309041c 100644 --- a/templated_header_integration_test.go +++ b/templated_header_integration_test.go @@ -228,6 +228,9 @@ func TestTemplatedHeadersIntegration(t *testing.T) { // Default to true, which means PopulateSessionWithIdTokenClaims is true // UseIdTokenForSession: true, // Explicitly can be set if needed } + tOidc.tokenVerifier = tOidc + tOidc.jwtVerifier = tOidc + tOidc.tokenExchanger = tOidc // Initialize and parse header templates for _, header := range tc.headers { @@ -502,6 +505,8 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) { extractClaimsFunc: extractClaims, headerTemplates: make(map[string]*template.Template), } + tOidc.tokenVerifier = tOidc + tOidc.jwtVerifier = tOidc // Initialize and parse header templates for _, header := range tc.headers { diff --git a/token_handling_test.go b/token_handling_test.go index 84fffd1..0d7a103 100644 --- a/token_handling_test.go +++ b/token_handling_test.go @@ -164,6 +164,8 @@ func TestTokenTypeIntegration(t *testing.T) { extractClaimsFunc: extractClaims, headerTemplates: make(map[string]*template.Template), } + tOidc.tokenVerifier = tOidc + tOidc.jwtVerifier = tOidc // Initialize and parse header templates for _, header := range headers {