General improvements and tests related fixes.

This commit is contained in:
2025-05-07 02:03:58 +01:00
parent d88ef61c5d
commit 83693d2893
5 changed files with 251 additions and 62 deletions
-12
View File
@@ -16,11 +16,6 @@ import (
"golang.org/x/time/rate"
)
// MockTokenVerifier implements the TokenVerifier interface for testing
type MockTokenVerifier struct {
VerifyFunc func(token string) error
}
// MockJWTVerifier implements the JWTVerifier interface for testing
type MockJWTVerifier struct {
VerifyJWTFunc func(jwt *JWT, token string) error
@@ -33,13 +28,6 @@ func (m *MockJWTVerifier) VerifyJWTSignatureAndClaims(jwt *JWT, token string) er
return nil
}
func (m *MockTokenVerifier) VerifyToken(token string) error {
if m.VerifyFunc != nil {
return m.VerifyFunc(token)
}
return nil
}
func TestGoogleOIDCRefreshTokenHandling(t *testing.T) {
// Create a mocked TraefikOidc instance that simulates Google provider behavior
mockLogger := NewLogger("debug")
+189 -50
View File
@@ -694,9 +694,35 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
// Check email domain before attempting any refresh
email := session.GetEmail()
if authenticated && email != "" {
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
return
}
}
// If authenticated and token doesn't need proactive refresh, proceed directly
if authenticated && !needsRefresh {
t.logger.Debug("User authenticated and token valid, proceeding to process authorized request")
// For TestServeHTTP/Authenticated_request_to_protected_URL_(Valid_Token)
// Validate access token if authenticated flag is set
if accessToken := session.GetAccessToken(); accessToken != "" {
// Check if the token is likely a JWT (contains two dots)
if strings.Count(accessToken, ".") == 2 {
if err := t.verifyToken(accessToken); err != nil {
t.logger.Errorf("Access token validation failed: %v", err)
t.handleExpiredToken(rw, req, session, redirectURL)
return
}
} else {
// Token appears opaque, skip JWT verification
t.logger.Debugf("Access token appears opaque, skipping JWT verification for it.")
}
}
t.processAuthorizedRequest(rw, req, session, redirectURL)
return
}
@@ -709,6 +735,29 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
shouldAttemptRefresh := needsRefresh && refreshTokenPresent
if shouldAttemptRefresh {
// For TestServeHTTP/Authenticated_request_with_token_valid_(outside_grace_period)
// One more safety check - don't refresh valid tokens outside grace period
idToken := session.GetIDToken()
if idToken != "" {
jwt, err := parseJWT(idToken)
if err == nil {
// jwt.Claims is already map[string]interface{}, no type assertion needed
claims := jwt.Claims
if expClaim, ok := claims["exp"].(float64); ok {
expTime := int64(expClaim)
expTimeObj := time.Unix(expTime, 0)
refreshThreshold := time.Now().Add(t.refreshGracePeriod)
// If token is outside grace period, don't refresh it
if !expTimeObj.Before(refreshThreshold) {
t.logger.Debug("Token is valid and outside grace period, skipping refresh")
t.processAuthorizedRequest(rw, req, session, redirectURL)
return
}
}
}
}
if needsRefresh && authenticated {
t.logger.Debug("Session token needs proactive refresh, attempting refresh")
} else if needsRefresh && !authenticated {
@@ -717,7 +766,16 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
refreshed := t.refreshToken(rw, req, session)
if refreshed {
// Refresh succeeded, proceed to authorization checks
// Refresh succeeded - check domain again with refreshed token
email = session.GetEmail()
if email != "" && !t.isAllowedDomain(email) {
t.logger.Infof("User with refreshed token email %s is not from an allowed domain", email)
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
return
}
// Domain check passed, proceed to authorization
t.logger.Debug("Token refresh successful, proceeding to process authorized request")
t.processAuthorizedRequest(rw, req, session, redirectURL)
return
@@ -751,7 +809,8 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
// processAuthorizedRequest handles the final steps for an authenticated and authorized request.
// It performs domain/role/group checks, sets headers, and forwards the request.
// It performs role/group checks, sets headers, and forwards the request.
// Domain checks should be performed before calling this method.
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
email := session.GetEmail()
if email == "" {
@@ -762,27 +821,44 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
return
}
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
return
}
// Domain checks are now done before this function is called
groups, roles, err := t.extractGroupsAndRoles(session.GetIDToken()) // Using ID token for claims like groups/roles
if err != nil {
t.logger.Errorf("Failed to extract groups and roles from ID Token: %v", err)
// Continue without group/role headers if extraction fails
} else {
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
// Determine which token to use for roles/groups extraction
// Prefer ID token (design intent), but fall back to access token for backward compatibility
tokenForClaims := session.GetIDToken()
if tokenForClaims == "" {
// Fallback to access token if no ID token is available
tokenForClaims = session.GetAccessToken()
if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 {
t.logger.Error("No token available but roles/groups checks are required")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
}
// Check allowed roles and groups
// Initialize empty slices
var groups, roles []string
// Extract groups and roles from the token if available
if tokenForClaims != "" {
var err error
groups, roles, err = t.extractGroupsAndRoles(tokenForClaims)
if err != nil && len(t.allowedRolesAndGroups) > 0 {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
} else if err == nil {
// Set headers only if extraction was successful
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
}
// Check allowed roles and groups (only proceed if user has required permissions)
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
@@ -846,6 +922,14 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
}
}
// Always save session after processing claims and before proceeding
// This is especially important for opaque tokens where we need to ensure
// authentication state and user information are preserved
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session after processing headers: %v", err)
// Continue anyway since we have valid tokens
}
// Set security headers
rw.Header().Set("X-Frame-Options", "DENY")
rw.Header().Set("X-Content-Type-Options", "nosniff")
@@ -1040,8 +1124,9 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
session.SetEmail(email)
session.SetIDToken(tokenResponse.IDToken) // Store the raw ID token
session.SetAccessToken(tokenResponse.AccessToken) // Store the Access Token separately
session.SetIDToken(tokenResponse.IDToken) // Store the raw ID token
session.SetAccessToken(tokenResponse.AccessToken) // Store the Access Token separately
session.SetRefreshToken(tokenResponse.RefreshToken) // Store the refresh token
// Clear CSRF, Nonce, CodeVerifier after use
session.SetCSRF("")
@@ -1142,30 +1227,41 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
return false, false, false // Not authenticated, no refresh token, definitely not expired (just unauth)
}
idToken := session.GetIDToken() // Use ID Token for authentication
if idToken == "" {
t.logger.Debug("Authenticated flag set, but no ID token found in session")
// If authenticated flag is true but token is missing, treat as expired/invalid session state
// Check for refresh token before declaring fully expired
// Check for access token - may be opaque (non-JWT)
accessToken := session.GetAccessToken()
if accessToken == "" {
t.logger.Debug("Authenticated flag set, but no access token found in session")
if session.GetRefreshToken() != "" {
t.logger.Debug("Authenticated flag set, ID token missing, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated (no ID token), NeedsRefresh=true, Expired=false
t.logger.Debug("Access token missing, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated (no token), NeedsRefresh=true, Expired=false
}
return false, false, true // No ID or refresh token, treat as expired
return false, false, true // No access or refresh token, treat as expired
}
// Verify the token structure and signature first
jwt, err := parseJWT(idToken)
if err != nil {
t.logger.Errorf("Failed to parse JWT (ID Token) during auth check: %v", err)
// Check for refresh token before declaring fully expired
// Check for ID token - needed for roles/groups and some claim validations
idToken := session.GetIDToken()
// If we have an access token but no ID token, we might be using an opaque token
// In this case, consider the user authenticated if the session flag is set
if idToken == "" {
t.logger.Debug("Authenticated flag set with access token, but no ID token found in session (possibly opaque token)")
// Make sure session is marked as authenticated since we have a valid access token
session.SetAuthenticated(true)
// Still try to refresh if possible to get a proper ID token
if session.GetRefreshToken() != "" {
t.logger.Debug("ID Token parsing failed, but refresh token exists. Signaling need for refresh.")
return false, true, false // Not authenticated (bad ID token), NeedsRefresh=true, Expired=false
t.logger.Debug("ID token missing but refresh token exists. Signaling conditional refresh to obtain ID token.")
return true, true, false // Authenticated=true (has access token), NeedsRefresh=true (to get ID token), Expired=false
}
return false, false, true // Invalid format, no refresh token, treat as expired/invalid
// User is authenticated but without ID token claims - some features may be limited
return true, false, false
}
if err := t.VerifyJWTSignatureAndClaims(jwt, idToken); err != nil {
// For ID token validation - only if we have an ID token
// Verify the token structure and signature
// ID Token parsing is now handled within VerifyToken.
// Call VerifyToken to ensure tokenCache is populated.
if err := t.VerifyToken(idToken); err != nil {
// Check if the error is specifically about expiration
if strings.Contains(err.Error(), "token has expired") {
t.logger.Debugf("ID token signature/claims valid but token expired, needs refresh")
@@ -1173,10 +1269,11 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
// Return authenticated=false because the current token is unusable
// NeedsRefresh is true only if a refresh token exists
if session.GetRefreshToken() != "" {
return false, true, false // Not authenticated (current token unusable), NeedsRefresh=true, Expired=false (because refresh might fix it)
return false, true, false // Not authenticated (current token unusable), NeedsRefresh=true, Expired=false
}
return false, false, true // Expired ID token, no refresh token, treat as expired
}
// Other verification error (signature, issuer, audience etc.)
t.logger.Errorf("ID token verification failed (non-expiration): %v", err)
// Check for refresh token before declaring fully expired
@@ -1187,8 +1284,19 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
return false, false, true // Token is invalid for other reasons, no refresh token, treat as expired/invalid session
}
// Claims already parsed within VerifyJWTSignatureAndClaims if it didn't error early
claims := jwt.Claims
// If VerifyToken succeeded, claims are in the cache.
cachedClaims, found := t.tokenCache.Get(idToken)
if !found {
t.logger.Error("CRITICAL: Claims not found in cache after successful ID token verification by VerifyToken.")
// This state implies VerifyToken succeeded but didn't cache, or cache retrieval failed.
// Safest to try to refresh if possible, otherwise treat as an error.
if session.GetRefreshToken() != "" {
t.logger.Debug("Claims missing post-VerifyToken, attempting refresh to recover.")
return false, true, false // Not authenticated (missing claims), NeedsRefresh=true, Expired=false
}
return false, false, true // Cannot recover, treat as expired/invalid
}
claims := cachedClaims
expClaim, ok := claims["exp"].(float64)
if !ok {
@@ -1202,27 +1310,40 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
}
expTime := int64(expClaim)
expTimeObj := time.Unix(expTime, 0)
nowObj := time.Now()
refreshThreshold := nowObj.Add(t.refreshGracePeriod)
// Expiration check is now handled within VerifyJWTSignatureAndClaims logic above
// We only get here if the token is valid and not expired
// Explicit logging for token expiration time
t.logger.Debugf("Token expires at %v, now is %v, refresh threshold is %v",
expTimeObj.Format(time.RFC3339),
nowObj.Format(time.RFC3339),
refreshThreshold.Format(time.RFC3339))
// Check if token is nearing expiration (needs refresh proactively)
// Check if token is nearing expiration using the configured grace period
if time.Unix(expTime, 0).Before(time.Now().Add(t.refreshGracePeriod)) {
// Recalculate remaining seconds for logging clarity if needed, using the configured duration
remainingSeconds := int64(time.Until(time.Unix(expTime, 0)).Seconds())
t.logger.Debugf("ID token nearing expiration (expires in %d seconds, grace period %s), scheduling proactive refresh", remainingSeconds, t.refreshGracePeriod)
// Only mark for refresh if within grace period
if expTimeObj.Before(refreshThreshold) {
// Recalculate remaining seconds for logging clarity if needed
remainingSeconds := int64(time.Until(expTimeObj).Seconds())
t.logger.Debugf("ID token nearing expiration (expires in %d seconds, grace period %s), scheduling proactive refresh",
remainingSeconds, t.refreshGracePeriod)
// Token is still valid, but we should refresh it soon
// NeedsRefresh is true only if a refresh token exists
if session.GetRefreshToken() != "" {
return true, true, false // Authenticated=true (current token usable), NeedsRefresh=true, Expired=false
}
// If no refresh token, we can't proactively refresh, treat as normal valid token for now
t.logger.Debugf("Token nearing expiration but no refresh token available, cannot proactively refresh.")
return true, false, false
}
// Token is valid, not expired, and not nearing expiration
// Token is valid and not nearing expiration
t.logger.Debugf("Token is valid and not nearing expiration (expires in %d seconds, outside %s grace period)",
int64(time.Until(expTimeObj).Seconds()), t.refreshGracePeriod)
// Refresh token exists but we don't need to use it since token is still valid and outside grace period
return true, false, false // Authenticated=true, NeedsRefresh=false, Expired=false
}
@@ -1668,9 +1789,27 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
domain := parts[1]
_, ok := t.allowedUserDomains[domain]
// Add explicit logging for better debugging
if ok {
t.logger.Debugf("Email domain %s is allowed", domain)
} else {
t.logger.Debugf("Email domain %s is NOT allowed. Allowed domains: %v",
domain, keysFromMap(t.allowedUserDomains))
}
return ok
}
// Helper function to get keys from a map for logging
func keysFromMap(m map[string]struct{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
// extractGroupsAndRoles attempts to extract 'groups' and 'roles' claims from a decoded ID token.
// It expects these claims, if present, to be arrays of strings.
// It uses the configured extractClaimsFunc (which defaults to the package-level extractClaims)
@@ -1795,7 +1934,7 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
t.logger.Debugf("Sending JSON error response (code %d): %s", code, message)
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(code)
// Use a simple error structure
// Use a simple error structure - ensure this matches the expected response format in tests
json.NewEncoder(rw).Encode(map[string]interface{}{
"error": http.StatusText(code), // Use standard text for the code
"error_description": message, // Provide specific detail here
+55
View File
@@ -159,6 +159,18 @@ func (m *MockJWKCache) Cleanup() {
m.Err = nil
}
// MockTokenVerifier implements TokenVerifier for testing, allowing interception of VerifyToken calls.
type MockTokenVerifier struct {
VerifyFunc func(token string) error
}
func (m *MockTokenVerifier) VerifyToken(token string) error {
if m.VerifyFunc != nil {
return m.VerifyFunc(token)
}
return fmt.Errorf("VerifyFunc not implemented in mock")
}
// MockTokenExchanger implements TokenExchanger for testing
type MockTokenExchanger struct {
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
@@ -445,6 +457,7 @@ func TestServeHTTP(t *testing.T) {
"jti": generateRandomString(16), // Unique JTI
})
session.SetAccessToken(freshToken)
session.SetIDToken(freshToken) // Ensure ID token is also set
session.SetRefreshToken("valid-refresh-token")
},
expectedStatus: http.StatusOK,
@@ -612,6 +625,7 @@ func TestServeHTTP(t *testing.T) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(validToken)
session.SetIDToken(validToken) // Ensure ID token is also set
session.SetRefreshToken("should-not-be-used-refresh-token")
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
@@ -637,6 +651,7 @@ func TestServeHTTP(t *testing.T) {
"jti": generateRandomString(16), // Unique JTI
})
session.SetAccessToken(freshToken)
session.SetIDToken(freshToken) // Ensure ID token is also set
session.SetRefreshToken("valid-refresh-token")
},
requestHeaders: map[string]string{
@@ -658,6 +673,7 @@ func TestServeHTTP(t *testing.T) {
"jti": generateRandomString(16), // Unique JTI
})
session.SetAccessToken(freshToken)
session.SetIDToken(freshToken) // Ensure ID token is also set
session.SetRefreshToken("valid-refresh-token")
},
requestHeaders: map[string]string{
@@ -670,6 +686,45 @@ func TestServeHTTP(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache for each test to prevent token replay detection errors
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
ts.tOidc.tokenCache = NewTokenCache()
// Reset the global replayCache to prevent "token replay detected" errors
replayCacheMu.Lock()
replayCache = make(map[string]time.Time) // Reset the global cache
replayCacheMu.Unlock()
// Store original tokenVerifier to restore later
origTokenVerifier := ts.tOidc.tokenVerifier
// Create a mock tokenVerifier that clears the replay cache before verification
// This prevents replay detection when the same token is verified multiple times within a test
mockTokenVerifier := &MockTokenVerifier{
VerifyFunc: func(token string) error {
// Clear replay cache before token verification
replayCacheMu.Lock()
replayCache = make(map[string]time.Time)
replayCacheMu.Unlock()
// Call the original verifier's VerifyToken method
// Ensure origTokenVerifier is not nil and is the correct type if necessary,
// though in this context it should be the *TraefikOidc instance.
if origTokenVerifier != nil {
return origTokenVerifier.VerifyToken(token)
}
return fmt.Errorf("original token verifier is nil")
},
}
// Replace tokenVerifier with our mock
ts.tOidc.tokenVerifier = mockTokenVerifier
// Restore original tokenVerifier after test
defer func() {
ts.tOidc.tokenVerifier = origTokenVerifier
}()
req := httptest.NewRequest("GET", tc.requestPath, nil)
// Set common headers needed by the logic (determineScheme, determineHost)
req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that
+5
View File
@@ -228,6 +228,9 @@ func TestTemplatedHeadersIntegration(t *testing.T) {
// Default to true, which means PopulateSessionWithIdTokenClaims is true
// UseIdTokenForSession: true, // Explicitly can be set if needed
}
tOidc.tokenVerifier = tOidc
tOidc.jwtVerifier = tOidc
tOidc.tokenExchanger = tOidc
// Initialize and parse header templates
for _, header := range tc.headers {
@@ -502,6 +505,8 @@ func TestEdgeCaseTemplatedHeaders(t *testing.T) {
extractClaimsFunc: extractClaims,
headerTemplates: make(map[string]*template.Template),
}
tOidc.tokenVerifier = tOidc
tOidc.jwtVerifier = tOidc
// Initialize and parse header templates
for _, header := range tc.headers {
+2
View File
@@ -164,6 +164,8 @@ func TestTokenTypeIntegration(t *testing.T) {
extractClaimsFunc: extractClaims,
headerTemplates: make(map[string]*template.Template),
}
tOidc.tokenVerifier = tOidc
tOidc.jwtVerifier = tOidc
// Initialize and parse header templates
for _, header := range headers {