Compare commits

...

3 Commits

Author SHA1 Message Date
lukaszraczylo 8a6e37f7fc Create LICENSE 2025-04-10 01:39:57 +01:00
lukaszraczylo bd7eaf6dff Bugfix: Refresh token not obtained when access token is expired. 2025-04-05 18:28:12 +01:00
lukaszraczylo 3df19e6d90 Update README.md 2025-04-05 14:56:28 +01:00
4 changed files with 444 additions and 199 deletions
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Lukasz Raczylo
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+9 -8
View File
@@ -69,14 +69,15 @@ The middleware supports the following configuration options:
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
| | `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
| | `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
| | `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
| | `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
| | `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| | `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
| | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
| `excludedURLs` | Lists paths that bypass authentication | none | `["/health", "/metrics", "/public"]` |
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
## Usage Examples
+343 -174
View File
@@ -17,6 +17,14 @@ import (
"golang.org/x/time/rate"
)
// min returns the smaller of x or y.
func min(x, y int) int {
if x > y {
return y
}
return x
}
// createDefaultHTTPClient creates a new http.Client with settings optimized for OIDC communication.
// It configures the transport with specific timeouts (dial, keepalive, TLS handshake, idle connection),
// connection limits (max idle, max per host), enables HTTP/2, and sets a default request timeout.
@@ -442,6 +450,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to get provider metadata: %v", err)
// Consider retrying or handling this more gracefully
return
}
@@ -457,7 +466,8 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
return
}
t.logger.Error("Received nil metadata")
t.logger.Error("Received nil metadata during initialization")
// Consider what should happen if metadata is nil after GetMetadata returns no error
}
// updateMetadataEndpoints updates the relevant endpoint URL fields (jwksURL, authURL, tokenURL, etc.)
@@ -497,6 +507,8 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
if metadata != nil {
t.updateMetadataEndpoints(metadata)
t.logger.Debug("Successfully refreshed metadata")
} else {
t.logger.Error("Received nil metadata during refresh")
}
}
}
@@ -544,7 +556,7 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
if delay > maxDelay {
delay = maxDelay
}
l.Debugf("Failed to fetch provider metadata, retrying in %s", delay)
l.Debugf("Failed to fetch provider metadata (attempt %d/%d), retrying in %s. Error: %v", attempt+1, maxRetries, delay, err)
time.Sleep(delay)
}
@@ -568,64 +580,56 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("received nil response from provider")
return nil, fmt.Errorf("received nil response from provider at %s", wellKnownURL)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch provider metadata: status code %d", resp.StatusCode)
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to fetch provider metadata from %s: status code %d, body: %s", wellKnownURL, resp.StatusCode, string(bodyBytes))
}
var metadata ProviderMetadata
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
return nil, fmt.Errorf("failed to decode provider metadata: %w", err)
// Attempt to read body for better error context if decoding fails
// Note: resp.Body might be partially read by Decode, so read remaining
bodyBytes, readErr := io.ReadAll(io.MultiReader(json.NewDecoder(resp.Body).Buffered(), resp.Body))
if readErr != nil {
bodyBytes = []byte(fmt.Sprintf("(failed to read response body: %v)", readErr))
}
return nil, fmt.Errorf("failed to decode provider metadata from %s: %w. Response body: %s", wellKnownURL, err, string(bodyBytes))
}
return &metadata, nil
}
// ServeHTTP is the main entry point for incoming requests to the middleware.
// It orchestrates the OIDC authentication flow:
// 1. Waits for initial OIDC metadata discovery to complete (with timeout).
// 2. Checks if the request path is excluded from authentication.
// 3. Checks if the request is for Server-Sent Events and bypasses if so.
// 4. Retrieves the user's session; initiates authentication if the session is invalid/missing.
// 5. Handles specific paths for OIDC callback (/callback) and logout (/logout).
// 6. Checks the user's authentication status using isUserAuthenticated (verifies token, checks expiry).
// 7. If the token is expired, handles it (initiates re-auth).
// 8. If the user is not authenticated, initiates authentication.
// 9. If the token needs proactive refresh (nearing expiry), attempts refreshToken. Handles refresh failure
// by returning 401 for API clients or initiating re-auth for browsers.
// 10. If authenticated and token is valid, performs authorization checks (allowed domain, roles/groups).
// 11. If authorized, sets user/token information in request headers (X-Forwarded-User, X-Auth-Request-*)
// and adds security headers (X-Frame-Options, etc.) to the response.
// 12. Forwards the request to the next handler in the chain.
// It orchestrates the OIDC authentication flow.
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// --- Initialization Check ---
select {
case <-t.initComplete:
if t.issuerURL == "" {
t.logger.Error("OIDC provider metadata initialization failed")
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability", http.StatusServiceUnavailable)
if t.issuerURL == "" { // Check if initialization actually succeeded
t.logger.Error("OIDC provider metadata initialization failed or incomplete")
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability and configuration", http.StatusServiceUnavailable)
return
}
case <-req.Context().Done():
t.logger.Debug("Request cancelled")
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
t.logger.Debug("Request cancelled while waiting for OIDC initialization")
http.Error(rw, "Request cancelled", http.StatusRequestTimeout) // 408 might be more appropriate
return
case <-time.After(30 * time.Second):
case <-time.After(30 * time.Second): // Timeout for initialization
t.logger.Error("Timeout waiting for OIDC initialization")
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again", http.StatusServiceUnavailable)
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again later", http.StatusServiceUnavailable)
return
}
// Check if URL is excluded
// --- Excluded Paths & SSE Check ---
if t.determineExcludedURL(req.URL.Path) {
t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
t.next.ServeHTTP(rw, req)
return
}
// Check if the request expects Server-Sent Events
acceptHeader := req.Header.Get("Accept")
if strings.Contains(acceptHeader, "text/event-stream") {
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
@@ -633,80 +637,119 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
// Get session
// --- Session Retrieval ---
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
// Obtain a new session and clear any residual session cookies
session, _ = t.sessionManager.GetSession(req)
session.Clear(req, rw)
// Build redirect URL
// Log the specific session error
t.logger.Errorf("Error getting session: %v. Initiating authentication.", err)
// Attempt to get a new session to store CSRF etc.
session, _ = t.sessionManager.GetSession(req) // Ignore error here, proceed with new session
if session != nil {
// Pass rw to ensure expiring cookies are sent if possible
if clearErr := session.Clear(req, rw); clearErr != nil {
t.logger.Errorf("Error clearing potentially corrupted session: %v", clearErr)
}
} else {
// If even getting a new session fails, something is very wrong
t.logger.Error("Critical session error: Failed to get even a new session.")
http.Error(rw, "Critical session error", http.StatusInternalServerError)
return
}
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
// Initiate authentication
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
// Build redirect URL
// --- URL Handling (Callback, Logout) ---
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
redirectURL := buildFullURL(scheme, host, t.redirURLPath) // Used for callback and re-auth
// Handle special URLs
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req, redirectURL)
return
}
// Check authentication status
// --- Authentication & Refresh Logic ---
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
t.logger.Debug("Session token is definitively expired or invalid, initiating re-auth")
// handleExpiredToken clears the session and initiates auth
t.handleExpiredToken(rw, req, session, redirectURL)
return
}
if !authenticated {
// Original logic: Always initiate authentication if not authenticated
t.logger.Debug("User not authenticated, initiating OIDC flow")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
// If authenticated and token doesn't need proactive refresh, proceed directly
if authenticated && !needsRefresh {
t.logger.Debug("User authenticated and token valid, proceeding to process authorized request")
t.processAuthorizedRequest(rw, req, session, redirectURL)
return
}
// --- Attempt Refresh if Needed or Possible ---
// Conditions to attempt refresh:
// 1. Token needs proactive refresh (authenticated=true, needsRefresh=true)
// 2. Token is invalid/expired but a refresh token exists (authenticated=false, needsRefresh=true)
refreshTokenPresent := session.GetRefreshToken() != ""
shouldAttemptRefresh := needsRefresh && refreshTokenPresent
if shouldAttemptRefresh {
if needsRefresh && authenticated {
t.logger.Debug("Session token needs proactive refresh, attempting refresh")
} else if needsRefresh && !authenticated {
t.logger.Debug("Access token invalid/expired, but refresh token found. Attempting refresh.")
}
refreshed := t.refreshToken(rw, req, session)
if refreshed {
// Refresh succeeded, proceed to authorization checks
t.logger.Debug("Token refresh successful, proceeding to process authorized request")
t.processAuthorizedRequest(rw, req, session, redirectURL)
return
}
// Refresh failed
t.logger.Infof("Token refresh failed (authenticated=%v, needsRefresh=%v, refreshTokenPresent=%v)", authenticated, needsRefresh, refreshTokenPresent)
// Handle refresh failure (401 for API, re-auth for browser)
acceptHeader := req.Header.Get("Accept")
if strings.Contains(acceptHeader, "application/json") {
t.logger.Debug("Client accepts JSON, sending 401 Unauthorized on refresh failure")
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(rw).Encode(map[string]string{"error": "unauthorized", "message": "Token refresh failed"})
} else {
t.logger.Debug("Client does not prefer JSON, handling refresh failure by initiating re-auth")
// Use defaultInitiateAuthentication which clears the session properly
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
return // Stop processing
}
if needsRefresh {
refreshed := t.refreshToken(rw, req, session)
if !refreshed {
t.logger.Infof("Token refresh failed") // Changed from Warn to Infof
// Check if the client prefers JSON (likely an API call)
acceptHeader := req.Header.Get("Accept")
if strings.Contains(acceptHeader, "application/json") {
t.logger.Debug("Client accepts JSON, sending 401 Unauthorized on refresh failure")
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(rw).Encode(map[string]string{"error": "unauthorized", "message": "Token refresh failed"})
} else {
// Client likely a browser, initiate full re-authentication
t.logger.Debug("Client does not prefer JSON, handling refresh failure as expired token (initiating re-auth)")
t.handleExpiredToken(rw, req, session, redirectURL)
}
return // Stop processing
}
}
// --- Initiate Full Authentication ---
// If we reach here, it means:
// - User is not authenticated (!authenticated)
// - AND EITHER token doesn't need refresh (!needsRefresh, e.g., first visit)
// - OR refresh token is missing (!refreshTokenPresent)
// - OR refresh was attempted but failed (handled above)
t.logger.Debugf("Initiating full OIDC authentication flow (authenticated=%v, needsRefresh=%v, refreshTokenPresent=%v)", authenticated, needsRefresh, refreshTokenPresent)
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// Process authenticated request
// processAuthorizedRequest handles the final steps for an authenticated and authorized request.
// It performs domain/role/group checks, sets headers, and forwards the request.
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
email := session.GetEmail()
if email == "" {
t.logger.Debug("No email found in session")
t.logger.Error("CRITICAL: No email found in session during final processing, initiating re-auth")
// This case should ideally not happen if checks are done correctly before calling this,
// but as a safeguard, initiate re-authentication.
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
@@ -721,6 +764,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
// Continue without group/role headers if extraction fails
} else {
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
@@ -779,6 +823,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
// Process the request
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
t.next.ServeHTTP(rw, req)
}
@@ -794,19 +839,21 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// - session: The user's session data containing the expired token information.
// - redirectURL: The callback URL to be used in the new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Clear authentication data but preserve CSRF state
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
// Clear authentication data but preserve CSRF state if possible (though Clear might remove it)
session.SetAuthenticated(false)
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
// Save the cleared session state
// Save the cleared session state (this sends expired cookies)
// Pass rw to ensure expiring cookies are sent
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save cleared session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
t.logger.Errorf("Failed to save cleared session during expired token handling: %v", err)
// Still attempt to initiate authentication, but log the error
}
// Initiate a new authentication flow
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
@@ -834,8 +881,8 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
t.logger.Errorf("Session error during callback: %v", err)
http.Error(rw, "Session error during callback", http.StatusInternalServerError)
return
}
@@ -847,7 +894,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
if errorDescription == "" {
errorDescription = req.URL.Query().Get("error") // Use error code if description is empty
}
t.logger.Errorf("Authentication error from provider: %s - %s", req.URL.Query().Get("error"), errorDescription)
t.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
t.sendErrorResponse(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
return
}
@@ -862,13 +909,13 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Error("CSRF token missing in session")
t.sendErrorResponse(rw, req, "CSRF token missing", http.StatusBadRequest)
t.logger.Error("CSRF token missing in session during callback")
t.sendErrorResponse(rw, req, "CSRF token missing in session", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
t.logger.Error("State parameter does not match CSRF token in session during callback")
t.sendErrorResponse(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
return
}
@@ -886,22 +933,21 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
t.logger.Errorf("Failed to exchange code for token during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
return
}
// Verify tokens and claims
// Use the exported VerifyToken method now that handleCallback is in main.go
if err := t.VerifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
t.logger.Errorf("Failed to verify id_token during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.logger.Errorf("Failed to extract claims during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
return
}
@@ -909,51 +955,68 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
// Verify nonce to prevent replay attacks
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
t.logger.Error("Nonce claim missing in id_token during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session")
t.logger.Error("Nonce not found in session during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
t.logger.Error("Nonce claim does not match session nonce during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
return
}
// Validate user's email domain
// Use the unexported isAllowedDomain method now that handleCallback is in main.go
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
t.sendErrorResponse(rw, req, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
if email == "" {
t.logger.Errorf("Email claim missing or empty in token during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
return
}
if !t.isAllowedDomain(email) {
t.logger.Errorf("Disallowed email domain during callback: %s", email)
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
return
}
// Update session with authentication data
session.SetAuthenticated(true)
// Regenerate session ID upon successful authentication
if err := session.SetAuthenticated(true); err != nil {
t.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
http.Error(rw, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Clear CSRF, Nonce, CodeVerifier after use
session.SetCSRF("")
session.SetNonce("")
session.SetCodeVerifier("")
// Redirect to original path or root
// Retrieve original path *before* saving, as save might clear it if Clear was called concurrently
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
session.SetIncomingPath("") // Clear incoming path after retrieving it
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session after callback: %v", err)
http.Error(rw, "Failed to save session after callback", http.StatusInternalServerError)
return
}
// Redirect to original path or root
t.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
@@ -972,7 +1035,7 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
return true
}
}
t.logger.Debugf("URL is not excluded - got %s", currentRequest)
// t.logger.Debugf("URL is not excluded - got %s", currentRequest) // Too verbose for every request
return false
}
@@ -1024,32 +1087,58 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
// - expired (bool): True if the session is unauthenticated, the token is missing, or the token verification failed for reasons other than nearing/actual expiration (e.g., invalid signature, invalid claims).
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
t.logger.Debug("User is not authenticated according to session")
return false, false, false
t.logger.Debug("User is not authenticated according to session flag")
// Check if there's still a refresh token - if so, refresh might be possible
if session.GetRefreshToken() != "" {
t.logger.Debug("Session not authenticated, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated, NeedsRefresh=true (to attempt recovery), Expired=false
}
return false, false, false // Not authenticated, no refresh token, definitely not expired (just unauth)
}
accessToken := session.GetAccessToken()
if accessToken == "" {
t.logger.Debug("No access token found in session")
return false, false, true // Session is invalid, consider it expired
t.logger.Debug("Authenticated flag set, but no access token found in session")
// If authenticated flag is true but token is missing, treat as expired/invalid session state
// Check for refresh token before declaring fully expired
if session.GetRefreshToken() != "" {
t.logger.Debug("Authenticated flag set, access token missing, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated (no access token), NeedsRefresh=true, Expired=false
}
return false, false, true // No access or refresh token, treat as expired
}
// Verify the token structure and signature first
jwt, err := parseJWT(accessToken)
if err != nil {
t.logger.Errorf("Failed to parse JWT during auth check: %v", err)
return false, false, true // Invalid format, treat as expired/invalid
// Check for refresh token before declaring fully expired
if session.GetRefreshToken() != "" {
t.logger.Debug("Access token parsing failed, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false
}
return false, false, true // Invalid format, no refresh token, treat as expired/invalid
}
if err := t.VerifyJWTSignatureAndClaims(jwt, accessToken); err != nil {
// Check if the error is specifically about expiration
if strings.Contains(err.Error(), "token has expired") {
t.logger.Debugf("Token signature/claims valid but token expired, attempting refresh")
t.logger.Debugf("Access token signature/claims valid but token expired, needs refresh")
// Token is expired but otherwise valid, signal for refresh
return true, true, false // Authenticated=true (was valid), NeedsRefresh=true, Expired=false (because refresh is possible)
// Return authenticated=false because the current token is unusable
// NeedsRefresh is true only if a refresh token exists
if session.GetRefreshToken() != "" {
return false, true, false // Not authenticated (current token unusable), NeedsRefresh=true, Expired=false (because refresh might fix it)
}
return false, false, true // Expired access token, no refresh token, treat as expired
}
// Other verification error (signature, issuer, audience etc.)
t.logger.Errorf("Token verification failed (non-expiration): %v", err)
return false, false, true // Token is invalid for other reasons
t.logger.Errorf("Access token verification failed (non-expiration): %v", err)
// Check for refresh token before declaring fully expired
if session.GetRefreshToken() != "" {
t.logger.Debug("Access token verification failed, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false
}
return false, false, true // Token is invalid for other reasons, no refresh token, treat as expired/invalid session
}
// Claims already parsed within VerifyJWTSignatureAndClaims if it didn't error early
@@ -1057,8 +1146,13 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
expClaim, ok := claims["exp"].(float64)
if !ok {
t.logger.Error("Failed to get expiration time from claims")
return false, false, true
t.logger.Error("Failed to get expiration time ('exp' claim) from verified token")
// Check for refresh token before declaring fully expired
if session.GetRefreshToken() != "" {
t.logger.Debug("Access token missing 'exp' claim, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false
}
return false, false, true // Treat as invalid if 'exp' is missing and no refresh token
}
expTime := int64(expClaim)
@@ -1071,12 +1165,19 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
if time.Unix(expTime, 0).Before(time.Now().Add(t.refreshGracePeriod)) {
// Recalculate remaining seconds for logging clarity if needed, using the configured duration
remainingSeconds := int64(time.Until(time.Unix(expTime, 0)).Seconds())
t.logger.Debugf("Token nearing expiration (expires in %d seconds, grace period %s), scheduling refresh", remainingSeconds, t.refreshGracePeriod)
return true, true, false // Needs proactive refresh
t.logger.Debugf("Access token nearing expiration (expires in %d seconds, grace period %s), scheduling proactive refresh", remainingSeconds, t.refreshGracePeriod)
// Token is still valid, but we should refresh it soon
// NeedsRefresh is true only if a refresh token exists
if session.GetRefreshToken() != "" {
return true, true, false // Authenticated=true (current token usable), NeedsRefresh=true, Expired=false
}
// If no refresh token, we can't proactively refresh, treat as normal valid token for now
t.logger.Debugf("Token nearing expiration but no refresh token available, cannot proactively refresh.")
return true, false, false
}
// Token is valid, not expired, and not nearing expiration
return true, false, false
return true, false, false // Authenticated=true, NeedsRefresh=false, Expired=false
}
// defaultInitiateAuthentication handles the process of starting an OIDC authentication flow.
@@ -1092,10 +1193,12 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
// - session: The user's SessionData object (potentially new or cleared).
// - redirectURL: The pre-calculated callback URL (redirect_uri) for this middleware instance.
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
// Generate CSRF token and nonce
csrfToken := uuid.NewString()
nonce, err := generateNonce()
if err != nil {
t.logger.Errorf("Failed to generate nonce: %v", err)
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
@@ -1106,37 +1209,41 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
var err error
codeVerifier, err = generateCodeVerifier()
if err != nil {
t.logger.Errorf("Failed to generate code verifier: %v", err)
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
return
}
// Derive code challenge from verifier
codeChallenge = deriveCodeChallenge(codeVerifier)
t.logger.Debugf("PKCE enabled, generated code challenge")
}
// Clear any existing session data to avoid stale state causing redirect loops
session.Clear(req, rw)
// Pass the response writer to ensure expiring cookies are sent
if err := session.Clear(req, rw); err != nil {
// Log the error but continue, as clearing is best-effort before re-auth
t.logger.Errorf("Error clearing session before initiating authentication: %v", err)
}
// Set new session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
// Only set code verifier if PKCE is enabled
if t.enablePKCE {
session.SetCodeVerifier(codeVerifier)
}
// Store the original path the user was trying to access
session.SetIncomingPath(req.URL.RequestURI())
t.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI())
// Save the session
// Save the session (to store CSRF, Nonce, etc.)
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
t.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Build and redirect to authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
t.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
http.Redirect(rw, req, authURL, http.StatusFound)
}
@@ -1185,6 +1292,7 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
params.Set("scope", strings.Join(t.scopes, " "))
}
// Use buildURLWithParams which handles potential relative authURL from metadata
return t.buildURLWithParams(t.authURL, params)
}
@@ -1201,17 +1309,30 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string {
// Ensure URL is absolute
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
// Extract issuer base URL
issuerURL, err := url.Parse(t.issuerURL)
// Attempt to resolve relative URL against issuer URL
issuerURLParsed, err := url.Parse(t.issuerURL)
if err == nil {
return fmt.Sprintf("%s://%s%s?%s",
issuerURL.Scheme,
issuerURL.Host,
baseURL,
params.Encode())
baseURLParsed, err := url.Parse(baseURL)
if err == nil {
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
resolvedURL.RawQuery = params.Encode()
return resolvedURL.String()
}
}
// Fallback if parsing fails - append params to potentially relative path
t.logger.Errorf("Could not parse issuerURL or baseURL to resolve relative URL. BaseURL: %s, IssuerURL: %s", baseURL, t.issuerURL)
return baseURL + "?" + params.Encode()
}
return baseURL + "?" + params.Encode()
// If baseURL is already absolute
u, err := url.Parse(baseURL)
if err != nil {
t.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err)
// Fallback: append params directly
return baseURL + "?" + params.Encode()
}
u.RawQuery = params.Encode()
return u.String()
}
// startTokenCleanup starts background goroutines for periodically cleaning up
@@ -1246,6 +1367,7 @@ func (t *TraefikOidc) RevokeToken(token string) {
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
// Use Set with a duration. Value 'true' is arbitrary, we only care about existence.
t.tokenBlacklist.Set(token, true, time.Until(expiry))
t.logger.Debugf("Locally revoked token (added to blacklist)")
}
// RevokeTokenWithProvider attempts to revoke a token directly with the OIDC provider
@@ -1260,7 +1382,10 @@ func (t *TraefikOidc) RevokeToken(token string) {
// - nil if the revocation request is successful (provider returns 200 OK).
// - An error if the request fails or the provider returns a non-OK status.
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
t.logger.Debugf("Revoking token with provider")
if t.revocationURL == "" {
return fmt.Errorf("token revocation endpoint is not configured or discovered")
}
t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, t.revocationURL)
data := url.Values{
"token": {token},
@@ -1277,6 +1402,7 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
// Set headers
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json") // Prefer JSON response if available
// Send the request
resp, err := t.httpClient.Do(req)
@@ -1288,18 +1414,20 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
// Check the response
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("token revocation failed with status %d: %s", resp.StatusCode, string(body))
// Log the failure details
t.logger.Errorf("Token revocation failed with status %d: %s", resp.StatusCode, string(body))
return fmt.Errorf("token revocation failed with status %d", resp.StatusCode)
}
t.logger.Debugf("Token successfully revoked")
t.logger.Debugf("Token successfully revoked with provider")
return nil
}
// refreshToken attempts to use the refresh token stored in the session to obtain a new set of tokens.
// It acquires a mutex associated with the session to prevent concurrent refresh attempts for the same session.
// It retrieves the refresh token, calls the TokenExchanger's GetNewTokenWithRefreshToken method,
// verifies the newly obtained ID token using verifyToken, updates the session with the new tokens,
// and saves the session.
// verifies the newly obtained ID token using verifyToken, performs a concurrency check,
// updates the session with the new tokens if the check passes, and saves the session.
//
// Parameters:
// - rw: The HTTP response writer (needed for saving the updated session).
@@ -1308,45 +1436,85 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
//
// Returns:
// - true if the token refresh was successful and the session was updated.
// - false if no refresh token was found, the refresh exchange failed, the new token failed verification, or saving the session failed.
// - false if no refresh token was found, the refresh exchange failed, the new token failed verification,
// a concurrency conflict was detected, or saving the session failed.
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
// Lock the mutex specific to this session instance before attempting refresh
session.refreshMutex.Lock()
defer session.refreshMutex.Unlock()
t.logger.Debug("Attempting to refresh token (mutex acquired)")
refreshToken := session.GetRefreshToken() // Get token *after* acquiring lock
if refreshToken == "" {
t.logger.Debug("No refresh token found in session (inside lock)")
initialRefreshToken := session.GetRefreshToken() // Get token *after* acquiring lock
if initialRefreshToken == "" {
t.logger.Errorf("refreshToken failed: No refresh token found in session (after acquiring lock)")
return false
}
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
// Store the initial token for later comparison
t.logger.Debugf("Attempting refresh with token starting with %s...", initialRefreshToken[:min(len(initialRefreshToken), 10)])
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
if err != nil {
// Log the error, potentially clear the invalid refresh token?
t.logger.Errorf("Failed to refresh token using refresh token: %v", err)
// Consider clearing the refresh token from the session here if the error indicates it's invalid
// session.SetRefreshToken("") // Example: Clear potentially invalid token
// session.Save(req, rw) // Need to handle potential save error
// Log the error more explicitly before returning false
truncatedToken := initialRefreshToken[:min(len(initialRefreshToken), 10)] // Log first 10 chars
t.logger.Errorf("refreshToken failed: Error from tokenExchanger.GetNewTokenWithRefreshToken for token starting with %s...: %v", truncatedToken, err)
// No need to clear token here, as the session might be cleared by another request anyway
return false
}
// Verify the new access token
// Verify the new access token (ID token)
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new access token: %v", err)
truncatedNewToken := newToken.IDToken[:min(len(newToken.IDToken), 10)] // Log first 10 chars
t.logger.Errorf("refreshToken failed: Failed to verify newly obtained ID token starting with %s...: %v", truncatedNewToken, err)
return false
}
// Update session with new tokens
// --- Concurrency Check ---
// Before saving the new token, check if the session state (specifically the refresh token)
// has been modified concurrently (e.g., by a logout or another auth initiation).
currentRefreshToken := session.GetRefreshToken() // Get token again *after* the potentially long exchange
if initialRefreshToken != currentRefreshToken {
// Use Infof as Warnf doesn't exist
t.logger.Infof("refreshToken aborted: Session refresh token changed concurrently during refresh attempt. Initial token prefix: %s..., Current token prefix: %s...", initialRefreshToken[:min(len(initialRefreshToken), 10)], currentRefreshToken[:min(len(currentRefreshToken), 10)])
// Do not save the new tokens, as the session state is likely invalid/cleared.
return false // Indicate refresh failure due to concurrency conflict
}
// --- End Concurrency Check ---
// Update session with new tokens ONLY if the concurrency check passed
t.logger.Debugf("Concurrency check passed. Updating session with new tokens.")
// Extract email from the new token and update session
claims, err := t.extractClaimsFunc(newToken.IDToken)
if err != nil {
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
return false // Cannot proceed without claims
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
return false // Cannot proceed without email
}
session.SetEmail(email) // Update email in session
session.SetAccessToken(newToken.IDToken)
session.SetRefreshToken(newToken.RefreshToken)
// Ensure the new refresh token is actually set, even if it's the same as the old one
// Also handle cases where the provider might not return a new refresh token
if newToken.RefreshToken != "" {
session.SetRefreshToken(newToken.RefreshToken)
} else {
// If no new refresh token is returned, keep the existing one
session.SetRefreshToken(initialRefreshToken)
t.logger.Debugf("Provider did not return a new refresh token, keeping the existing one.")
}
// Save the session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save refreshed session: %v", err)
t.logger.Errorf("refreshToken failed: Failed to save session after successful token refresh and concurrency check: %v", err)
return false
}
t.logger.Debugf("Token refresh successful and session saved.")
return true
}
@@ -1367,6 +1535,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
parts := strings.Split(email, "@")
if len(parts) != 2 {
t.logger.Errorf("Invalid email format encountered: %s", email)
return false // Invalid email format
}
@@ -1400,12 +1569,16 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
if groupsClaim, exists := claims["groups"]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
// Strictly expect an array
return nil, nil, fmt.Errorf("groups claim is not an array")
}
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
} else {
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
} else {
t.logger.Errorf("Non-string value found in groups claim array: %v", group)
}
}
}
}
@@ -1414,12 +1587,16 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
if rolesClaim, exists := claims["roles"]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
// Strictly expect an array
return nil, nil, fmt.Errorf("roles claim is not an array")
}
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
} else {
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
} else {
t.logger.Errorf("Non-string value found in roles claim array: %v", role)
}
}
}
}
@@ -1493,8 +1670,8 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
rw.WriteHeader(code)
// Use a simple error structure
json.NewEncoder(rw).Encode(map[string]interface{}{
"error": http.StatusText(code),
"error_description": message,
"error": http.StatusText(code), // Use standard text for the code
"error_description": message, // Provide specific detail here
"status_code": code,
})
return
@@ -1504,17 +1681,9 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
t.logger.Debugf("Sending HTML error response (code %d): %s", code, message)
// Determine the return URL (mostly relevant for HTML)
returnURL := "/" // Default to root
session, err := t.sessionManager.GetSession(req) // Attempt to get session for return URL
if err == nil {
incomingPath := session.GetIncomingPath()
// Use incoming path if it's valid and not one of the special OIDC paths
if incomingPath != "" && incomingPath != t.redirURLPath && incomingPath != t.logoutURLPath {
returnURL = incomingPath
}
} else {
t.logger.Infof("Could not get session to determine return URL in sendErrorResponse: %v", err)
}
returnURL := "/" // Default to root
// No need to get session here, as we are already in an error path
// where session might be invalid or unavailable.
// Basic HTML structure for the error page
htmlBody := fmt.Sprintf(`
@@ -1537,7 +1706,7 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
<p><a href="%s">Return to application</a></p>
</div>
</body>
</html>`, message, returnURL)
</html>`, message, returnURL) // Use default returnURL
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.WriteHeader(code)
+71 -17
View File
@@ -385,9 +385,52 @@ func TestServeHTTP(t *testing.T) {
expectedBody: "OK",
},
{
name: "Unauthenticated request to protected URL",
requestPath: "/protected",
expectedStatus: http.StatusFound, // Expect redirect to OIDC
name: "Unauthenticated request (no refresh token) to protected URL",
requestPath: "/protected",
setupSession: func(session *SessionData) {
// Ensure no tokens are set
session.SetAuthenticated(false)
session.SetAccessToken("")
session.SetRefreshToken("")
},
expectedStatus: http.StatusFound, // Expect redirect to OIDC as there's no refresh token
},
{
name: "Unauthenticated request (with refresh token) to protected URL - Expect Refresh Attempt",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(false) // Not authenticated
session.SetAccessToken("") // No access token
session.SetRefreshToken("valid-refresh-token-for-unauth-test") // BUT has refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
if refreshToken != "valid-refresh-token-for-unauth-test" {
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
}
// Simulate successful refresh
newToken := createNewValidToken() // Use helper from TestServeHTTP
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-unauth", ExpiresIn: 3600}, nil
}
},
expectedStatus: http.StatusOK, // Expect OK after successful refresh
expectedBody: "OK",
},
{
name: "Unauthenticated request (with refresh token) to protected URL - Refresh Fails",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(false) // Not authenticated
session.SetAccessToken("") // No access token
session.SetRefreshToken("invalid-refresh-token-for-unauth-test") // Invalid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
// Simulate failed refresh
return nil, fmt.Errorf("mock error: refresh token invalid")
}
},
expectedStatus: http.StatusFound, // Expect redirect to OIDC after failed refresh
},
{
name: "Authenticated request to protected URL (Valid Token)",
@@ -407,11 +450,15 @@ func TestServeHTTP(t *testing.T) {
expectedStatus: http.StatusOK,
expectedBody: "OK",
},
// This test case remains valid as the logic should still attempt refresh when expired token + refresh token exist
{
name: "Authenticated request with expired token and successful refresh",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Still marked authenticated initially
// NOTE: isUserAuthenticated now returns authenticated=false if access token is expired,
// even if session.SetAuthenticated(true) was called.
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken()) // Set expired token
session.SetRefreshToken("valid-refresh-token") // Set valid refresh token
@@ -445,16 +492,19 @@ func TestServeHTTP(t *testing.T) {
t.Fatalf("Failed to get session after request: %v", err)
}
// Assert new tokens are in the session
// Direct comparison with createNewValidToken() is flawed as it generates a new token each time.
// Instead, check if the token was updated (not empty) and verify the refresh token.
if session.GetAccessToken() == "" {
t.Errorf("Expected access token to be updated in session, but it was empty")
if session.GetAccessToken() == "" || session.GetAccessToken() == createExpiredToken() {
t.Errorf("Expected access token to be updated in session, but it was empty or still the expired one")
}
if session.GetRefreshToken() != "new-refresh-token" {
t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken())
}
// Also check authenticated flag is now true
if !session.GetAuthenticated() {
t.Errorf("Expected session to be marked authenticated after successful refresh")
}
},
},
// This test case remains valid as the logic should still return 401 for API clients on refresh failure
{
name: "Logout URL",
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
@@ -477,10 +527,10 @@ func TestServeHTTP(t *testing.T) {
name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken())
session.SetRefreshToken("valid-refresh-token")
session.SetAccessToken(createExpiredToken()) // Expired access token
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
@@ -491,17 +541,18 @@ func TestServeHTTP(t *testing.T) {
requestHeaders: map[string]string{
"Accept": "application/json",
},
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client after failed refresh attempt
expectedBody: `{"error":"unauthorized","message":"Token refresh failed"}`,
},
// This test case remains valid as the logic should still redirect browser clients on refresh failure
{
name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken())
session.SetRefreshToken("valid-refresh-token")
session.SetAccessToken(createExpiredToken()) // Expired access token
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
@@ -512,8 +563,9 @@ func TestServeHTTP(t *testing.T) {
requestHeaders: map[string]string{
"Accept": "text/html", // Browser client
},
expectedStatus: http.StatusFound, // Expect redirect for browser client
expectedStatus: http.StatusFound, // Expect redirect to OIDC for browser client after failed refresh attempt
},
// This test case remains valid as proactive refresh should still be attempted
{
name: "Authenticated request with token nearing expiry (needs refresh)",
requestPath: "/protected",
@@ -529,7 +581,7 @@ func TestServeHTTP(t *testing.T) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(nearExpiryToken)
session.SetRefreshToken("valid-refresh-token-for-near-expiry")
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
@@ -544,6 +596,7 @@ func TestServeHTTP(t *testing.T) {
expectedStatus: http.StatusOK, // Expect success after proactive refresh
expectedBody: "OK",
},
// This test case remains valid as no refresh should be attempted
{
name: "Authenticated request with token valid (outside grace period)",
requestPath: "/protected",
@@ -1531,6 +1584,7 @@ func TestRevokeToken(t *testing.T) {
tOidc := &TraefikOidc{
tokenBlacklist: NewCache(), // Use generic cache for blacklist
tokenCache: NewTokenCache(),
logger: NewLogger("info"), // Initialize the logger
}
// Cache the token