Compare commits

..

4 Commits

5 changed files with 709 additions and 201 deletions
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Lukasz Raczylo
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+58 -8
View File
@@ -69,14 +69,15 @@ The middleware supports the following configuration options:
| `postLogoutRedirectURI` | The URL to redirect to after logout | `/` | `/logged-out-page` |
| `scopes` | The OAuth 2.0 scopes to request | `["openid", "profile", "email"]` | `["openid", "email", "profile", "roles"]` |
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
| | `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
| | `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
| | `excludedURLs` | Lists paths that bypass authentication | `["/favicon"]` | `["/health", "/metrics", "/public"]` |
| | `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
| | `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| | `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
| | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
| `excludedURLs` | Lists paths that bypass authentication | none | `["/health", "/metrics", "/public"]` |
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
## Usage Examples
@@ -258,6 +259,34 @@ spec:
- profile
```
### Google OIDC Configuration Example
This example shows a configuration specifically tailored for Google OIDC, including necessary scopes for session extension:
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-google
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-google-client-id.apps.googleusercontent.com # Replace with your Client ID
clientSecret: your-google-client-secret # Replace with your Client Secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars # Replace with your key
callbackURL: /oauth2/callback # Adjust if needed
logoutURL: /oauth2/logout # Optional: Adjust if needed
scopes:
- openid
- email
- profile
- offline_access # Required for refresh tokens / long sessions with Google
refreshGracePeriodSeconds: 300 # Optional: Start refresh 5 min before expiry (default 60)
# Other optional parameters like allowedUserDomains, etc. can be added here
```
### Keeping Secrets Secret in Kubernetes
For Kubernetes environments, you can reference secrets instead of hardcoding sensitive values:
@@ -414,6 +443,23 @@ PKCE is recommended when:
Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature.
### Session Duration and Token Refresh
This middleware aims to provide long-lived user sessions, typically up to 24 hours, by utilizing OIDC refresh tokens.
**How it works:**
- When a user authenticates, the middleware requests an access token and, if available, a refresh token from the OIDC provider.
- The access token usually has a short lifespan (e.g., 1 hour).
- Before the access token expires (controlled by `refreshGracePeriodSeconds`), the middleware uses the refresh token to obtain a new access token from the provider without requiring the user to log in again.
- This process repeats, allowing the session to remain valid for as long as the refresh token is valid (often 24 hours or more, depending on the provider).
**Provider-Specific Considerations (e.g., Google):**
- Some providers, like Google, issue short-lived access tokens (e.g., 1 hour) and require specific configurations for long-term sessions.
- To enable session extension beyond the initial token expiry with Google and similar providers, the middleware automatically includes the `offline_access` scope in the authentication request. This scope is necessary to obtain a refresh token.
- For Google specifically, the middleware also adds the `prompt=consent` parameter to the initial authorization request. This ensures Google issues a refresh token, which is crucial for extending the session.
- If a refresh attempt fails (e.g., the refresh token is revoked or expired), the user will be required to re-authenticate. The middleware includes enhanced error handling and logging for these scenarios.
- Ensure your OIDC provider is configured to issue refresh tokens and allows their use for extending sessions. Check your provider's documentation for details on refresh token validity periods.
### Token Caching and Blacklisting
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
@@ -455,6 +501,10 @@ logLevel: debug
3. **No matching public key found**: The JWKS endpoint might be unavailable or the token's key ID (kid) doesn't match any key in the JWKS.
4. **Access denied: Your email domain is not allowed**: The user's email domain is not in the `allowedUserDomains` list.
5. **Access denied: You do not have any of the allowed roles or groups**: The user doesn't have any of the roles or groups specified in `allowedRolesAndGroups`.
6. **Google sessions expire after ~1 hour**: If using Google as the OIDC provider and sessions expire prematurely (around 1 hour instead of longer), ensure:
- The `offline_access` scope is included in your configuration (the middleware adds this automatically now, but verify if manually configured).
- Your Google Cloud OAuth consent screen is set to "External" and "Production" mode. "Testing" mode often limits refresh token validity.
- The fix involving automatic `offline_access` scope and `prompt=consent` for Google is active in your middleware version. Check the plugin version corresponds to when this fix was implemented. Enhanced logging around refresh token failures can provide more clues if issues persist.
## Contributing
+147
View File
@@ -0,0 +1,147 @@
package traefikoidc
import (
"fmt"
"net/http/httptest"
"strings"
"testing"
"time"
)
// MockTokenVerifier implements the TokenVerifier interface for testing
type MockTokenVerifier struct {
VerifyFunc func(token string) error
}
func (m *MockTokenVerifier) VerifyToken(token string) error {
if m.VerifyFunc != nil {
return m.VerifyFunc(token)
}
return nil
}
func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
// Create a mocked TraefikOidc instance that simulates Google provider behavior
mockLogger := NewLogger("debug")
// Create a test instance with a Google-like issuer URL
tOidc := &TraefikOidc{
issuerURL: "https://accounts.google.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
logger: mockLogger,
scopes: []string{"openid", "profile", "email"},
refreshGracePeriod: 60,
}
// Create a session manager
sessionManager, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, mockLogger)
tOidc.sessionManager = sessionManager
t.Run("Google provider detection adds required parameters", func(t *testing.T) {
// Test buildAuthURL to ensure it adds offline_access and prompt=consent for Google
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Check that offline_access scope was added
if !strings.Contains(authURL, "scope=") || !strings.Contains(authURL, "offline_access") {
t.Errorf("offline_access scope not added to Google auth URL: %s", authURL)
}
// Check that prompt=consent was added
if !strings.Contains(authURL, "prompt=consent") {
t.Errorf("prompt=consent not added to Google auth URL: %s", authURL)
}
})
t.Run("Non-Google provider doesn't add Google-specific params", func(t *testing.T) {
// Create a test instance with a non-Google issuer URL
nonGoogleOidc := &TraefikOidc{
issuerURL: "https://auth.example.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
logger: mockLogger,
scopes: []string{"openid", "profile", "email"},
}
// Test buildAuthURL without Google-specific parameters
authURL := nonGoogleOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
// Check that prompt=consent is not automatically added
if strings.Contains(authURL, "prompt=consent") {
t.Errorf("prompt=consent added to non-Google auth URL: %s", authURL)
}
})
t.Run("Session refresh with Google provider", func(t *testing.T) {
// Create a request and response recorder
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
// Create a session and set a refresh token
session, _ := sessionManager.GetSession(req)
session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetAccessToken("old-access-token")
session.SetRefreshToken("valid-refresh-token")
// Create a mock token exchanger that simulates Google's behavior
mockTokenExchanger := &MockTokenExchanger{
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
// Check that the refresh token is passed correctly
if refreshToken != "valid-refresh-token" {
t.Errorf("Incorrect refresh token passed: %s", refreshToken)
return nil, fmt.Errorf("invalid token")
}
// Return a simulated Google token response with a new access token
// but without a new refresh token (Google doesn't always return a new refresh token)
return &TokenResponse{
IDToken: "new-id-token-from-google",
AccessToken: "new-access-token-from-google",
RefreshToken: "", // Google often doesn't return a new refresh token
ExpiresIn: 3600,
}, nil
},
}
// Set the mock token exchanger
tOidc.tokenExchanger = mockTokenExchanger
// Create a struct that implements the TokenVerifier interface
tOidc.tokenVerifier = &MockTokenVerifier{
VerifyFunc: func(token string) error {
return nil
},
}
tOidc.extractClaimsFunc = func(token string) (map[string]interface{}, error) {
// Return mock claims
return map[string]interface{}{
"email": "test@example.com",
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
}, nil
}
// Attempt to refresh the token
refreshed := tOidc.refreshToken(rw, req, session)
// Verify the refresh was successful
if !refreshed {
t.Error("Token refresh failed for Google provider")
}
// Check that we kept the original refresh token since Google didn't provide a new one
if session.GetRefreshToken() != "valid-refresh-token" {
t.Errorf("Original refresh token not preserved: got %s, expected 'valid-refresh-token'",
session.GetRefreshToken())
}
// Check that the access token was updated
if session.GetAccessToken() != "new-id-token-from-google" {
t.Errorf("Access token not updated: got %s, expected 'new-id-token-from-google'",
session.GetAccessToken())
}
})
}
// No need to redefine MockTokenExchanger - it's already defined in main_test.go
+412 -176
View File
@@ -442,6 +442,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to get provider metadata: %v", err)
// Consider retrying or handling this more gracefully
return
}
@@ -457,7 +458,8 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
return
}
t.logger.Error("Received nil metadata")
t.logger.Error("Received nil metadata during initialization")
// Consider what should happen if metadata is nil after GetMetadata returns no error
}
// updateMetadataEndpoints updates the relevant endpoint URL fields (jwksURL, authURL, tokenURL, etc.)
@@ -497,6 +499,8 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
if metadata != nil {
t.updateMetadataEndpoints(metadata)
t.logger.Debug("Successfully refreshed metadata")
} else {
t.logger.Error("Received nil metadata during refresh")
}
}
}
@@ -544,7 +548,7 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
if delay > maxDelay {
delay = maxDelay
}
l.Debugf("Failed to fetch provider metadata, retrying in %s", delay)
l.Debugf("Failed to fetch provider metadata (attempt %d/%d), retrying in %s. Error: %v", attempt+1, maxRetries, delay, err)
time.Sleep(delay)
}
@@ -568,64 +572,56 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("received nil response from provider")
return nil, fmt.Errorf("received nil response from provider at %s", wellKnownURL)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch provider metadata: status code %d", resp.StatusCode)
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to fetch provider metadata from %s: status code %d, body: %s", wellKnownURL, resp.StatusCode, string(bodyBytes))
}
var metadata ProviderMetadata
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
return nil, fmt.Errorf("failed to decode provider metadata: %w", err)
// Attempt to read body for better error context if decoding fails
// Note: resp.Body might be partially read by Decode, so read remaining
bodyBytes, readErr := io.ReadAll(io.MultiReader(json.NewDecoder(resp.Body).Buffered(), resp.Body))
if readErr != nil {
bodyBytes = []byte(fmt.Sprintf("(failed to read response body: %v)", readErr))
}
return nil, fmt.Errorf("failed to decode provider metadata from %s: %w. Response body: %s", wellKnownURL, err, string(bodyBytes))
}
return &metadata, nil
}
// ServeHTTP is the main entry point for incoming requests to the middleware.
// It orchestrates the OIDC authentication flow:
// 1. Waits for initial OIDC metadata discovery to complete (with timeout).
// 2. Checks if the request path is excluded from authentication.
// 3. Checks if the request is for Server-Sent Events and bypasses if so.
// 4. Retrieves the user's session; initiates authentication if the session is invalid/missing.
// 5. Handles specific paths for OIDC callback (/callback) and logout (/logout).
// 6. Checks the user's authentication status using isUserAuthenticated (verifies token, checks expiry).
// 7. If the token is expired, handles it (initiates re-auth).
// 8. If the user is not authenticated, initiates authentication.
// 9. If the token needs proactive refresh (nearing expiry), attempts refreshToken. Handles refresh failure
// by returning 401 for API clients or initiating re-auth for browsers.
// 10. If authenticated and token is valid, performs authorization checks (allowed domain, roles/groups).
// 11. If authorized, sets user/token information in request headers (X-Forwarded-User, X-Auth-Request-*)
// and adds security headers (X-Frame-Options, etc.) to the response.
// 12. Forwards the request to the next handler in the chain.
// It orchestrates the OIDC authentication flow.
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// --- Initialization Check ---
select {
case <-t.initComplete:
if t.issuerURL == "" {
t.logger.Error("OIDC provider metadata initialization failed")
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability", http.StatusServiceUnavailable)
if t.issuerURL == "" { // Check if initialization actually succeeded
t.logger.Error("OIDC provider metadata initialization failed or incomplete")
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability and configuration", http.StatusServiceUnavailable)
return
}
case <-req.Context().Done():
t.logger.Debug("Request cancelled")
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
t.logger.Debug("Request cancelled while waiting for OIDC initialization")
http.Error(rw, "Request cancelled", http.StatusRequestTimeout) // 408 might be more appropriate
return
case <-time.After(30 * time.Second):
case <-time.After(30 * time.Second): // Timeout for initialization
t.logger.Error("Timeout waiting for OIDC initialization")
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again", http.StatusServiceUnavailable)
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again later", http.StatusServiceUnavailable)
return
}
// Check if URL is excluded
// --- Excluded Paths & SSE Check ---
if t.determineExcludedURL(req.URL.Path) {
t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
t.next.ServeHTTP(rw, req)
return
}
// Check if the request expects Server-Sent Events
acceptHeader := req.Header.Get("Accept")
if strings.Contains(acceptHeader, "text/event-stream") {
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
@@ -633,80 +629,119 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
// Get session
// --- Session Retrieval ---
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
// Obtain a new session and clear any residual session cookies
session, _ = t.sessionManager.GetSession(req)
session.Clear(req, rw)
// Build redirect URL
// Log the specific session error
t.logger.Errorf("Error getting session: %v. Initiating authentication.", err)
// Attempt to get a new session to store CSRF etc.
session, _ = t.sessionManager.GetSession(req) // Ignore error here, proceed with new session
if session != nil {
// Pass rw to ensure expiring cookies are sent if possible
if clearErr := session.Clear(req, rw); clearErr != nil {
t.logger.Errorf("Error clearing potentially corrupted session: %v", clearErr)
}
} else {
// If even getting a new session fails, something is very wrong
t.logger.Error("Critical session error: Failed to get even a new session.")
http.Error(rw, "Critical session error", http.StatusInternalServerError)
return
}
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
// Initiate authentication
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
// Build redirect URL
// --- URL Handling (Callback, Logout) ---
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
redirectURL := buildFullURL(scheme, host, t.redirURLPath) // Used for callback and re-auth
// Handle special URLs
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req, redirectURL)
return
}
// Check authentication status
// --- Authentication & Refresh Logic ---
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
t.logger.Debug("Session token is definitively expired or invalid, initiating re-auth")
// handleExpiredToken clears the session and initiates auth
t.handleExpiredToken(rw, req, session, redirectURL)
return
}
if !authenticated {
// Original logic: Always initiate authentication if not authenticated
t.logger.Debug("User not authenticated, initiating OIDC flow")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
// If authenticated and token doesn't need proactive refresh, proceed directly
if authenticated && !needsRefresh {
t.logger.Debug("User authenticated and token valid, proceeding to process authorized request")
t.processAuthorizedRequest(rw, req, session, redirectURL)
return
}
// --- Attempt Refresh if Needed or Possible ---
// Conditions to attempt refresh:
// 1. Token needs proactive refresh (authenticated=true, needsRefresh=true)
// 2. Token is invalid/expired but a refresh token exists (authenticated=false, needsRefresh=true)
refreshTokenPresent := session.GetRefreshToken() != ""
shouldAttemptRefresh := needsRefresh && refreshTokenPresent
if shouldAttemptRefresh {
if needsRefresh && authenticated {
t.logger.Debug("Session token needs proactive refresh, attempting refresh")
} else if needsRefresh && !authenticated {
t.logger.Debug("Access token invalid/expired, but refresh token found. Attempting refresh.")
}
refreshed := t.refreshToken(rw, req, session)
if refreshed {
// Refresh succeeded, proceed to authorization checks
t.logger.Debug("Token refresh successful, proceeding to process authorized request")
t.processAuthorizedRequest(rw, req, session, redirectURL)
return
}
// Refresh failed
t.logger.Infof("Token refresh failed (authenticated=%v, needsRefresh=%v, refreshTokenPresent=%v)", authenticated, needsRefresh, refreshTokenPresent)
// Handle refresh failure (401 for API, re-auth for browser)
acceptHeader := req.Header.Get("Accept")
if strings.Contains(acceptHeader, "application/json") {
t.logger.Debug("Client accepts JSON, sending 401 Unauthorized on refresh failure")
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(rw).Encode(map[string]string{"error": "unauthorized", "message": "Token refresh failed"})
} else {
t.logger.Debug("Client does not prefer JSON, handling refresh failure by initiating re-auth")
// Use defaultInitiateAuthentication which clears the session properly
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
return // Stop processing
}
if needsRefresh {
refreshed := t.refreshToken(rw, req, session)
if !refreshed {
t.logger.Infof("Token refresh failed") // Changed from Warn to Infof
// Check if the client prefers JSON (likely an API call)
acceptHeader := req.Header.Get("Accept")
if strings.Contains(acceptHeader, "application/json") {
t.logger.Debug("Client accepts JSON, sending 401 Unauthorized on refresh failure")
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(rw).Encode(map[string]string{"error": "unauthorized", "message": "Token refresh failed"})
} else {
// Client likely a browser, initiate full re-authentication
t.logger.Debug("Client does not prefer JSON, handling refresh failure as expired token (initiating re-auth)")
t.handleExpiredToken(rw, req, session, redirectURL)
}
return // Stop processing
}
}
// --- Initiate Full Authentication ---
// If we reach here, it means:
// - User is not authenticated (!authenticated)
// - AND EITHER token doesn't need refresh (!needsRefresh, e.g., first visit)
// - OR refresh token is missing (!refreshTokenPresent)
// - OR refresh was attempted but failed (handled above)
t.logger.Debugf("Initiating full OIDC authentication flow (authenticated=%v, needsRefresh=%v, refreshTokenPresent=%v)", authenticated, needsRefresh, refreshTokenPresent)
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// Process authenticated request
// processAuthorizedRequest handles the final steps for an authenticated and authorized request.
// It performs domain/role/group checks, sets headers, and forwards the request.
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
email := session.GetEmail()
if email == "" {
t.logger.Debug("No email found in session")
t.logger.Error("CRITICAL: No email found in session during final processing, initiating re-auth")
// This case should ideally not happen if checks are done correctly before calling this,
// but as a safeguard, initiate re-authentication.
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
@@ -721,6 +756,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
// Continue without group/role headers if extraction fails
} else {
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
@@ -779,6 +815,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
// Process the request
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
t.next.ServeHTTP(rw, req)
}
@@ -794,19 +831,21 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// - session: The user's session data containing the expired token information.
// - redirectURL: The callback URL to be used in the new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Clear authentication data but preserve CSRF state
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
// Clear authentication data but preserve CSRF state if possible (though Clear might remove it)
session.SetAuthenticated(false)
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
// Save the cleared session state
// Save the cleared session state (this sends expired cookies)
// Pass rw to ensure expiring cookies are sent
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save cleared session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
t.logger.Errorf("Failed to save cleared session during expired token handling: %v", err)
// Still attempt to initiate authentication, but log the error
}
// Initiate a new authentication flow
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
@@ -834,8 +873,8 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
t.logger.Errorf("Session error during callback: %v", err)
http.Error(rw, "Session error during callback", http.StatusInternalServerError)
return
}
@@ -847,7 +886,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
if errorDescription == "" {
errorDescription = req.URL.Query().Get("error") // Use error code if description is empty
}
t.logger.Errorf("Authentication error from provider: %s - %s", req.URL.Query().Get("error"), errorDescription)
t.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
t.sendErrorResponse(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
return
}
@@ -862,13 +901,13 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Error("CSRF token missing in session")
t.sendErrorResponse(rw, req, "CSRF token missing", http.StatusBadRequest)
t.logger.Error("CSRF token missing in session during callback")
t.sendErrorResponse(rw, req, "CSRF token missing in session", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
t.logger.Error("State parameter does not match CSRF token in session during callback")
t.sendErrorResponse(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
return
}
@@ -886,22 +925,21 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
t.logger.Errorf("Failed to exchange code for token during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
return
}
// Verify tokens and claims
// Use the exported VerifyToken method now that handleCallback is in main.go
if err := t.VerifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
t.logger.Errorf("Failed to verify id_token during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.logger.Errorf("Failed to extract claims during callback: %v", err)
t.sendErrorResponse(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
return
}
@@ -909,51 +947,68 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
// Verify nonce to prevent replay attacks
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
t.logger.Error("Nonce claim missing in id_token during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session")
t.logger.Error("Nonce not found in session during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
t.logger.Error("Nonce claim does not match session nonce during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
return
}
// Validate user's email domain
// Use the unexported isAllowedDomain method now that handleCallback is in main.go
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
t.sendErrorResponse(rw, req, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
if email == "" {
t.logger.Errorf("Email claim missing or empty in token during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
return
}
if !t.isAllowedDomain(email) {
t.logger.Errorf("Disallowed email domain during callback: %s", email)
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
return
}
// Update session with authentication data
session.SetAuthenticated(true)
// Regenerate session ID upon successful authentication
if err := session.SetAuthenticated(true); err != nil {
t.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
http.Error(rw, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Clear CSRF, Nonce, CodeVerifier after use
session.SetCSRF("")
session.SetNonce("")
session.SetCodeVerifier("")
// Redirect to original path or root
// Retrieve original path *before* saving, as save might clear it if Clear was called concurrently
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
session.SetIncomingPath("") // Clear incoming path after retrieving it
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session after callback: %v", err)
http.Error(rw, "Failed to save session after callback", http.StatusInternalServerError)
return
}
// Redirect to original path or root
t.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
@@ -972,7 +1027,7 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
return true
}
}
t.logger.Debugf("URL is not excluded - got %s", currentRequest)
// t.logger.Debugf("URL is not excluded - got %s", currentRequest) // Too verbose for every request
return false
}
@@ -1024,32 +1079,58 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
// - expired (bool): True if the session is unauthenticated, the token is missing, or the token verification failed for reasons other than nearing/actual expiration (e.g., invalid signature, invalid claims).
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
t.logger.Debug("User is not authenticated according to session")
return false, false, false
t.logger.Debug("User is not authenticated according to session flag")
// Check if there's still a refresh token - if so, refresh might be possible
if session.GetRefreshToken() != "" {
t.logger.Debug("Session not authenticated, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated, NeedsRefresh=true (to attempt recovery), Expired=false
}
return false, false, false // Not authenticated, no refresh token, definitely not expired (just unauth)
}
accessToken := session.GetAccessToken()
if accessToken == "" {
t.logger.Debug("No access token found in session")
return false, false, true // Session is invalid, consider it expired
t.logger.Debug("Authenticated flag set, but no access token found in session")
// If authenticated flag is true but token is missing, treat as expired/invalid session state
// Check for refresh token before declaring fully expired
if session.GetRefreshToken() != "" {
t.logger.Debug("Authenticated flag set, access token missing, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated (no access token), NeedsRefresh=true, Expired=false
}
return false, false, true // No access or refresh token, treat as expired
}
// Verify the token structure and signature first
jwt, err := parseJWT(accessToken)
if err != nil {
t.logger.Errorf("Failed to parse JWT during auth check: %v", err)
return false, false, true // Invalid format, treat as expired/invalid
// Check for refresh token before declaring fully expired
if session.GetRefreshToken() != "" {
t.logger.Debug("Access token parsing failed, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false
}
return false, false, true // Invalid format, no refresh token, treat as expired/invalid
}
if err := t.VerifyJWTSignatureAndClaims(jwt, accessToken); err != nil {
// Check if the error is specifically about expiration
if strings.Contains(err.Error(), "token has expired") {
t.logger.Debugf("Token signature/claims valid but token expired, attempting refresh")
t.logger.Debugf("Access token signature/claims valid but token expired, needs refresh")
// Token is expired but otherwise valid, signal for refresh
return true, true, false // Authenticated=true (was valid), NeedsRefresh=true, Expired=false (because refresh is possible)
// Return authenticated=false because the current token is unusable
// NeedsRefresh is true only if a refresh token exists
if session.GetRefreshToken() != "" {
return false, true, false // Not authenticated (current token unusable), NeedsRefresh=true, Expired=false (because refresh might fix it)
}
return false, false, true // Expired access token, no refresh token, treat as expired
}
// Other verification error (signature, issuer, audience etc.)
t.logger.Errorf("Token verification failed (non-expiration): %v", err)
return false, false, true // Token is invalid for other reasons
t.logger.Errorf("Access token verification failed (non-expiration): %v", err)
// Check for refresh token before declaring fully expired
if session.GetRefreshToken() != "" {
t.logger.Debug("Access token verification failed, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false
}
return false, false, true // Token is invalid for other reasons, no refresh token, treat as expired/invalid session
}
// Claims already parsed within VerifyJWTSignatureAndClaims if it didn't error early
@@ -1057,8 +1138,13 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
expClaim, ok := claims["exp"].(float64)
if !ok {
t.logger.Error("Failed to get expiration time from claims")
return false, false, true
t.logger.Error("Failed to get expiration time ('exp' claim) from verified token")
// Check for refresh token before declaring fully expired
if session.GetRefreshToken() != "" {
t.logger.Debug("Access token missing 'exp' claim, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated (bad access token), NeedsRefresh=true, Expired=false
}
return false, false, true // Treat as invalid if 'exp' is missing and no refresh token
}
expTime := int64(expClaim)
@@ -1071,12 +1157,19 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
if time.Unix(expTime, 0).Before(time.Now().Add(t.refreshGracePeriod)) {
// Recalculate remaining seconds for logging clarity if needed, using the configured duration
remainingSeconds := int64(time.Until(time.Unix(expTime, 0)).Seconds())
t.logger.Debugf("Token nearing expiration (expires in %d seconds, grace period %s), scheduling refresh", remainingSeconds, t.refreshGracePeriod)
return true, true, false // Needs proactive refresh
t.logger.Debugf("Access token nearing expiration (expires in %d seconds, grace period %s), scheduling proactive refresh", remainingSeconds, t.refreshGracePeriod)
// Token is still valid, but we should refresh it soon
// NeedsRefresh is true only if a refresh token exists
if session.GetRefreshToken() != "" {
return true, true, false // Authenticated=true (current token usable), NeedsRefresh=true, Expired=false
}
// If no refresh token, we can't proactively refresh, treat as normal valid token for now
t.logger.Debugf("Token nearing expiration but no refresh token available, cannot proactively refresh.")
return true, false, false
}
// Token is valid, not expired, and not nearing expiration
return true, false, false
return true, false, false // Authenticated=true, NeedsRefresh=false, Expired=false
}
// defaultInitiateAuthentication handles the process of starting an OIDC authentication flow.
@@ -1092,10 +1185,12 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
// - session: The user's SessionData object (potentially new or cleared).
// - redirectURL: The pre-calculated callback URL (redirect_uri) for this middleware instance.
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
// Generate CSRF token and nonce
csrfToken := uuid.NewString()
nonce, err := generateNonce()
if err != nil {
t.logger.Errorf("Failed to generate nonce: %v", err)
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
@@ -1106,37 +1201,41 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
var err error
codeVerifier, err = generateCodeVerifier()
if err != nil {
t.logger.Errorf("Failed to generate code verifier: %v", err)
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
return
}
// Derive code challenge from verifier
codeChallenge = deriveCodeChallenge(codeVerifier)
t.logger.Debugf("PKCE enabled, generated code challenge")
}
// Clear any existing session data to avoid stale state causing redirect loops
session.Clear(req, rw)
// Pass the response writer to ensure expiring cookies are sent
if err := session.Clear(req, rw); err != nil {
// Log the error but continue, as clearing is best-effort before re-auth
t.logger.Errorf("Error clearing session before initiating authentication: %v", err)
}
// Set new session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
// Only set code verifier if PKCE is enabled
if t.enablePKCE {
session.SetCodeVerifier(codeVerifier)
}
// Store the original path the user was trying to access
session.SetIncomingPath(req.URL.RequestURI())
t.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI())
// Save the session
// Save the session (to store CSRF, Nonce, etc.)
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
t.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Build and redirect to authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
t.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
http.Redirect(rw, req, authURL, http.StatusFound)
}
@@ -1181,10 +1280,37 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
params.Set("code_challenge_method", "S256")
}
if len(t.scopes) > 0 {
params.Set("scope", strings.Join(t.scopes, " "))
// Handle scopes - ensure offline_access is included for refresh tokens
scopes := make([]string, len(t.scopes))
copy(scopes, t.scopes)
// Check if we're dealing with a Google OIDC provider
isGoogleProvider := strings.Contains(t.issuerURL, "google") || strings.Contains(t.issuerURL, "accounts.google.com")
// Add offline_access scope if it's missing
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
}
if len(scopes) > 0 {
params.Set("scope", strings.Join(scopes, " "))
}
// Add prompt=consent for Google to ensure refresh token is issued
if isGoogleProvider {
params.Set("prompt", "consent")
t.logger.Debug("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
}
// Use buildURLWithParams which handles potential relative authURL from metadata
return t.buildURLWithParams(t.authURL, params)
}
@@ -1201,17 +1327,30 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) string {
// Ensure URL is absolute
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
// Extract issuer base URL
issuerURL, err := url.Parse(t.issuerURL)
// Attempt to resolve relative URL against issuer URL
issuerURLParsed, err := url.Parse(t.issuerURL)
if err == nil {
return fmt.Sprintf("%s://%s%s?%s",
issuerURL.Scheme,
issuerURL.Host,
baseURL,
params.Encode())
baseURLParsed, err := url.Parse(baseURL)
if err == nil {
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
resolvedURL.RawQuery = params.Encode()
return resolvedURL.String()
}
}
// Fallback if parsing fails - append params to potentially relative path
t.logger.Errorf("Could not parse issuerURL or baseURL to resolve relative URL. BaseURL: %s, IssuerURL: %s", baseURL, t.issuerURL)
return baseURL + "?" + params.Encode()
}
return baseURL + "?" + params.Encode()
// If baseURL is already absolute
u, err := url.Parse(baseURL)
if err != nil {
t.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err)
// Fallback: append params directly
return baseURL + "?" + params.Encode()
}
u.RawQuery = params.Encode()
return u.String()
}
// startTokenCleanup starts background goroutines for periodically cleaning up
@@ -1246,6 +1385,7 @@ func (t *TraefikOidc) RevokeToken(token string) {
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
// Use Set with a duration. Value 'true' is arbitrary, we only care about existence.
t.tokenBlacklist.Set(token, true, time.Until(expiry))
t.logger.Debugf("Locally revoked token (added to blacklist)")
}
// RevokeTokenWithProvider attempts to revoke a token directly with the OIDC provider
@@ -1260,7 +1400,10 @@ func (t *TraefikOidc) RevokeToken(token string) {
// - nil if the revocation request is successful (provider returns 200 OK).
// - An error if the request fails or the provider returns a non-OK status.
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
t.logger.Debugf("Revoking token with provider")
if t.revocationURL == "" {
return fmt.Errorf("token revocation endpoint is not configured or discovered")
}
t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, t.revocationURL)
data := url.Values{
"token": {token},
@@ -1277,6 +1420,7 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
// Set headers
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json") // Prefer JSON response if available
// Send the request
resp, err := t.httpClient.Do(req)
@@ -1288,18 +1432,20 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
// Check the response
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("token revocation failed with status %d: %s", resp.StatusCode, string(body))
// Log the failure details
t.logger.Errorf("Token revocation failed with status %d: %s", resp.StatusCode, string(body))
return fmt.Errorf("token revocation failed with status %d", resp.StatusCode)
}
t.logger.Debugf("Token successfully revoked")
t.logger.Debugf("Token successfully revoked with provider")
return nil
}
// refreshToken attempts to use the refresh token stored in the session to obtain a new set of tokens.
// It acquires a mutex associated with the session to prevent concurrent refresh attempts for the same session.
// It retrieves the refresh token, calls the TokenExchanger's GetNewTokenWithRefreshToken method,
// verifies the newly obtained ID token using verifyToken, updates the session with the new tokens,
// and saves the session.
// verifies the newly obtained ID token using verifyToken, performs a concurrency check,
// updates the session with the new tokens if the check passes, and saves the session.
//
// Parameters:
// - rw: The HTTP response writer (needed for saving the updated session).
@@ -1308,45 +1454,134 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
//
// Returns:
// - true if the token refresh was successful and the session was updated.
// - false if no refresh token was found, the refresh exchange failed, the new token failed verification, or saving the session failed.
// - false if no refresh token was found, the refresh exchange failed, the new token failed verification,
// a concurrency conflict was detected, or saving the session failed.
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
// Lock the mutex specific to this session instance before attempting refresh
session.refreshMutex.Lock()
defer session.refreshMutex.Unlock()
t.logger.Debug("Attempting to refresh token (mutex acquired)")
refreshToken := session.GetRefreshToken() // Get token *after* acquiring lock
if refreshToken == "" {
t.logger.Debug("No refresh token found in session (inside lock)")
initialRefreshToken := session.GetRefreshToken() // Get token *after* acquiring lock
if initialRefreshToken == "" {
t.logger.Errorf("refreshToken failed: No refresh token found in session (after acquiring lock)")
return false
}
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
// Detect if we're using Google's OIDC provider
isGoogleProvider := strings.Contains(t.issuerURL, "google") || strings.Contains(t.issuerURL, "accounts.google.com")
if isGoogleProvider {
t.logger.Debug("Google OIDC provider detected for token refresh operation")
}
// Log the attempt with a truncated token for security
tokenPrefix := initialRefreshToken
if len(initialRefreshToken) > 10 {
tokenPrefix = initialRefreshToken[:10]
}
t.logger.Debugf("Attempting refresh with token starting with %s...", tokenPrefix)
// Attempt to refresh the token
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
if err != nil {
// Log the error, potentially clear the invalid refresh token?
t.logger.Errorf("Failed to refresh token using refresh token: %v", err)
// Consider clearing the refresh token from the session here if the error indicates it's invalid
// session.SetRefreshToken("") // Example: Clear potentially invalid token
// session.Save(req, rw) // Need to handle potential save error
// Log detailed error information
t.logger.Errorf("refreshToken failed: Error from token refresh operation: %v", err)
// Check for specific error patterns
errMsg := err.Error()
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
t.logger.Errorf("Refresh token appears to be expired or revoked: %v", err)
// Don't keep trying with an invalid refresh token
session.SetRefreshToken("")
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to remove invalid refresh token from session: %v", err)
}
} else if strings.Contains(errMsg, "invalid_client") {
t.logger.Errorf("Client credentials rejected: %v - check client_id and client_secret configuration", err)
} else if isGoogleProvider && strings.Contains(errMsg, "invalid_request") {
t.logger.Errorf("Google OIDC provider error: %v - check scope configuration includes 'offline_access' and prompt=consent is used during authentication", err)
}
return false
}
// Verify the new access token
// Handle potentially missing tokens in the response
if newToken.IDToken == "" {
t.logger.Errorf("refreshToken failed: Provider did not return a new ID token")
return false
}
// Verify the new access token (ID token)
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new access token: %v", err)
truncatedNewToken := newToken.IDToken
if len(newToken.IDToken) > 10 {
truncatedNewToken = newToken.IDToken[:10]
}
t.logger.Errorf("refreshToken failed: Failed to verify newly obtained ID token starting with %s...: %v", truncatedNewToken, err)
return false
}
// Update session with new tokens
// --- Concurrency Check ---
// Before saving the new token, check if the session state (specifically the refresh token)
// has been modified concurrently (e.g., by a logout or another auth initiation).
currentRefreshToken := session.GetRefreshToken() // Get token again *after* the potentially long exchange
if initialRefreshToken != currentRefreshToken {
// Use Infof as Warnf doesn't exist
t.logger.Infof("refreshToken aborted: Session refresh token changed concurrently during refresh attempt.")
// Do not save the new tokens, as the session state is likely invalid/cleared.
return false // Indicate refresh failure due to concurrency conflict
}
// --- End Concurrency Check ---
// Update session with new tokens ONLY if the concurrency check passed
t.logger.Debugf("Concurrency check passed. Updating session with new tokens.")
// Extract email from the new token and update session
claims, err := t.extractClaimsFunc(newToken.IDToken)
if err != nil {
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
return false // Cannot proceed without claims
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
return false // Cannot proceed without email
}
session.SetEmail(email) // Update email in session
// Get token expiry information for logging
var expiryTime time.Time
if expClaim, ok := claims["exp"].(float64); ok {
expiryTime = time.Unix(int64(expClaim), 0)
t.logger.Debugf("New token expires at: %v (in %v)", expiryTime, time.Until(expiryTime))
}
// Set the new access token
session.SetAccessToken(newToken.IDToken)
session.SetRefreshToken(newToken.RefreshToken)
// Handle the refresh token
if newToken.RefreshToken != "" {
t.logger.Debug("Received new refresh token from provider")
session.SetRefreshToken(newToken.RefreshToken)
} else {
// If no new refresh token is returned, keep the existing one
t.logger.Debug("Provider did not return a new refresh token, keeping the existing one")
session.SetRefreshToken(initialRefreshToken)
}
// Ensure authenticated flag is set
if err := session.SetAuthenticated(true); err != nil {
t.logger.Errorf("refreshToken warning: Failed to set authenticated flag: %v", err)
// Continue anyway since we have valid tokens
}
// Save the session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save refreshed session: %v", err)
t.logger.Errorf("refreshToken failed: Failed to save session after successful token refresh: %v", err)
return false
}
t.logger.Debugf("Token refresh successful and session saved")
return true
}
@@ -1367,6 +1602,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
parts := strings.Split(email, "@")
if len(parts) != 2 {
t.logger.Errorf("Invalid email format encountered: %s", email)
return false // Invalid email format
}
@@ -1400,12 +1636,16 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
if groupsClaim, exists := claims["groups"]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
// Strictly expect an array
return nil, nil, fmt.Errorf("groups claim is not an array")
}
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
} else {
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
} else {
t.logger.Errorf("Non-string value found in groups claim array: %v", group)
}
}
}
}
@@ -1414,12 +1654,16 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
if rolesClaim, exists := claims["roles"]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
// Strictly expect an array
return nil, nil, fmt.Errorf("roles claim is not an array")
}
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
} else {
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
} else {
t.logger.Errorf("Non-string value found in roles claim array: %v", role)
}
}
}
}
@@ -1493,8 +1737,8 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
rw.WriteHeader(code)
// Use a simple error structure
json.NewEncoder(rw).Encode(map[string]interface{}{
"error": http.StatusText(code),
"error_description": message,
"error": http.StatusText(code), // Use standard text for the code
"error_description": message, // Provide specific detail here
"status_code": code,
})
return
@@ -1504,17 +1748,9 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
t.logger.Debugf("Sending HTML error response (code %d): %s", code, message)
// Determine the return URL (mostly relevant for HTML)
returnURL := "/" // Default to root
session, err := t.sessionManager.GetSession(req) // Attempt to get session for return URL
if err == nil {
incomingPath := session.GetIncomingPath()
// Use incoming path if it's valid and not one of the special OIDC paths
if incomingPath != "" && incomingPath != t.redirURLPath && incomingPath != t.logoutURLPath {
returnURL = incomingPath
}
} else {
t.logger.Infof("Could not get session to determine return URL in sendErrorResponse: %v", err)
}
returnURL := "/" // Default to root
// No need to get session here, as we are already in an error path
// where session might be invalid or unavailable.
// Basic HTML structure for the error page
htmlBody := fmt.Sprintf(`
@@ -1537,7 +1773,7 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
<p><a href="%s">Return to application</a></p>
</div>
</body>
</html>`, message, returnURL)
</html>`, message, returnURL) // Use default returnURL
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.WriteHeader(code)
+71 -17
View File
@@ -385,9 +385,52 @@ func TestServeHTTP(t *testing.T) {
expectedBody: "OK",
},
{
name: "Unauthenticated request to protected URL",
requestPath: "/protected",
expectedStatus: http.StatusFound, // Expect redirect to OIDC
name: "Unauthenticated request (no refresh token) to protected URL",
requestPath: "/protected",
setupSession: func(session *SessionData) {
// Ensure no tokens are set
session.SetAuthenticated(false)
session.SetAccessToken("")
session.SetRefreshToken("")
},
expectedStatus: http.StatusFound, // Expect redirect to OIDC as there's no refresh token
},
{
name: "Unauthenticated request (with refresh token) to protected URL - Expect Refresh Attempt",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(false) // Not authenticated
session.SetAccessToken("") // No access token
session.SetRefreshToken("valid-refresh-token-for-unauth-test") // BUT has refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
if refreshToken != "valid-refresh-token-for-unauth-test" {
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
}
// Simulate successful refresh
newToken := createNewValidToken() // Use helper from TestServeHTTP
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-unauth", ExpiresIn: 3600}, nil
}
},
expectedStatus: http.StatusOK, // Expect OK after successful refresh
expectedBody: "OK",
},
{
name: "Unauthenticated request (with refresh token) to protected URL - Refresh Fails",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(false) // Not authenticated
session.SetAccessToken("") // No access token
session.SetRefreshToken("invalid-refresh-token-for-unauth-test") // Invalid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
// Simulate failed refresh
return nil, fmt.Errorf("mock error: refresh token invalid")
}
},
expectedStatus: http.StatusFound, // Expect redirect to OIDC after failed refresh
},
{
name: "Authenticated request to protected URL (Valid Token)",
@@ -407,11 +450,15 @@ func TestServeHTTP(t *testing.T) {
expectedStatus: http.StatusOK,
expectedBody: "OK",
},
// This test case remains valid as the logic should still attempt refresh when expired token + refresh token exist
{
name: "Authenticated request with expired token and successful refresh",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Still marked authenticated initially
// NOTE: isUserAuthenticated now returns authenticated=false if access token is expired,
// even if session.SetAuthenticated(true) was called.
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken()) // Set expired token
session.SetRefreshToken("valid-refresh-token") // Set valid refresh token
@@ -445,16 +492,19 @@ func TestServeHTTP(t *testing.T) {
t.Fatalf("Failed to get session after request: %v", err)
}
// Assert new tokens are in the session
// Direct comparison with createNewValidToken() is flawed as it generates a new token each time.
// Instead, check if the token was updated (not empty) and verify the refresh token.
if session.GetAccessToken() == "" {
t.Errorf("Expected access token to be updated in session, but it was empty")
if session.GetAccessToken() == "" || session.GetAccessToken() == createExpiredToken() {
t.Errorf("Expected access token to be updated in session, but it was empty or still the expired one")
}
if session.GetRefreshToken() != "new-refresh-token" {
t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken())
}
// Also check authenticated flag is now true
if !session.GetAuthenticated() {
t.Errorf("Expected session to be marked authenticated after successful refresh")
}
},
},
// This test case remains valid as the logic should still return 401 for API clients on refresh failure
{
name: "Logout URL",
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
@@ -477,10 +527,10 @@ func TestServeHTTP(t *testing.T) {
name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken())
session.SetRefreshToken("valid-refresh-token")
session.SetAccessToken(createExpiredToken()) // Expired access token
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
@@ -491,17 +541,18 @@ func TestServeHTTP(t *testing.T) {
requestHeaders: map[string]string{
"Accept": "application/json",
},
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client after failed refresh attempt
expectedBody: `{"error":"unauthorized","message":"Token refresh failed"}`,
},
// This test case remains valid as the logic should still redirect browser clients on refresh failure
{
name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken())
session.SetRefreshToken("valid-refresh-token")
session.SetAccessToken(createExpiredToken()) // Expired access token
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
@@ -512,8 +563,9 @@ func TestServeHTTP(t *testing.T) {
requestHeaders: map[string]string{
"Accept": "text/html", // Browser client
},
expectedStatus: http.StatusFound, // Expect redirect for browser client
expectedStatus: http.StatusFound, // Expect redirect to OIDC for browser client after failed refresh attempt
},
// This test case remains valid as proactive refresh should still be attempted
{
name: "Authenticated request with token nearing expiry (needs refresh)",
requestPath: "/protected",
@@ -529,7 +581,7 @@ func TestServeHTTP(t *testing.T) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(nearExpiryToken)
session.SetRefreshToken("valid-refresh-token-for-near-expiry")
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
@@ -544,6 +596,7 @@ func TestServeHTTP(t *testing.T) {
expectedStatus: http.StatusOK, // Expect success after proactive refresh
expectedBody: "OK",
},
// This test case remains valid as no refresh should be attempted
{
name: "Authenticated request with token valid (outside grace period)",
requestPath: "/protected",
@@ -1531,6 +1584,7 @@ func TestRevokeToken(t *testing.T) {
tOidc := &TraefikOidc{
tokenBlacklist: NewCache(), // Use generic cache for blacklist
tokenCache: NewTokenCache(),
logger: NewLogger("info"), // Initialize the logger
}
// Cache the token