mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Bugfix: Refresh token not obtained when access token is expired.
This commit is contained in:
@@ -17,6 +17,14 @@ import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// min returns the smaller of x or y.
|
||||
func min(x, y int) int {
|
||||
if x > y {
|
||||
return y
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// createDefaultHTTPClient creates a new http.Client with settings optimized for OIDC communication.
|
||||
// It configures the transport with specific timeouts (dial, keepalive, TLS handshake, idle connection),
|
||||
// connection limits (max idle, max per host), enables HTTP/2, and sets a default request timeout.
|
||||
@@ -442,6 +450,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to get provider metadata: %v", err)
|
||||
// Consider retrying or handling this more gracefully
|
||||
return
|
||||
}
|
||||
|
||||
@@ -457,7 +466,8 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Error("Received nil metadata")
|
||||
t.logger.Error("Received nil metadata during initialization")
|
||||
// Consider what should happen if metadata is nil after GetMetadata returns no error
|
||||
}
|
||||
|
||||
// updateMetadataEndpoints updates the relevant endpoint URL fields (jwksURL, authURL, tokenURL, etc.)
|
||||
@@ -497,6 +507,8 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
if metadata != nil {
|
||||
t.updateMetadataEndpoints(metadata)
|
||||
t.logger.Debug("Successfully refreshed metadata")
|
||||
} else {
|
||||
t.logger.Error("Received nil metadata during refresh")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -544,7 +556,7 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
l.Debugf("Failed to fetch provider metadata, retrying in %s", delay)
|
||||
l.Debugf("Failed to fetch provider metadata (attempt %d/%d), retrying in %s. Error: %v", attempt+1, maxRetries, delay, err)
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
@@ -568,64 +580,56 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
|
||||
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("received nil response from provider")
|
||||
return nil, fmt.Errorf("received nil response from provider at %s", wellKnownURL)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to fetch provider metadata: status code %d", resp.StatusCode)
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("failed to fetch provider metadata from %s: status code %d, body: %s", wellKnownURL, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var metadata ProviderMetadata
|
||||
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode provider metadata: %w", err)
|
||||
// Attempt to read body for better error context if decoding fails
|
||||
// Note: resp.Body might be partially read by Decode, so read remaining
|
||||
bodyBytes, readErr := io.ReadAll(io.MultiReader(json.NewDecoder(resp.Body).Buffered(), resp.Body))
|
||||
if readErr != nil {
|
||||
bodyBytes = []byte(fmt.Sprintf("(failed to read response body: %v)", readErr))
|
||||
}
|
||||
return nil, fmt.Errorf("failed to decode provider metadata from %s: %w. Response body: %s", wellKnownURL, err, string(bodyBytes))
|
||||
}
|
||||
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// ServeHTTP is the main entry point for incoming requests to the middleware.
|
||||
// It orchestrates the OIDC authentication flow:
|
||||
// 1. Waits for initial OIDC metadata discovery to complete (with timeout).
|
||||
// 2. Checks if the request path is excluded from authentication.
|
||||
// 3. Checks if the request is for Server-Sent Events and bypasses if so.
|
||||
// 4. Retrieves the user's session; initiates authentication if the session is invalid/missing.
|
||||
// 5. Handles specific paths for OIDC callback (/callback) and logout (/logout).
|
||||
// 6. Checks the user's authentication status using isUserAuthenticated (verifies token, checks expiry).
|
||||
// 7. If the token is expired, handles it (initiates re-auth).
|
||||
// 8. If the user is not authenticated, initiates authentication.
|
||||
// 9. If the token needs proactive refresh (nearing expiry), attempts refreshToken. Handles refresh failure
|
||||
// by returning 401 for API clients or initiating re-auth for browsers.
|
||||
// 10. If authenticated and token is valid, performs authorization checks (allowed domain, roles/groups).
|
||||
// 11. If authorized, sets user/token information in request headers (X-Forwarded-User, X-Auth-Request-*)
|
||||
// and adds security headers (X-Frame-Options, etc.) to the response.
|
||||
// 12. Forwards the request to the next handler in the chain.
|
||||
// It orchestrates the OIDC authentication flow.
|
||||
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// --- Initialization Check ---
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
if t.issuerURL == "" {
|
||||
t.logger.Error("OIDC provider metadata initialization failed")
|
||||
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability", http.StatusServiceUnavailable)
|
||||
if t.issuerURL == "" { // Check if initialization actually succeeded
|
||||
t.logger.Error("OIDC provider metadata initialization failed or incomplete")
|
||||
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability and configuration", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
case <-req.Context().Done():
|
||||
t.logger.Debug("Request cancelled")
|
||||
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
|
||||
t.logger.Debug("Request cancelled while waiting for OIDC initialization")
|
||||
http.Error(rw, "Request cancelled", http.StatusRequestTimeout) // 408 might be more appropriate
|
||||
return
|
||||
case <-time.After(30 * time.Second):
|
||||
case <-time.After(30 * time.Second): // Timeout for initialization
|
||||
t.logger.Error("Timeout waiting for OIDC initialization")
|
||||
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again", http.StatusServiceUnavailable)
|
||||
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again later", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if URL is excluded
|
||||
// --- Excluded Paths & SSE Check ---
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the request expects Server-Sent Events
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "text/event-stream") {
|
||||
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
|
||||
@@ -633,80 +637,119 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get session
|
||||
// --- Session Retrieval ---
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
|
||||
// Obtain a new session and clear any residual session cookies
|
||||
session, _ = t.sessionManager.GetSession(req)
|
||||
session.Clear(req, rw)
|
||||
|
||||
// Build redirect URL
|
||||
// Log the specific session error
|
||||
t.logger.Errorf("Error getting session: %v. Initiating authentication.", err)
|
||||
// Attempt to get a new session to store CSRF etc.
|
||||
session, _ = t.sessionManager.GetSession(req) // Ignore error here, proceed with new session
|
||||
if session != nil {
|
||||
// Pass rw to ensure expiring cookies are sent if possible
|
||||
if clearErr := session.Clear(req, rw); clearErr != nil {
|
||||
t.logger.Errorf("Error clearing potentially corrupted session: %v", clearErr)
|
||||
}
|
||||
} else {
|
||||
// If even getting a new session fails, something is very wrong
|
||||
t.logger.Error("Critical session error: Failed to get even a new session.")
|
||||
http.Error(rw, "Critical session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
scheme := t.determineScheme(req)
|
||||
host := t.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
|
||||
// Initiate authentication
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// Build redirect URL
|
||||
// --- URL Handling (Callback, Logout) ---
|
||||
scheme := t.determineScheme(req)
|
||||
host := t.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath) // Used for callback and re-auth
|
||||
|
||||
// Handle special URLs
|
||||
if req.URL.Path == t.logoutURLPath {
|
||||
t.handleLogout(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
if req.URL.Path == t.redirURLPath {
|
||||
t.handleCallback(rw, req, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// Check authentication status
|
||||
// --- Authentication & Refresh Logic ---
|
||||
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
|
||||
|
||||
if expired {
|
||||
t.logger.Debug("Session token is definitively expired or invalid, initiating re-auth")
|
||||
// handleExpiredToken clears the session and initiates auth
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if !authenticated {
|
||||
// Original logic: Always initiate authentication if not authenticated
|
||||
t.logger.Debug("User not authenticated, initiating OIDC flow")
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
// If authenticated and token doesn't need proactive refresh, proceed directly
|
||||
if authenticated && !needsRefresh {
|
||||
t.logger.Debug("User authenticated and token valid, proceeding to process authorized request")
|
||||
t.processAuthorizedRequest(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// --- Attempt Refresh if Needed or Possible ---
|
||||
// Conditions to attempt refresh:
|
||||
// 1. Token needs proactive refresh (authenticated=true, needsRefresh=true)
|
||||
// 2. Token is invalid/expired but a refresh token exists (authenticated=false, needsRefresh=true)
|
||||
refreshTokenPresent := session.GetRefreshToken() != ""
|
||||
shouldAttemptRefresh := needsRefresh && refreshTokenPresent
|
||||
|
||||
if shouldAttemptRefresh {
|
||||
if needsRefresh && authenticated {
|
||||
t.logger.Debug("Session token needs proactive refresh, attempting refresh")
|
||||
} else if needsRefresh && !authenticated {
|
||||
t.logger.Debug("Access token invalid/expired, but refresh token found. Attempting refresh.")
|
||||
}
|
||||
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if refreshed {
|
||||
// Refresh succeeded, proceed to authorization checks
|
||||
t.logger.Debug("Token refresh successful, proceeding to process authorized request")
|
||||
t.processAuthorizedRequest(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh failed
|
||||
t.logger.Infof("Token refresh failed (authenticated=%v, needsRefresh=%v, refreshTokenPresent=%v)", authenticated, needsRefresh, refreshTokenPresent)
|
||||
// Handle refresh failure (401 for API, re-auth for browser)
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "application/json") {
|
||||
t.logger.Debug("Client accepts JSON, sending 401 Unauthorized on refresh failure")
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(rw).Encode(map[string]string{"error": "unauthorized", "message": "Token refresh failed"})
|
||||
} else {
|
||||
t.logger.Debug("Client does not prefer JSON, handling refresh failure by initiating re-auth")
|
||||
// Use defaultInitiateAuthentication which clears the session properly
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
return // Stop processing
|
||||
}
|
||||
|
||||
if needsRefresh {
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if !refreshed {
|
||||
t.logger.Infof("Token refresh failed") // Changed from Warn to Infof
|
||||
// Check if the client prefers JSON (likely an API call)
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "application/json") {
|
||||
t.logger.Debug("Client accepts JSON, sending 401 Unauthorized on refresh failure")
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(rw).Encode(map[string]string{"error": "unauthorized", "message": "Token refresh failed"})
|
||||
} else {
|
||||
// Client likely a browser, initiate full re-authentication
|
||||
t.logger.Debug("Client does not prefer JSON, handling refresh failure as expired token (initiating re-auth)")
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
}
|
||||
return // Stop processing
|
||||
}
|
||||
}
|
||||
// --- Initiate Full Authentication ---
|
||||
// If we reach here, it means:
|
||||
// - User is not authenticated (!authenticated)
|
||||
// - AND EITHER token doesn't need refresh (!needsRefresh, e.g., first visit)
|
||||
// - OR refresh token is missing (!refreshTokenPresent)
|
||||
// - OR refresh was attempted but failed (handled above)
|
||||
t.logger.Debugf("Initiating full OIDC authentication flow (authenticated=%v, needsRefresh=%v, refreshTokenPresent=%v)", authenticated, needsRefresh, refreshTokenPresent)
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// Process authenticated request
|
||||
// processAuthorizedRequest handles the final steps for an authenticated and authorized request.
|
||||
// It performs domain/role/group checks, sets headers, and forwards the request.
|
||||
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
t.logger.Debug("No email found in session")
|
||||
t.logger.Error("CRITICAL: No email found in session during final processing, initiating re-auth")
|
||||
// This case should ideally not happen if checks are done correctly before calling this,
|
||||
// but as a safeguard, initiate re-authentication.
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
@@ -721,6 +764,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
// Continue without group/role headers if extraction fails
|
||||
} else {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
@@ -779,6 +823,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
// Process the request
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
@@ -794,19 +839,21 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// - session: The user's session data containing the expired token information.
|
||||
// - redirectURL: The callback URL to be used in the new authentication flow.
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Clear authentication data but preserve CSRF state
|
||||
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
|
||||
// Clear authentication data but preserve CSRF state if possible (though Clear might remove it)
|
||||
session.SetAuthenticated(false)
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
|
||||
// Save the cleared session state
|
||||
// Save the cleared session state (this sends expired cookies)
|
||||
// Pass rw to ensure expiring cookies are sent
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save cleared session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
t.logger.Errorf("Failed to save cleared session during expired token handling: %v", err)
|
||||
// Still attempt to initiate authentication, but log the error
|
||||
}
|
||||
|
||||
// Initiate a new authentication flow
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
@@ -834,8 +881,8 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Session error: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
t.logger.Errorf("Session error during callback: %v", err)
|
||||
http.Error(rw, "Session error during callback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -847,7 +894,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
if errorDescription == "" {
|
||||
errorDescription = req.URL.Query().Get("error") // Use error code if description is empty
|
||||
}
|
||||
t.logger.Errorf("Authentication error from provider: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
t.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
t.sendErrorResponse(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@@ -862,13 +909,13 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
t.logger.Error("CSRF token missing in session")
|
||||
t.sendErrorResponse(rw, req, "CSRF token missing", http.StatusBadRequest)
|
||||
t.logger.Error("CSRF token missing in session during callback")
|
||||
t.sendErrorResponse(rw, req, "CSRF token missing in session", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state != csrfToken {
|
||||
t.logger.Error("State parameter does not match CSRF token in session")
|
||||
t.logger.Error("State parameter does not match CSRF token in session during callback")
|
||||
t.sendErrorResponse(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@@ -886,22 +933,21 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
||||
t.logger.Errorf("Failed to exchange code for token during callback: %v", err)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify tokens and claims
|
||||
// Use the exported VerifyToken method now that handleCallback is in main.go
|
||||
if err := t.VerifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
t.logger.Errorf("Failed to verify id_token during callback: %v", err)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
t.logger.Errorf("Failed to extract claims during callback: %v", err)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -909,51 +955,68 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
// Verify nonce to prevent replay attacks
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
t.logger.Error("Nonce claim missing in id_token during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
t.logger.Error("Nonce not found in session")
|
||||
t.logger.Error("Nonce not found in session during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
t.logger.Error("Nonce claim does not match session nonce")
|
||||
t.logger.Error("Nonce claim does not match session nonce during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate user's email domain
|
||||
// Use the unexported isAllowedDomain method now that handleCallback is in main.go
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
|
||||
if email == "" {
|
||||
t.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Update session with authentication data
|
||||
session.SetAuthenticated(true)
|
||||
// Regenerate session ID upon successful authentication
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
|
||||
http.Error(rw, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetAccessToken(tokenResponse.IDToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// Clear CSRF, Nonce, CodeVerifier after use
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
session.SetCodeVerifier("")
|
||||
|
||||
// Redirect to original path or root
|
||||
// Retrieve original path *before* saving, as save might clear it if Clear was called concurrently
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
}
|
||||
session.SetIncomingPath("") // Clear incoming path after retrieving it
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session after callback: %v", err)
|
||||
http.Error(rw, "Failed to save session after callback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to original path or root
|
||||
t.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -972,7 +1035,7 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
t.logger.Debugf("URL is not excluded - got %s", currentRequest)
|
||||
// t.logger.Debugf("URL is not excluded - got %s", currentRequest) // Too verbose for every request
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1024,32 +1087,58 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||
// - expired (bool): True if the session is unauthenticated, the token is missing, or the token verification failed for reasons other than nearing/actual expiration (e.g., invalid signature, invalid claims).
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debug("User is not authenticated according to session")
|
||||
return false, false, false
|
||||
t.logger.Debug("User is not authenticated according to session flag")
|
||||
// Check if there's still a refresh token - if so, refresh might be possible
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("Session not authenticated, but refresh token exists. Signaling need for refresh.")
|
||||
return false, true, false // Not authenticated, NeedsRefresh=true (to attempt recovery), Expired=false
|
||||
}
|
||||
return false, false, false // Not authenticated, no refresh token, definitely not expired (just unauth)
|
||||
}
|
||||
|
||||
accessToken := session.GetAccessToken()
|
||||
if accessToken == "" {
|
||||
t.logger.Debug("No access token found in session")
|
||||
return false, false, true // Session is invalid, consider it expired
|
||||
t.logger.Debug("Authenticated flag set, but no access token found in session")
|
||||
// If authenticated flag is true but token is missing, treat as expired/invalid session state
|
||||
// Check for refresh token before declaring fully expired
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("Authenticated flag set, access token missing, but refresh token exists. Signaling need for refresh.")
|
||||
return false, true, false // Not authenticated (no access token), NeedsRefresh=true, Expired=false
|
||||
}
|
||||
return false, false, true // No access or refresh token, treat as expired
|
||||
}
|
||||
|
||||
// Verify the token structure and signature first
|
||||
jwt, err := parseJWT(accessToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to parse JWT during auth check: %v", err)
|
||||
return false, false, true // Invalid format, treat as expired/invalid
|
||||
// Check for refresh token before declaring fully expired
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("Access token parsing failed, but refresh token exists. Signaling need for refresh.")
|
||||
return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false
|
||||
}
|
||||
return false, false, true // Invalid format, no refresh token, treat as expired/invalid
|
||||
}
|
||||
if err := t.VerifyJWTSignatureAndClaims(jwt, accessToken); err != nil {
|
||||
// Check if the error is specifically about expiration
|
||||
if strings.Contains(err.Error(), "token has expired") {
|
||||
t.logger.Debugf("Token signature/claims valid but token expired, attempting refresh")
|
||||
t.logger.Debugf("Access token signature/claims valid but token expired, needs refresh")
|
||||
// Token is expired but otherwise valid, signal for refresh
|
||||
return true, true, false // Authenticated=true (was valid), NeedsRefresh=true, Expired=false (because refresh is possible)
|
||||
// Return authenticated=false because the current token is unusable
|
||||
// NeedsRefresh is true only if a refresh token exists
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false // Not authenticated (current token unusable), NeedsRefresh=true, Expired=false (because refresh might fix it)
|
||||
}
|
||||
return false, false, true // Expired access token, no refresh token, treat as expired
|
||||
}
|
||||
// Other verification error (signature, issuer, audience etc.)
|
||||
t.logger.Errorf("Token verification failed (non-expiration): %v", err)
|
||||
return false, false, true // Token is invalid for other reasons
|
||||
t.logger.Errorf("Access token verification failed (non-expiration): %v", err)
|
||||
// Check for refresh token before declaring fully expired
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("Access token verification failed, but refresh token exists. Signaling need for refresh.")
|
||||
return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false
|
||||
}
|
||||
return false, false, true // Token is invalid for other reasons, no refresh token, treat as expired/invalid session
|
||||
}
|
||||
|
||||
// Claims already parsed within VerifyJWTSignatureAndClaims if it didn't error early
|
||||
@@ -1057,8 +1146,13 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
|
||||
|
||||
expClaim, ok := claims["exp"].(float64)
|
||||
if !ok {
|
||||
t.logger.Error("Failed to get expiration time from claims")
|
||||
return false, false, true
|
||||
t.logger.Error("Failed to get expiration time ('exp' claim) from verified token")
|
||||
// Check for refresh token before declaring fully expired
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("Access token missing 'exp' claim, but refresh token exists. Signaling need for refresh.")
|
||||
return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false
|
||||
}
|
||||
return false, false, true // Treat as invalid if 'exp' is missing and no refresh token
|
||||
}
|
||||
|
||||
expTime := int64(expClaim)
|
||||
@@ -1071,12 +1165,19 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
|
||||
if time.Unix(expTime, 0).Before(time.Now().Add(t.refreshGracePeriod)) {
|
||||
// Recalculate remaining seconds for logging clarity if needed, using the configured duration
|
||||
remainingSeconds := int64(time.Until(time.Unix(expTime, 0)).Seconds())
|
||||
t.logger.Debugf("Token nearing expiration (expires in %d seconds, grace period %s), scheduling refresh", remainingSeconds, t.refreshGracePeriod)
|
||||
return true, true, false // Needs proactive refresh
|
||||
t.logger.Debugf("Access token nearing expiration (expires in %d seconds, grace period %s), scheduling proactive refresh", remainingSeconds, t.refreshGracePeriod)
|
||||
// Token is still valid, but we should refresh it soon
|
||||
// NeedsRefresh is true only if a refresh token exists
|
||||
if session.GetRefreshToken() != "" {
|
||||
return true, true, false // Authenticated=true (current token usable), NeedsRefresh=true, Expired=false
|
||||
}
|
||||
// If no refresh token, we can't proactively refresh, treat as normal valid token for now
|
||||
t.logger.Debugf("Token nearing expiration but no refresh token available, cannot proactively refresh.")
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
// Token is valid, not expired, and not nearing expiration
|
||||
return true, false, false
|
||||
return true, false, false // Authenticated=true, NeedsRefresh=false, Expired=false
|
||||
}
|
||||
|
||||
// defaultInitiateAuthentication handles the process of starting an OIDC authentication flow.
|
||||
@@ -1092,10 +1193,12 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
|
||||
// - session: The user's SessionData object (potentially new or cleared).
|
||||
// - redirectURL: The pre-calculated callback URL (redirect_uri) for this middleware instance.
|
||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
|
||||
// Generate CSRF token and nonce
|
||||
csrfToken := uuid.NewString()
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to generate nonce: %v", err)
|
||||
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -1106,37 +1209,41 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
||||
var err error
|
||||
codeVerifier, err = generateCodeVerifier()
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to generate code verifier: %v", err)
|
||||
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Derive code challenge from verifier
|
||||
codeChallenge = deriveCodeChallenge(codeVerifier)
|
||||
t.logger.Debugf("PKCE enabled, generated code challenge")
|
||||
}
|
||||
|
||||
// Clear any existing session data to avoid stale state causing redirect loops
|
||||
session.Clear(req, rw)
|
||||
// Pass the response writer to ensure expiring cookies are sent
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
// Log the error but continue, as clearing is best-effort before re-auth
|
||||
t.logger.Errorf("Error clearing session before initiating authentication: %v", err)
|
||||
}
|
||||
|
||||
// Set new session values
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
|
||||
// Only set code verifier if PKCE is enabled
|
||||
if t.enablePKCE {
|
||||
session.SetCodeVerifier(codeVerifier)
|
||||
}
|
||||
|
||||
// Store the original path the user was trying to access
|
||||
session.SetIncomingPath(req.URL.RequestURI())
|
||||
t.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI())
|
||||
|
||||
// Save the session
|
||||
// Save the session (to store CSRF, Nonce, etc.)
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
t.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Build and redirect to authentication URL
|
||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
|
||||
t.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -1185,6 +1292,7 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
params.Set("scope", strings.Join(t.scopes, " "))
|
||||
}
|
||||
|
||||
// Use buildURLWithParams which handles potential relative authURL from metadata
|
||||
return t.buildURLWithParams(t.authURL, params)
|
||||
}
|
||||
|
||||
@@ -1201,17 +1309,30 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
// Ensure URL is absolute
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
// Extract issuer base URL
|
||||
issuerURL, err := url.Parse(t.issuerURL)
|
||||
// Attempt to resolve relative URL against issuer URL
|
||||
issuerURLParsed, err := url.Parse(t.issuerURL)
|
||||
if err == nil {
|
||||
return fmt.Sprintf("%s://%s%s?%s",
|
||||
issuerURL.Scheme,
|
||||
issuerURL.Host,
|
||||
baseURL,
|
||||
params.Encode())
|
||||
baseURLParsed, err := url.Parse(baseURL)
|
||||
if err == nil {
|
||||
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
|
||||
resolvedURL.RawQuery = params.Encode()
|
||||
return resolvedURL.String()
|
||||
}
|
||||
}
|
||||
// Fallback if parsing fails - append params to potentially relative path
|
||||
t.logger.Errorf("Could not parse issuerURL or baseURL to resolve relative URL. BaseURL: %s, IssuerURL: %s", baseURL, t.issuerURL)
|
||||
return baseURL + "?" + params.Encode()
|
||||
}
|
||||
return baseURL + "?" + params.Encode()
|
||||
|
||||
// If baseURL is already absolute
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err)
|
||||
// Fallback: append params directly
|
||||
return baseURL + "?" + params.Encode()
|
||||
}
|
||||
u.RawQuery = params.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// startTokenCleanup starts background goroutines for periodically cleaning up
|
||||
@@ -1246,6 +1367,7 @@ func (t *TraefikOidc) RevokeToken(token string) {
|
||||
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
|
||||
// Use Set with a duration. Value 'true' is arbitrary, we only care about existence.
|
||||
t.tokenBlacklist.Set(token, true, time.Until(expiry))
|
||||
t.logger.Debugf("Locally revoked token (added to blacklist)")
|
||||
}
|
||||
|
||||
// RevokeTokenWithProvider attempts to revoke a token directly with the OIDC provider
|
||||
@@ -1260,7 +1382,10 @@ func (t *TraefikOidc) RevokeToken(token string) {
|
||||
// - nil if the revocation request is successful (provider returns 200 OK).
|
||||
// - An error if the request fails or the provider returns a non-OK status.
|
||||
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
t.logger.Debugf("Revoking token with provider")
|
||||
if t.revocationURL == "" {
|
||||
return fmt.Errorf("token revocation endpoint is not configured or discovered")
|
||||
}
|
||||
t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, t.revocationURL)
|
||||
|
||||
data := url.Values{
|
||||
"token": {token},
|
||||
@@ -1277,6 +1402,7 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json") // Prefer JSON response if available
|
||||
|
||||
// Send the request
|
||||
resp, err := t.httpClient.Do(req)
|
||||
@@ -1288,18 +1414,20 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
// Check the response
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("token revocation failed with status %d: %s", resp.StatusCode, string(body))
|
||||
// Log the failure details
|
||||
t.logger.Errorf("Token revocation failed with status %d: %s", resp.StatusCode, string(body))
|
||||
return fmt.Errorf("token revocation failed with status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
t.logger.Debugf("Token successfully revoked")
|
||||
t.logger.Debugf("Token successfully revoked with provider")
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshToken attempts to use the refresh token stored in the session to obtain a new set of tokens.
|
||||
// It acquires a mutex associated with the session to prevent concurrent refresh attempts for the same session.
|
||||
// It retrieves the refresh token, calls the TokenExchanger's GetNewTokenWithRefreshToken method,
|
||||
// verifies the newly obtained ID token using verifyToken, updates the session with the new tokens,
|
||||
// and saves the session.
|
||||
// verifies the newly obtained ID token using verifyToken, performs a concurrency check,
|
||||
// updates the session with the new tokens if the check passes, and saves the session.
|
||||
//
|
||||
// Parameters:
|
||||
// - rw: The HTTP response writer (needed for saving the updated session).
|
||||
@@ -1308,45 +1436,85 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
//
|
||||
// Returns:
|
||||
// - true if the token refresh was successful and the session was updated.
|
||||
// - false if no refresh token was found, the refresh exchange failed, the new token failed verification, or saving the session failed.
|
||||
// - false if no refresh token was found, the refresh exchange failed, the new token failed verification,
|
||||
// a concurrency conflict was detected, or saving the session failed.
|
||||
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
|
||||
// Lock the mutex specific to this session instance before attempting refresh
|
||||
session.refreshMutex.Lock()
|
||||
defer session.refreshMutex.Unlock()
|
||||
|
||||
t.logger.Debug("Attempting to refresh token (mutex acquired)")
|
||||
refreshToken := session.GetRefreshToken() // Get token *after* acquiring lock
|
||||
if refreshToken == "" {
|
||||
t.logger.Debug("No refresh token found in session (inside lock)")
|
||||
initialRefreshToken := session.GetRefreshToken() // Get token *after* acquiring lock
|
||||
if initialRefreshToken == "" {
|
||||
t.logger.Errorf("refreshToken failed: No refresh token found in session (after acquiring lock)")
|
||||
return false
|
||||
}
|
||||
|
||||
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
// Store the initial token for later comparison
|
||||
t.logger.Debugf("Attempting refresh with token starting with %s...", initialRefreshToken[:min(len(initialRefreshToken), 10)])
|
||||
|
||||
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
|
||||
if err != nil {
|
||||
// Log the error, potentially clear the invalid refresh token?
|
||||
t.logger.Errorf("Failed to refresh token using refresh token: %v", err)
|
||||
// Consider clearing the refresh token from the session here if the error indicates it's invalid
|
||||
// session.SetRefreshToken("") // Example: Clear potentially invalid token
|
||||
// session.Save(req, rw) // Need to handle potential save error
|
||||
// Log the error more explicitly before returning false
|
||||
truncatedToken := initialRefreshToken[:min(len(initialRefreshToken), 10)] // Log first 10 chars
|
||||
t.logger.Errorf("refreshToken failed: Error from tokenExchanger.GetNewTokenWithRefreshToken for token starting with %s...: %v", truncatedToken, err)
|
||||
// No need to clear token here, as the session might be cleared by another request anyway
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify the new access token
|
||||
// Verify the new access token (ID token)
|
||||
if err := t.verifyToken(newToken.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify new access token: %v", err)
|
||||
truncatedNewToken := newToken.IDToken[:min(len(newToken.IDToken), 10)] // Log first 10 chars
|
||||
t.logger.Errorf("refreshToken failed: Failed to verify newly obtained ID token starting with %s...: %v", truncatedNewToken, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Update session with new tokens
|
||||
// --- Concurrency Check ---
|
||||
// Before saving the new token, check if the session state (specifically the refresh token)
|
||||
// has been modified concurrently (e.g., by a logout or another auth initiation).
|
||||
currentRefreshToken := session.GetRefreshToken() // Get token again *after* the potentially long exchange
|
||||
if initialRefreshToken != currentRefreshToken {
|
||||
// Use Infof as Warnf doesn't exist
|
||||
t.logger.Infof("refreshToken aborted: Session refresh token changed concurrently during refresh attempt. Initial token prefix: %s..., Current token prefix: %s...", initialRefreshToken[:min(len(initialRefreshToken), 10)], currentRefreshToken[:min(len(currentRefreshToken), 10)])
|
||||
// Do not save the new tokens, as the session state is likely invalid/cleared.
|
||||
return false // Indicate refresh failure due to concurrency conflict
|
||||
}
|
||||
// --- End Concurrency Check ---
|
||||
|
||||
// Update session with new tokens ONLY if the concurrency check passed
|
||||
t.logger.Debugf("Concurrency check passed. Updating session with new tokens.")
|
||||
|
||||
// Extract email from the new token and update session
|
||||
claims, err := t.extractClaimsFunc(newToken.IDToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
|
||||
return false // Cannot proceed without claims
|
||||
}
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
|
||||
return false // Cannot proceed without email
|
||||
}
|
||||
session.SetEmail(email) // Update email in session
|
||||
|
||||
session.SetAccessToken(newToken.IDToken)
|
||||
session.SetRefreshToken(newToken.RefreshToken)
|
||||
// Ensure the new refresh token is actually set, even if it's the same as the old one
|
||||
// Also handle cases where the provider might not return a new refresh token
|
||||
if newToken.RefreshToken != "" {
|
||||
session.SetRefreshToken(newToken.RefreshToken)
|
||||
} else {
|
||||
// If no new refresh token is returned, keep the existing one
|
||||
session.SetRefreshToken(initialRefreshToken)
|
||||
t.logger.Debugf("Provider did not return a new refresh token, keeping the existing one.")
|
||||
}
|
||||
|
||||
// Save the session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save refreshed session: %v", err)
|
||||
t.logger.Errorf("refreshToken failed: Failed to save session after successful token refresh and concurrency check: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
t.logger.Debugf("Token refresh successful and session saved.")
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1367,6 +1535,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
||||
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
t.logger.Errorf("Invalid email format encountered: %s", email)
|
||||
return false // Invalid email format
|
||||
}
|
||||
|
||||
@@ -1400,12 +1569,16 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
if groupsClaim, exists := claims["groups"]; exists {
|
||||
groupsSlice, ok := groupsClaim.([]interface{})
|
||||
if !ok {
|
||||
// Strictly expect an array
|
||||
return nil, nil, fmt.Errorf("groups claim is not an array")
|
||||
}
|
||||
for _, group := range groupsSlice {
|
||||
if groupStr, ok := group.(string); ok {
|
||||
t.logger.Debugf("Found group: %s", groupStr)
|
||||
groups = append(groups, groupStr)
|
||||
} else {
|
||||
for _, group := range groupsSlice {
|
||||
if groupStr, ok := group.(string); ok {
|
||||
t.logger.Debugf("Found group: %s", groupStr)
|
||||
groups = append(groups, groupStr)
|
||||
} else {
|
||||
t.logger.Errorf("Non-string value found in groups claim array: %v", group)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1414,12 +1587,16 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
if rolesClaim, exists := claims["roles"]; exists {
|
||||
rolesSlice, ok := rolesClaim.([]interface{})
|
||||
if !ok {
|
||||
// Strictly expect an array
|
||||
return nil, nil, fmt.Errorf("roles claim is not an array")
|
||||
}
|
||||
for _, role := range rolesSlice {
|
||||
if roleStr, ok := role.(string); ok {
|
||||
t.logger.Debugf("Found role: %s", roleStr)
|
||||
roles = append(roles, roleStr)
|
||||
} else {
|
||||
for _, role := range rolesSlice {
|
||||
if roleStr, ok := role.(string); ok {
|
||||
t.logger.Debugf("Found role: %s", roleStr)
|
||||
roles = append(roles, roleStr)
|
||||
} else {
|
||||
t.logger.Errorf("Non-string value found in roles claim array: %v", role)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1493,8 +1670,8 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
|
||||
rw.WriteHeader(code)
|
||||
// Use a simple error structure
|
||||
json.NewEncoder(rw).Encode(map[string]interface{}{
|
||||
"error": http.StatusText(code),
|
||||
"error_description": message,
|
||||
"error": http.StatusText(code), // Use standard text for the code
|
||||
"error_description": message, // Provide specific detail here
|
||||
"status_code": code,
|
||||
})
|
||||
return
|
||||
@@ -1504,17 +1681,9 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
|
||||
t.logger.Debugf("Sending HTML error response (code %d): %s", code, message)
|
||||
|
||||
// Determine the return URL (mostly relevant for HTML)
|
||||
returnURL := "/" // Default to root
|
||||
session, err := t.sessionManager.GetSession(req) // Attempt to get session for return URL
|
||||
if err == nil {
|
||||
incomingPath := session.GetIncomingPath()
|
||||
// Use incoming path if it's valid and not one of the special OIDC paths
|
||||
if incomingPath != "" && incomingPath != t.redirURLPath && incomingPath != t.logoutURLPath {
|
||||
returnURL = incomingPath
|
||||
}
|
||||
} else {
|
||||
t.logger.Infof("Could not get session to determine return URL in sendErrorResponse: %v", err)
|
||||
}
|
||||
returnURL := "/" // Default to root
|
||||
// No need to get session here, as we are already in an error path
|
||||
// where session might be invalid or unavailable.
|
||||
|
||||
// Basic HTML structure for the error page
|
||||
htmlBody := fmt.Sprintf(`
|
||||
@@ -1537,7 +1706,7 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
|
||||
<p><a href="%s">Return to application</a></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, message, returnURL)
|
||||
</html>`, message, returnURL) // Use default returnURL
|
||||
|
||||
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rw.WriteHeader(code)
|
||||
|
||||
+71
-17
@@ -385,9 +385,52 @@ func TestServeHTTP(t *testing.T) {
|
||||
expectedBody: "OK",
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated request to protected URL",
|
||||
requestPath: "/protected",
|
||||
expectedStatus: http.StatusFound, // Expect redirect to OIDC
|
||||
name: "Unauthenticated request (no refresh token) to protected URL",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
// Ensure no tokens are set
|
||||
session.SetAuthenticated(false)
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
},
|
||||
expectedStatus: http.StatusFound, // Expect redirect to OIDC as there's no refresh token
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated request (with refresh token) to protected URL - Expect Refresh Attempt",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(false) // Not authenticated
|
||||
session.SetAccessToken("") // No access token
|
||||
session.SetRefreshToken("valid-refresh-token-for-unauth-test") // BUT has refresh token
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
if refreshToken != "valid-refresh-token-for-unauth-test" {
|
||||
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
|
||||
}
|
||||
// Simulate successful refresh
|
||||
newToken := createNewValidToken() // Use helper from TestServeHTTP
|
||||
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-unauth", ExpiresIn: 3600}, nil
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusOK, // Expect OK after successful refresh
|
||||
expectedBody: "OK",
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated request (with refresh token) to protected URL - Refresh Fails",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(false) // Not authenticated
|
||||
session.SetAccessToken("") // No access token
|
||||
session.SetRefreshToken("invalid-refresh-token-for-unauth-test") // Invalid refresh token
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
// Simulate failed refresh
|
||||
return nil, fmt.Errorf("mock error: refresh token invalid")
|
||||
}
|
||||
},
|
||||
expectedStatus: http.StatusFound, // Expect redirect to OIDC after failed refresh
|
||||
},
|
||||
{
|
||||
name: "Authenticated request to protected URL (Valid Token)",
|
||||
@@ -407,11 +450,15 @@ func TestServeHTTP(t *testing.T) {
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: "OK",
|
||||
},
|
||||
// This test case remains valid as the logic should still attempt refresh when expired token + refresh token exist
|
||||
{
|
||||
name: "Authenticated request with expired token and successful refresh",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Still marked authenticated initially
|
||||
// NOTE: isUserAuthenticated now returns authenticated=false if access token is expired,
|
||||
// even if session.SetAuthenticated(true) was called.
|
||||
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
|
||||
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(createExpiredToken()) // Set expired token
|
||||
session.SetRefreshToken("valid-refresh-token") // Set valid refresh token
|
||||
@@ -445,16 +492,19 @@ func TestServeHTTP(t *testing.T) {
|
||||
t.Fatalf("Failed to get session after request: %v", err)
|
||||
}
|
||||
// Assert new tokens are in the session
|
||||
// Direct comparison with createNewValidToken() is flawed as it generates a new token each time.
|
||||
// Instead, check if the token was updated (not empty) and verify the refresh token.
|
||||
if session.GetAccessToken() == "" {
|
||||
t.Errorf("Expected access token to be updated in session, but it was empty")
|
||||
if session.GetAccessToken() == "" || session.GetAccessToken() == createExpiredToken() {
|
||||
t.Errorf("Expected access token to be updated in session, but it was empty or still the expired one")
|
||||
}
|
||||
if session.GetRefreshToken() != "new-refresh-token" {
|
||||
t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken())
|
||||
}
|
||||
// Also check authenticated flag is now true
|
||||
if !session.GetAuthenticated() {
|
||||
t.Errorf("Expected session to be marked authenticated after successful refresh")
|
||||
}
|
||||
},
|
||||
},
|
||||
// This test case remains valid as the logic should still return 401 for API clients on refresh failure
|
||||
{
|
||||
name: "Logout URL",
|
||||
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
|
||||
@@ -477,10 +527,10 @@ func TestServeHTTP(t *testing.T) {
|
||||
name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(createExpiredToken())
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
session.SetAccessToken(createExpiredToken()) // Expired access token
|
||||
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
@@ -491,17 +541,18 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client
|
||||
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client after failed refresh attempt
|
||||
expectedBody: `{"error":"unauthorized","message":"Token refresh failed"}`,
|
||||
},
|
||||
// This test case remains valid as the logic should still redirect browser clients on refresh failure
|
||||
{
|
||||
name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)",
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(createExpiredToken())
|
||||
session.SetRefreshToken("valid-refresh-token")
|
||||
session.SetAccessToken(createExpiredToken()) // Expired access token
|
||||
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
@@ -512,8 +563,9 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestHeaders: map[string]string{
|
||||
"Accept": "text/html", // Browser client
|
||||
},
|
||||
expectedStatus: http.StatusFound, // Expect redirect for browser client
|
||||
expectedStatus: http.StatusFound, // Expect redirect to OIDC for browser client after failed refresh attempt
|
||||
},
|
||||
// This test case remains valid as proactive refresh should still be attempted
|
||||
{
|
||||
name: "Authenticated request with token nearing expiry (needs refresh)",
|
||||
requestPath: "/protected",
|
||||
@@ -529,7 +581,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(nearExpiryToken)
|
||||
session.SetRefreshToken("valid-refresh-token-for-near-expiry")
|
||||
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
|
||||
},
|
||||
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
||||
return func(refreshToken string) (*TokenResponse, error) {
|
||||
@@ -544,6 +596,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
expectedStatus: http.StatusOK, // Expect success after proactive refresh
|
||||
expectedBody: "OK",
|
||||
},
|
||||
// This test case remains valid as no refresh should be attempted
|
||||
{
|
||||
name: "Authenticated request with token valid (outside grace period)",
|
||||
requestPath: "/protected",
|
||||
@@ -1531,6 +1584,7 @@ func TestRevokeToken(t *testing.T) {
|
||||
tOidc := &TraefikOidc{
|
||||
tokenBlacklist: NewCache(), // Use generic cache for blacklist
|
||||
tokenCache: NewTokenCache(),
|
||||
logger: NewLogger("info"), // Initialize the logger
|
||||
}
|
||||
|
||||
// Cache the token
|
||||
|
||||
Reference in New Issue
Block a user