diff --git a/main.go b/main.go index 9e77b12..1821bfd 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,14 @@ import ( "golang.org/x/time/rate" ) +// min returns the smaller of x or y. +func min(x, y int) int { + if x > y { + return y + } + return x +} + // createDefaultHTTPClient creates a new http.Client with settings optimized for OIDC communication. // It configures the transport with specific timeouts (dial, keepalive, TLS handshake, idle connection), // connection limits (max idle, max per host), enables HTTP/2, and sets a default request timeout. @@ -442,6 +450,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) { metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger) if err != nil { t.logger.Errorf("Failed to get provider metadata: %v", err) + // Consider retrying or handling this more gracefully return } @@ -457,7 +466,8 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) { return } - t.logger.Error("Received nil metadata") + t.logger.Error("Received nil metadata during initialization") + // Consider what should happen if metadata is nil after GetMetadata returns no error } // updateMetadataEndpoints updates the relevant endpoint URL fields (jwksURL, authURL, tokenURL, etc.) @@ -497,6 +507,8 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) { if metadata != nil { t.updateMetadataEndpoints(metadata) t.logger.Debug("Successfully refreshed metadata") + } else { + t.logger.Error("Received nil metadata during refresh") } } } @@ -544,7 +556,7 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo if delay > maxDelay { delay = maxDelay } - l.Debugf("Failed to fetch provider metadata, retrying in %s", delay) + l.Debugf("Failed to fetch provider metadata (attempt %d/%d), retrying in %s. Error: %v", attempt+1, maxRetries, delay, err) time.Sleep(delay) } @@ -568,64 +580,56 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad return nil, fmt.Errorf("failed to fetch provider metadata: %w", err) } if resp == nil { - return nil, fmt.Errorf("received nil response from provider") + return nil, fmt.Errorf("received nil response from provider at %s", wellKnownURL) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("failed to fetch provider metadata: status code %d", resp.StatusCode) + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to fetch provider metadata from %s: status code %d, body: %s", wellKnownURL, resp.StatusCode, string(bodyBytes)) } var metadata ProviderMetadata if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { - return nil, fmt.Errorf("failed to decode provider metadata: %w", err) + // Attempt to read body for better error context if decoding fails + // Note: resp.Body might be partially read by Decode, so read remaining + bodyBytes, readErr := io.ReadAll(io.MultiReader(json.NewDecoder(resp.Body).Buffered(), resp.Body)) + if readErr != nil { + bodyBytes = []byte(fmt.Sprintf("(failed to read response body: %v)", readErr)) + } + return nil, fmt.Errorf("failed to decode provider metadata from %s: %w. Response body: %s", wellKnownURL, err, string(bodyBytes)) } return &metadata, nil } // ServeHTTP is the main entry point for incoming requests to the middleware. -// It orchestrates the OIDC authentication flow: -// 1. Waits for initial OIDC metadata discovery to complete (with timeout). -// 2. Checks if the request path is excluded from authentication. -// 3. Checks if the request is for Server-Sent Events and bypasses if so. -// 4. Retrieves the user's session; initiates authentication if the session is invalid/missing. -// 5. Handles specific paths for OIDC callback (/callback) and logout (/logout). -// 6. Checks the user's authentication status using isUserAuthenticated (verifies token, checks expiry). -// 7. If the token is expired, handles it (initiates re-auth). -// 8. If the user is not authenticated, initiates authentication. -// 9. If the token needs proactive refresh (nearing expiry), attempts refreshToken. Handles refresh failure -// by returning 401 for API clients or initiating re-auth for browsers. -// 10. If authenticated and token is valid, performs authorization checks (allowed domain, roles/groups). -// 11. If authorized, sets user/token information in request headers (X-Forwarded-User, X-Auth-Request-*) -// and adds security headers (X-Frame-Options, etc.) to the response. -// 12. Forwards the request to the next handler in the chain. +// It orchestrates the OIDC authentication flow. func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // --- Initialization Check --- select { case <-t.initComplete: - if t.issuerURL == "" { - t.logger.Error("OIDC provider metadata initialization failed") - http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability", http.StatusServiceUnavailable) + if t.issuerURL == "" { // Check if initialization actually succeeded + t.logger.Error("OIDC provider metadata initialization failed or incomplete") + http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability and configuration", http.StatusServiceUnavailable) return } case <-req.Context().Done(): - t.logger.Debug("Request cancelled") - http.Error(rw, "Request cancelled", http.StatusServiceUnavailable) + t.logger.Debug("Request cancelled while waiting for OIDC initialization") + http.Error(rw, "Request cancelled", http.StatusRequestTimeout) // 408 might be more appropriate return - case <-time.After(30 * time.Second): + case <-time.After(30 * time.Second): // Timeout for initialization t.logger.Error("Timeout waiting for OIDC initialization") - http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again", http.StatusServiceUnavailable) + http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again later", http.StatusServiceUnavailable) return } - // Check if URL is excluded + // --- Excluded Paths & SSE Check --- if t.determineExcludedURL(req.URL.Path) { t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path) t.next.ServeHTTP(rw, req) return } - - // Check if the request expects Server-Sent Events acceptHeader := req.Header.Get("Accept") if strings.Contains(acceptHeader, "text/event-stream") { t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader) @@ -633,80 +637,119 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - // Get session + // --- Session Retrieval --- session, err := t.sessionManager.GetSession(req) if err != nil { - t.logger.Errorf("Error getting session: %v", err) - - // Obtain a new session and clear any residual session cookies - session, _ = t.sessionManager.GetSession(req) - session.Clear(req, rw) - - // Build redirect URL + // Log the specific session error + t.logger.Errorf("Error getting session: %v. Initiating authentication.", err) + // Attempt to get a new session to store CSRF etc. + session, _ = t.sessionManager.GetSession(req) // Ignore error here, proceed with new session + if session != nil { + // Pass rw to ensure expiring cookies are sent if possible + if clearErr := session.Clear(req, rw); clearErr != nil { + t.logger.Errorf("Error clearing potentially corrupted session: %v", clearErr) + } + } else { + // If even getting a new session fails, something is very wrong + t.logger.Error("Critical session error: Failed to get even a new session.") + http.Error(rw, "Critical session error", http.StatusInternalServerError) + return + } scheme := t.determineScheme(req) host := t.determineHost(req) redirectURL := buildFullURL(scheme, host, t.redirURLPath) - - // Initiate authentication t.defaultInitiateAuthentication(rw, req, session, redirectURL) return } - // Build redirect URL + // --- URL Handling (Callback, Logout) --- scheme := t.determineScheme(req) host := t.determineHost(req) - redirectURL := buildFullURL(scheme, host, t.redirURLPath) + redirectURL := buildFullURL(scheme, host, t.redirURLPath) // Used for callback and re-auth - // Handle special URLs if req.URL.Path == t.logoutURLPath { t.handleLogout(rw, req) return } - if req.URL.Path == t.redirURLPath { t.handleCallback(rw, req, redirectURL) return } - // Check authentication status + // --- Authentication & Refresh Logic --- authenticated, needsRefresh, expired := t.isUserAuthenticated(session) if expired { + t.logger.Debug("Session token is definitively expired or invalid, initiating re-auth") + // handleExpiredToken clears the session and initiates auth t.handleExpiredToken(rw, req, session, redirectURL) return } - if !authenticated { - // Original logic: Always initiate authentication if not authenticated - t.logger.Debug("User not authenticated, initiating OIDC flow") - t.defaultInitiateAuthentication(rw, req, session, redirectURL) + // If authenticated and token doesn't need proactive refresh, proceed directly + if authenticated && !needsRefresh { + t.logger.Debug("User authenticated and token valid, proceeding to process authorized request") + t.processAuthorizedRequest(rw, req, session, redirectURL) + return + } + + // --- Attempt Refresh if Needed or Possible --- + // Conditions to attempt refresh: + // 1. Token needs proactive refresh (authenticated=true, needsRefresh=true) + // 2. Token is invalid/expired but a refresh token exists (authenticated=false, needsRefresh=true) + refreshTokenPresent := session.GetRefreshToken() != "" + shouldAttemptRefresh := needsRefresh && refreshTokenPresent + + if shouldAttemptRefresh { + if needsRefresh && authenticated { + t.logger.Debug("Session token needs proactive refresh, attempting refresh") + } else if needsRefresh && !authenticated { + t.logger.Debug("Access token invalid/expired, but refresh token found. Attempting refresh.") + } + + refreshed := t.refreshToken(rw, req, session) + if refreshed { + // Refresh succeeded, proceed to authorization checks + t.logger.Debug("Token refresh successful, proceeding to process authorized request") + t.processAuthorizedRequest(rw, req, session, redirectURL) + return + } + + // Refresh failed + t.logger.Infof("Token refresh failed (authenticated=%v, needsRefresh=%v, refreshTokenPresent=%v)", authenticated, needsRefresh, refreshTokenPresent) + // Handle refresh failure (401 for API, re-auth for browser) + acceptHeader := req.Header.Get("Accept") + if strings.Contains(acceptHeader, "application/json") { + t.logger.Debug("Client accepts JSON, sending 401 Unauthorized on refresh failure") + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(rw).Encode(map[string]string{"error": "unauthorized", "message": "Token refresh failed"}) + } else { + t.logger.Debug("Client does not prefer JSON, handling refresh failure by initiating re-auth") + // Use defaultInitiateAuthentication which clears the session properly + t.defaultInitiateAuthentication(rw, req, session, redirectURL) + } return // Stop processing } - if needsRefresh { - refreshed := t.refreshToken(rw, req, session) - if !refreshed { - t.logger.Infof("Token refresh failed") // Changed from Warn to Infof - // Check if the client prefers JSON (likely an API call) - acceptHeader := req.Header.Get("Accept") - if strings.Contains(acceptHeader, "application/json") { - t.logger.Debug("Client accepts JSON, sending 401 Unauthorized on refresh failure") - rw.Header().Set("Content-Type", "application/json") - rw.WriteHeader(http.StatusUnauthorized) - json.NewEncoder(rw).Encode(map[string]string{"error": "unauthorized", "message": "Token refresh failed"}) - } else { - // Client likely a browser, initiate full re-authentication - t.logger.Debug("Client does not prefer JSON, handling refresh failure as expired token (initiating re-auth)") - t.handleExpiredToken(rw, req, session, redirectURL) - } - return // Stop processing - } - } + // --- Initiate Full Authentication --- + // If we reach here, it means: + // - User is not authenticated (!authenticated) + // - AND EITHER token doesn't need refresh (!needsRefresh, e.g., first visit) + // - OR refresh token is missing (!refreshTokenPresent) + // - OR refresh was attempted but failed (handled above) + t.logger.Debugf("Initiating full OIDC authentication flow (authenticated=%v, needsRefresh=%v, refreshTokenPresent=%v)", authenticated, needsRefresh, refreshTokenPresent) + t.defaultInitiateAuthentication(rw, req, session, redirectURL) +} - // Process authenticated request +// processAuthorizedRequest handles the final steps for an authenticated and authorized request. +// It performs domain/role/group checks, sets headers, and forwards the request. +func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { email := session.GetEmail() if email == "" { - t.logger.Debug("No email found in session") + t.logger.Error("CRITICAL: No email found in session during final processing, initiating re-auth") + // This case should ideally not happen if checks are done correctly before calling this, + // but as a safeguard, initiate re-authentication. t.defaultInitiateAuthentication(rw, req, session, redirectURL) return } @@ -721,6 +764,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken()) if err != nil { t.logger.Errorf("Failed to extract groups and roles: %v", err) + // Continue without group/role headers if extraction fails } else { if len(groups) > 0 { req.Header.Set("X-User-Groups", strings.Join(groups, ",")) @@ -779,6 +823,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } // Process the request + t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email) t.next.ServeHTTP(rw, req) } @@ -794,19 +839,21 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // - session: The user's session data containing the expired token information. // - redirectURL: The callback URL to be used in the new authentication flow. func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { - // Clear authentication data but preserve CSRF state + t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.") + // Clear authentication data but preserve CSRF state if possible (though Clear might remove it) session.SetAuthenticated(false) session.SetAccessToken("") session.SetRefreshToken("") session.SetEmail("") - // Save the cleared session state + // Save the cleared session state (this sends expired cookies) + // Pass rw to ensure expiring cookies are sent if err := session.Save(req, rw); err != nil { - t.logger.Errorf("Failed to save cleared session: %v", err) - http.Error(rw, "Internal Server Error", http.StatusInternalServerError) - return + t.logger.Errorf("Failed to save cleared session during expired token handling: %v", err) + // Still attempt to initiate authentication, but log the error } + // Initiate a new authentication flow t.defaultInitiateAuthentication(rw, req, session, redirectURL) } @@ -834,8 +881,8 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) { session, err := t.sessionManager.GetSession(req) if err != nil { - t.logger.Errorf("Session error: %v", err) - http.Error(rw, "Session error", http.StatusInternalServerError) + t.logger.Errorf("Session error during callback: %v", err) + http.Error(rw, "Session error during callback", http.StatusInternalServerError) return } @@ -847,7 +894,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, if errorDescription == "" { errorDescription = req.URL.Query().Get("error") // Use error code if description is empty } - t.logger.Errorf("Authentication error from provider: %s - %s", req.URL.Query().Get("error"), errorDescription) + t.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription) t.sendErrorResponse(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest) return } @@ -862,13 +909,13 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, csrfToken := session.GetCSRF() if csrfToken == "" { - t.logger.Error("CSRF token missing in session") - t.sendErrorResponse(rw, req, "CSRF token missing", http.StatusBadRequest) + t.logger.Error("CSRF token missing in session during callback") + t.sendErrorResponse(rw, req, "CSRF token missing in session", http.StatusBadRequest) return } if state != csrfToken { - t.logger.Error("State parameter does not match CSRF token in session") + t.logger.Error("State parameter does not match CSRF token in session during callback") t.sendErrorResponse(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest) return } @@ -886,22 +933,21 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier) if err != nil { - t.logger.Errorf("Failed to exchange code for token: %v", err) + t.logger.Errorf("Failed to exchange code for token during callback: %v", err) t.sendErrorResponse(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError) return } // Verify tokens and claims - // Use the exported VerifyToken method now that handleCallback is in main.go if err := t.VerifyToken(tokenResponse.IDToken); err != nil { - t.logger.Errorf("Failed to verify id_token: %v", err) + t.logger.Errorf("Failed to verify id_token during callback: %v", err) t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError) return } claims, err := t.extractClaimsFunc(tokenResponse.IDToken) if err != nil { - t.logger.Errorf("Failed to extract claims: %v", err) + t.logger.Errorf("Failed to extract claims during callback: %v", err) t.sendErrorResponse(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError) return } @@ -909,51 +955,68 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, // Verify nonce to prevent replay attacks nonceClaim, ok := claims["nonce"].(string) if !ok || nonceClaim == "" { - t.logger.Error("Nonce claim missing in id_token") + t.logger.Error("Nonce claim missing in id_token during callback") t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError) return } sessionNonce := session.GetNonce() if sessionNonce == "" { - t.logger.Error("Nonce not found in session") + t.logger.Error("Nonce not found in session during callback") t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError) return } if nonceClaim != sessionNonce { - t.logger.Error("Nonce claim does not match session nonce") + t.logger.Error("Nonce claim does not match session nonce during callback") t.sendErrorResponse(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError) return } // Validate user's email domain - // Use the unexported isAllowedDomain method now that handleCallback is in main.go email, _ := claims["email"].(string) - if email == "" || !t.isAllowedDomain(email) { - t.logger.Errorf("Invalid or disallowed email: %s", email) - t.sendErrorResponse(rw, req, "Authentication failed: Invalid or disallowed email", http.StatusForbidden) + if email == "" { + t.logger.Errorf("Email claim missing or empty in token during callback") + t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError) + return + } + if !t.isAllowedDomain(email) { + t.logger.Errorf("Disallowed email domain during callback: %s", email) + t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden) return } // Update session with authentication data - session.SetAuthenticated(true) + // Regenerate session ID upon successful authentication + if err := session.SetAuthenticated(true); err != nil { + t.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err) + http.Error(rw, "Failed to update session", http.StatusInternalServerError) + return + } session.SetEmail(email) session.SetAccessToken(tokenResponse.IDToken) session.SetRefreshToken(tokenResponse.RefreshToken) - if err := session.Save(req, rw); err != nil { - t.logger.Errorf("Failed to save session: %v", err) - http.Error(rw, "Failed to save session", http.StatusInternalServerError) - return - } + // Clear CSRF, Nonce, CodeVerifier after use + session.SetCSRF("") + session.SetNonce("") + session.SetCodeVerifier("") - // Redirect to original path or root + // Retrieve original path *before* saving, as save might clear it if Clear was called concurrently redirectPath := "/" if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath { redirectPath = incomingPath } + session.SetIncomingPath("") // Clear incoming path after retrieving it + if err := session.Save(req, rw); err != nil { + t.logger.Errorf("Failed to save session after callback: %v", err) + http.Error(rw, "Failed to save session after callback", http.StatusInternalServerError) + return + } + + // Redirect to original path or root + t.logger.Debugf("Callback successful, redirecting to %s", redirectPath) http.Redirect(rw, req, redirectPath, http.StatusFound) } @@ -972,7 +1035,7 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool { return true } } - t.logger.Debugf("URL is not excluded - got %s", currentRequest) + // t.logger.Debugf("URL is not excluded - got %s", currentRequest) // Too verbose for every request return false } @@ -1024,32 +1087,58 @@ func (t *TraefikOidc) determineHost(req *http.Request) string { // - expired (bool): True if the session is unauthenticated, the token is missing, or the token verification failed for reasons other than nearing/actual expiration (e.g., invalid signature, invalid claims). func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) { if !session.GetAuthenticated() { - t.logger.Debug("User is not authenticated according to session") - return false, false, false + t.logger.Debug("User is not authenticated according to session flag") + // Check if there's still a refresh token - if so, refresh might be possible + if session.GetRefreshToken() != "" { + t.logger.Debug("Session not authenticated, but refresh token exists. Signaling need for refresh.") + return false, true, false // Not authenticated, NeedsRefresh=true (to attempt recovery), Expired=false + } + return false, false, false // Not authenticated, no refresh token, definitely not expired (just unauth) } accessToken := session.GetAccessToken() if accessToken == "" { - t.logger.Debug("No access token found in session") - return false, false, true // Session is invalid, consider it expired + t.logger.Debug("Authenticated flag set, but no access token found in session") + // If authenticated flag is true but token is missing, treat as expired/invalid session state + // Check for refresh token before declaring fully expired + if session.GetRefreshToken() != "" { + t.logger.Debug("Authenticated flag set, access token missing, but refresh token exists. Signaling need for refresh.") + return false, true, false // Not authenticated (no access token), NeedsRefresh=true, Expired=false + } + return false, false, true // No access or refresh token, treat as expired } // Verify the token structure and signature first jwt, err := parseJWT(accessToken) if err != nil { t.logger.Errorf("Failed to parse JWT during auth check: %v", err) - return false, false, true // Invalid format, treat as expired/invalid + // Check for refresh token before declaring fully expired + if session.GetRefreshToken() != "" { + t.logger.Debug("Access token parsing failed, but refresh token exists. Signaling need for refresh.") + return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false + } + return false, false, true // Invalid format, no refresh token, treat as expired/invalid } if err := t.VerifyJWTSignatureAndClaims(jwt, accessToken); err != nil { // Check if the error is specifically about expiration if strings.Contains(err.Error(), "token has expired") { - t.logger.Debugf("Token signature/claims valid but token expired, attempting refresh") + t.logger.Debugf("Access token signature/claims valid but token expired, needs refresh") // Token is expired but otherwise valid, signal for refresh - return true, true, false // Authenticated=true (was valid), NeedsRefresh=true, Expired=false (because refresh is possible) + // Return authenticated=false because the current token is unusable + // NeedsRefresh is true only if a refresh token exists + if session.GetRefreshToken() != "" { + return false, true, false // Not authenticated (current token unusable), NeedsRefresh=true, Expired=false (because refresh might fix it) + } + return false, false, true // Expired access token, no refresh token, treat as expired } // Other verification error (signature, issuer, audience etc.) - t.logger.Errorf("Token verification failed (non-expiration): %v", err) - return false, false, true // Token is invalid for other reasons + t.logger.Errorf("Access token verification failed (non-expiration): %v", err) + // Check for refresh token before declaring fully expired + if session.GetRefreshToken() != "" { + t.logger.Debug("Access token verification failed, but refresh token exists. Signaling need for refresh.") + return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false + } + return false, false, true // Token is invalid for other reasons, no refresh token, treat as expired/invalid session } // Claims already parsed within VerifyJWTSignatureAndClaims if it didn't error early @@ -1057,8 +1146,13 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo expClaim, ok := claims["exp"].(float64) if !ok { - t.logger.Error("Failed to get expiration time from claims") - return false, false, true + t.logger.Error("Failed to get expiration time ('exp' claim) from verified token") + // Check for refresh token before declaring fully expired + if session.GetRefreshToken() != "" { + t.logger.Debug("Access token missing 'exp' claim, but refresh token exists. Signaling need for refresh.") + return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false + } + return false, false, true // Treat as invalid if 'exp' is missing and no refresh token } expTime := int64(expClaim) @@ -1071,12 +1165,19 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo if time.Unix(expTime, 0).Before(time.Now().Add(t.refreshGracePeriod)) { // Recalculate remaining seconds for logging clarity if needed, using the configured duration remainingSeconds := int64(time.Until(time.Unix(expTime, 0)).Seconds()) - t.logger.Debugf("Token nearing expiration (expires in %d seconds, grace period %s), scheduling refresh", remainingSeconds, t.refreshGracePeriod) - return true, true, false // Needs proactive refresh + t.logger.Debugf("Access token nearing expiration (expires in %d seconds, grace period %s), scheduling proactive refresh", remainingSeconds, t.refreshGracePeriod) + // Token is still valid, but we should refresh it soon + // NeedsRefresh is true only if a refresh token exists + if session.GetRefreshToken() != "" { + return true, true, false // Authenticated=true (current token usable), NeedsRefresh=true, Expired=false + } + // If no refresh token, we can't proactively refresh, treat as normal valid token for now + t.logger.Debugf("Token nearing expiration but no refresh token available, cannot proactively refresh.") + return true, false, false } // Token is valid, not expired, and not nearing expiration - return true, false, false + return true, false, false // Authenticated=true, NeedsRefresh=false, Expired=false } // defaultInitiateAuthentication handles the process of starting an OIDC authentication flow. @@ -1092,10 +1193,12 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo // - session: The user's SessionData object (potentially new or cleared). // - redirectURL: The pre-calculated callback URL (redirect_uri) for this middleware instance. func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { + t.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI()) // Generate CSRF token and nonce csrfToken := uuid.NewString() nonce, err := generateNonce() if err != nil { + t.logger.Errorf("Failed to generate nonce: %v", err) http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError) return } @@ -1106,37 +1209,41 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req var err error codeVerifier, err = generateCodeVerifier() if err != nil { + t.logger.Errorf("Failed to generate code verifier: %v", err) http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError) return } - - // Derive code challenge from verifier codeChallenge = deriveCodeChallenge(codeVerifier) + t.logger.Debugf("PKCE enabled, generated code challenge") } // Clear any existing session data to avoid stale state causing redirect loops - session.Clear(req, rw) + // Pass the response writer to ensure expiring cookies are sent + if err := session.Clear(req, rw); err != nil { + // Log the error but continue, as clearing is best-effort before re-auth + t.logger.Errorf("Error clearing session before initiating authentication: %v", err) + } // Set new session values session.SetCSRF(csrfToken) session.SetNonce(nonce) - - // Only set code verifier if PKCE is enabled if t.enablePKCE { session.SetCodeVerifier(codeVerifier) } - + // Store the original path the user was trying to access session.SetIncomingPath(req.URL.RequestURI()) + t.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI()) - // Save the session + // Save the session (to store CSRF, Nonce, etc.) if err := session.Save(req, rw); err != nil { - t.logger.Errorf("Failed to save session: %v", err) + t.logger.Errorf("Failed to save session before redirecting to provider: %v", err) http.Error(rw, "Failed to save session", http.StatusInternalServerError) return } // Build and redirect to authentication URL authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge) + t.logger.Debugf("Redirecting user to OIDC provider: %s", authURL) http.Redirect(rw, req, authURL, http.StatusFound) } @@ -1185,6 +1292,7 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri params.Set("scope", strings.Join(t.scopes, " ")) } + // Use buildURLWithParams which handles potential relative authURL from metadata return t.buildURLWithParams(t.authURL, params) } @@ -1201,17 +1309,30 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string { // Ensure URL is absolute if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - // Extract issuer base URL - issuerURL, err := url.Parse(t.issuerURL) + // Attempt to resolve relative URL against issuer URL + issuerURLParsed, err := url.Parse(t.issuerURL) if err == nil { - return fmt.Sprintf("%s://%s%s?%s", - issuerURL.Scheme, - issuerURL.Host, - baseURL, - params.Encode()) + baseURLParsed, err := url.Parse(baseURL) + if err == nil { + resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed) + resolvedURL.RawQuery = params.Encode() + return resolvedURL.String() + } } + // Fallback if parsing fails - append params to potentially relative path + t.logger.Errorf("Could not parse issuerURL or baseURL to resolve relative URL. BaseURL: %s, IssuerURL: %s", baseURL, t.issuerURL) + return baseURL + "?" + params.Encode() } - return baseURL + "?" + params.Encode() + + // If baseURL is already absolute + u, err := url.Parse(baseURL) + if err != nil { + t.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err) + // Fallback: append params directly + return baseURL + "?" + params.Encode() + } + u.RawQuery = params.Encode() + return u.String() } // startTokenCleanup starts background goroutines for periodically cleaning up @@ -1246,6 +1367,7 @@ func (t *TraefikOidc) RevokeToken(token string) { expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration // Use Set with a duration. Value 'true' is arbitrary, we only care about existence. t.tokenBlacklist.Set(token, true, time.Until(expiry)) + t.logger.Debugf("Locally revoked token (added to blacklist)") } // RevokeTokenWithProvider attempts to revoke a token directly with the OIDC provider @@ -1260,7 +1382,10 @@ func (t *TraefikOidc) RevokeToken(token string) { // - nil if the revocation request is successful (provider returns 200 OK). // - An error if the request fails or the provider returns a non-OK status. func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error { - t.logger.Debugf("Revoking token with provider") + if t.revocationURL == "" { + return fmt.Errorf("token revocation endpoint is not configured or discovered") + } + t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, t.revocationURL) data := url.Values{ "token": {token}, @@ -1277,6 +1402,7 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error { // Set headers req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") // Prefer JSON response if available // Send the request resp, err := t.httpClient.Do(req) @@ -1288,18 +1414,20 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error { // Check the response if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("token revocation failed with status %d: %s", resp.StatusCode, string(body)) + // Log the failure details + t.logger.Errorf("Token revocation failed with status %d: %s", resp.StatusCode, string(body)) + return fmt.Errorf("token revocation failed with status %d", resp.StatusCode) } - t.logger.Debugf("Token successfully revoked") + t.logger.Debugf("Token successfully revoked with provider") return nil } // refreshToken attempts to use the refresh token stored in the session to obtain a new set of tokens. // It acquires a mutex associated with the session to prevent concurrent refresh attempts for the same session. // It retrieves the refresh token, calls the TokenExchanger's GetNewTokenWithRefreshToken method, -// verifies the newly obtained ID token using verifyToken, updates the session with the new tokens, -// and saves the session. +// verifies the newly obtained ID token using verifyToken, performs a concurrency check, +// updates the session with the new tokens if the check passes, and saves the session. // // Parameters: // - rw: The HTTP response writer (needed for saving the updated session). @@ -1308,45 +1436,85 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error { // // Returns: // - true if the token refresh was successful and the session was updated. -// - false if no refresh token was found, the refresh exchange failed, the new token failed verification, or saving the session failed. +// - false if no refresh token was found, the refresh exchange failed, the new token failed verification, +// a concurrency conflict was detected, or saving the session failed. func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool { // Lock the mutex specific to this session instance before attempting refresh session.refreshMutex.Lock() defer session.refreshMutex.Unlock() t.logger.Debug("Attempting to refresh token (mutex acquired)") - refreshToken := session.GetRefreshToken() // Get token *after* acquiring lock - if refreshToken == "" { - t.logger.Debug("No refresh token found in session (inside lock)") + initialRefreshToken := session.GetRefreshToken() // Get token *after* acquiring lock + if initialRefreshToken == "" { + t.logger.Errorf("refreshToken failed: No refresh token found in session (after acquiring lock)") return false } - newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken) + // Store the initial token for later comparison + t.logger.Debugf("Attempting refresh with token starting with %s...", initialRefreshToken[:min(len(initialRefreshToken), 10)]) + + newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken) if err != nil { - // Log the error, potentially clear the invalid refresh token? - t.logger.Errorf("Failed to refresh token using refresh token: %v", err) - // Consider clearing the refresh token from the session here if the error indicates it's invalid - // session.SetRefreshToken("") // Example: Clear potentially invalid token - // session.Save(req, rw) // Need to handle potential save error + // Log the error more explicitly before returning false + truncatedToken := initialRefreshToken[:min(len(initialRefreshToken), 10)] // Log first 10 chars + t.logger.Errorf("refreshToken failed: Error from tokenExchanger.GetNewTokenWithRefreshToken for token starting with %s...: %v", truncatedToken, err) + // No need to clear token here, as the session might be cleared by another request anyway return false } - // Verify the new access token + // Verify the new access token (ID token) if err := t.verifyToken(newToken.IDToken); err != nil { - t.logger.Errorf("Failed to verify new access token: %v", err) + truncatedNewToken := newToken.IDToken[:min(len(newToken.IDToken), 10)] // Log first 10 chars + t.logger.Errorf("refreshToken failed: Failed to verify newly obtained ID token starting with %s...: %v", truncatedNewToken, err) return false } - // Update session with new tokens + // --- Concurrency Check --- + // Before saving the new token, check if the session state (specifically the refresh token) + // has been modified concurrently (e.g., by a logout or another auth initiation). + currentRefreshToken := session.GetRefreshToken() // Get token again *after* the potentially long exchange + if initialRefreshToken != currentRefreshToken { + // Use Infof as Warnf doesn't exist + t.logger.Infof("refreshToken aborted: Session refresh token changed concurrently during refresh attempt. Initial token prefix: %s..., Current token prefix: %s...", initialRefreshToken[:min(len(initialRefreshToken), 10)], currentRefreshToken[:min(len(currentRefreshToken), 10)]) + // Do not save the new tokens, as the session state is likely invalid/cleared. + return false // Indicate refresh failure due to concurrency conflict + } + // --- End Concurrency Check --- + + // Update session with new tokens ONLY if the concurrency check passed + t.logger.Debugf("Concurrency check passed. Updating session with new tokens.") + + // Extract email from the new token and update session + claims, err := t.extractClaimsFunc(newToken.IDToken) + if err != nil { + t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err) + return false // Cannot proceed without claims + } + email, _ := claims["email"].(string) + if email == "" { + t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token") + return false // Cannot proceed without email + } + session.SetEmail(email) // Update email in session + session.SetAccessToken(newToken.IDToken) - session.SetRefreshToken(newToken.RefreshToken) + // Ensure the new refresh token is actually set, even if it's the same as the old one + // Also handle cases where the provider might not return a new refresh token + if newToken.RefreshToken != "" { + session.SetRefreshToken(newToken.RefreshToken) + } else { + // If no new refresh token is returned, keep the existing one + session.SetRefreshToken(initialRefreshToken) + t.logger.Debugf("Provider did not return a new refresh token, keeping the existing one.") + } // Save the session if err := session.Save(req, rw); err != nil { - t.logger.Errorf("Failed to save refreshed session: %v", err) + t.logger.Errorf("refreshToken failed: Failed to save session after successful token refresh and concurrency check: %v", err) return false } + t.logger.Debugf("Token refresh successful and session saved.") return true } @@ -1367,6 +1535,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool { parts := strings.Split(email, "@") if len(parts) != 2 { + t.logger.Errorf("Invalid email format encountered: %s", email) return false // Invalid email format } @@ -1400,12 +1569,16 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, if groupsClaim, exists := claims["groups"]; exists { groupsSlice, ok := groupsClaim.([]interface{}) if !ok { + // Strictly expect an array return nil, nil, fmt.Errorf("groups claim is not an array") - } - for _, group := range groupsSlice { - if groupStr, ok := group.(string); ok { - t.logger.Debugf("Found group: %s", groupStr) - groups = append(groups, groupStr) + } else { + for _, group := range groupsSlice { + if groupStr, ok := group.(string); ok { + t.logger.Debugf("Found group: %s", groupStr) + groups = append(groups, groupStr) + } else { + t.logger.Errorf("Non-string value found in groups claim array: %v", group) + } } } } @@ -1414,12 +1587,16 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, if rolesClaim, exists := claims["roles"]; exists { rolesSlice, ok := rolesClaim.([]interface{}) if !ok { + // Strictly expect an array return nil, nil, fmt.Errorf("roles claim is not an array") - } - for _, role := range rolesSlice { - if roleStr, ok := role.(string); ok { - t.logger.Debugf("Found role: %s", roleStr) - roles = append(roles, roleStr) + } else { + for _, role := range rolesSlice { + if roleStr, ok := role.(string); ok { + t.logger.Debugf("Found role: %s", roleStr) + roles = append(roles, roleStr) + } else { + t.logger.Errorf("Non-string value found in roles claim array: %v", role) + } } } } @@ -1493,8 +1670,8 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques rw.WriteHeader(code) // Use a simple error structure json.NewEncoder(rw).Encode(map[string]interface{}{ - "error": http.StatusText(code), - "error_description": message, + "error": http.StatusText(code), // Use standard text for the code + "error_description": message, // Provide specific detail here "status_code": code, }) return @@ -1504,17 +1681,9 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques t.logger.Debugf("Sending HTML error response (code %d): %s", code, message) // Determine the return URL (mostly relevant for HTML) - returnURL := "/" // Default to root - session, err := t.sessionManager.GetSession(req) // Attempt to get session for return URL - if err == nil { - incomingPath := session.GetIncomingPath() - // Use incoming path if it's valid and not one of the special OIDC paths - if incomingPath != "" && incomingPath != t.redirURLPath && incomingPath != t.logoutURLPath { - returnURL = incomingPath - } - } else { - t.logger.Infof("Could not get session to determine return URL in sendErrorResponse: %v", err) - } + returnURL := "/" // Default to root + // No need to get session here, as we are already in an error path + // where session might be invalid or unavailable. // Basic HTML structure for the error page htmlBody := fmt.Sprintf(` @@ -1537,7 +1706,7 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques

Return to application

-`, message, returnURL) +`, message, returnURL) // Use default returnURL rw.Header().Set("Content-Type", "text/html; charset=utf-8") rw.WriteHeader(code) diff --git a/main_test.go b/main_test.go index b86c325..97acce3 100644 --- a/main_test.go +++ b/main_test.go @@ -385,9 +385,52 @@ func TestServeHTTP(t *testing.T) { expectedBody: "OK", }, { - name: "Unauthenticated request to protected URL", - requestPath: "/protected", - expectedStatus: http.StatusFound, // Expect redirect to OIDC + name: "Unauthenticated request (no refresh token) to protected URL", + requestPath: "/protected", + setupSession: func(session *SessionData) { + // Ensure no tokens are set + session.SetAuthenticated(false) + session.SetAccessToken("") + session.SetRefreshToken("") + }, + expectedStatus: http.StatusFound, // Expect redirect to OIDC as there's no refresh token + }, + { + name: "Unauthenticated request (with refresh token) to protected URL - Expect Refresh Attempt", + requestPath: "/protected", + setupSession: func(session *SessionData) { + session.SetAuthenticated(false) // Not authenticated + session.SetAccessToken("") // No access token + session.SetRefreshToken("valid-refresh-token-for-unauth-test") // BUT has refresh token + }, + mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { + return func(refreshToken string) (*TokenResponse, error) { + if refreshToken != "valid-refresh-token-for-unauth-test" { + return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken) + } + // Simulate successful refresh + newToken := createNewValidToken() // Use helper from TestServeHTTP + return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-unauth", ExpiresIn: 3600}, nil + } + }, + expectedStatus: http.StatusOK, // Expect OK after successful refresh + expectedBody: "OK", + }, + { + name: "Unauthenticated request (with refresh token) to protected URL - Refresh Fails", + requestPath: "/protected", + setupSession: func(session *SessionData) { + session.SetAuthenticated(false) // Not authenticated + session.SetAccessToken("") // No access token + session.SetRefreshToken("invalid-refresh-token-for-unauth-test") // Invalid refresh token + }, + mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { + return func(refreshToken string) (*TokenResponse, error) { + // Simulate failed refresh + return nil, fmt.Errorf("mock error: refresh token invalid") + } + }, + expectedStatus: http.StatusFound, // Expect redirect to OIDC after failed refresh }, { name: "Authenticated request to protected URL (Valid Token)", @@ -407,11 +450,15 @@ func TestServeHTTP(t *testing.T) { expectedStatus: http.StatusOK, expectedBody: "OK", }, + // This test case remains valid as the logic should still attempt refresh when expired token + refresh token exist { name: "Authenticated request with expired token and successful refresh", requestPath: "/protected", setupSession: func(session *SessionData) { - session.SetAuthenticated(true) // Still marked authenticated initially + // NOTE: isUserAuthenticated now returns authenticated=false if access token is expired, + // even if session.SetAuthenticated(true) was called. + // We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt. + session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token session.SetEmail("user@example.com") session.SetAccessToken(createExpiredToken()) // Set expired token session.SetRefreshToken("valid-refresh-token") // Set valid refresh token @@ -445,16 +492,19 @@ func TestServeHTTP(t *testing.T) { t.Fatalf("Failed to get session after request: %v", err) } // Assert new tokens are in the session - // Direct comparison with createNewValidToken() is flawed as it generates a new token each time. - // Instead, check if the token was updated (not empty) and verify the refresh token. - if session.GetAccessToken() == "" { - t.Errorf("Expected access token to be updated in session, but it was empty") + if session.GetAccessToken() == "" || session.GetAccessToken() == createExpiredToken() { + t.Errorf("Expected access token to be updated in session, but it was empty or still the expired one") } if session.GetRefreshToken() != "new-refresh-token" { t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken()) } + // Also check authenticated flag is now true + if !session.GetAuthenticated() { + t.Errorf("Expected session to be marked authenticated after successful refresh") + } }, }, + // This test case remains valid as the logic should still return 401 for API clients on refresh failure { name: "Logout URL", requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup @@ -477,10 +527,10 @@ func TestServeHTTP(t *testing.T) { name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)", requestPath: "/protected", setupSession: func(session *SessionData) { - session.SetAuthenticated(true) + session.SetAuthenticated(true) // Set flag initially session.SetEmail("user@example.com") - session.SetAccessToken(createExpiredToken()) - session.SetRefreshToken("valid-refresh-token") + session.SetAccessToken(createExpiredToken()) // Expired access token + session.SetRefreshToken("valid-refresh-token") // Valid refresh token }, mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { return func(refreshToken string) (*TokenResponse, error) { @@ -491,17 +541,18 @@ func TestServeHTTP(t *testing.T) { requestHeaders: map[string]string{ "Accept": "application/json", }, - expectedStatus: http.StatusUnauthorized, // Expect 401 for API client + expectedStatus: http.StatusUnauthorized, // Expect 401 for API client after failed refresh attempt expectedBody: `{"error":"unauthorized","message":"Token refresh failed"}`, }, + // This test case remains valid as the logic should still redirect browser clients on refresh failure { name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)", requestPath: "/protected", setupSession: func(session *SessionData) { - session.SetAuthenticated(true) + session.SetAuthenticated(true) // Set flag initially session.SetEmail("user@example.com") - session.SetAccessToken(createExpiredToken()) - session.SetRefreshToken("valid-refresh-token") + session.SetAccessToken(createExpiredToken()) // Expired access token + session.SetRefreshToken("valid-refresh-token") // Valid refresh token }, mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { return func(refreshToken string) (*TokenResponse, error) { @@ -512,8 +563,9 @@ func TestServeHTTP(t *testing.T) { requestHeaders: map[string]string{ "Accept": "text/html", // Browser client }, - expectedStatus: http.StatusFound, // Expect redirect for browser client + expectedStatus: http.StatusFound, // Expect redirect to OIDC for browser client after failed refresh attempt }, + // This test case remains valid as proactive refresh should still be attempted { name: "Authenticated request with token nearing expiry (needs refresh)", requestPath: "/protected", @@ -529,7 +581,7 @@ func TestServeHTTP(t *testing.T) { session.SetAuthenticated(true) session.SetEmail("user@example.com") session.SetAccessToken(nearExpiryToken) - session.SetRefreshToken("valid-refresh-token-for-near-expiry") + session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh }, mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { return func(refreshToken string) (*TokenResponse, error) { @@ -544,6 +596,7 @@ func TestServeHTTP(t *testing.T) { expectedStatus: http.StatusOK, // Expect success after proactive refresh expectedBody: "OK", }, + // This test case remains valid as no refresh should be attempted { name: "Authenticated request with token valid (outside grace period)", requestPath: "/protected", @@ -1531,6 +1584,7 @@ func TestRevokeToken(t *testing.T) { tOidc := &TraefikOidc{ tokenBlacklist: NewCache(), // Use generic cache for blacklist tokenCache: NewTokenCache(), + logger: NewLogger("info"), // Initialize the logger } // Cache the token