Multiple improvements for April 2025

* Improve refresh token handling in the background.

Resolves issue when user opens the website, allows the access token to expire, but continues browsing.
The background requests are failing with CORS errors to OIDC provider.

* fixup! Improve refresh token handling in the background.

* Abstract the token blacklisting.
This commit is contained in:
2025-04-04 18:42:41 +01:00
committed by GitHub
parent 4322407129
commit 23e019092a
9 changed files with 587 additions and 611 deletions
-110
View File
@@ -1,110 +0,0 @@
package traefikoidc
import (
"sync"
"time"
)
// TokenBlacklist manages a thread-safe list of revoked tokens with expiration.
type TokenBlacklist struct {
tokens map[string]time.Time
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new token blacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
tokens: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist with an expiration time.
func (b *TokenBlacklist) Add(token string, expiry time.Time) {
b.mutex.Lock()
defer b.mutex.Unlock()
// Clean up expired tokens if we're at capacity
if len(b.tokens) >= 1000 {
now := time.Now()
futureThreshold := now.Add(time.Minute)
for t, exp := range b.tokens {
if now.After(exp) || futureThreshold.After(exp) {
delete(b.tokens, t)
}
}
// If still at capacity, remove oldest token
if len(b.tokens) >= 1000 {
var oldestToken string
var oldestTime time.Time
first := true
for t, exp := range b.tokens {
if first || exp.Before(oldestTime) {
oldestToken = t
oldestTime = exp
first = false
}
}
if oldestToken != "" {
delete(b.tokens, oldestToken)
}
}
}
b.tokens[token] = expiry
}
// IsBlacklisted checks if a token is in the blacklist and not expired.
func (b *TokenBlacklist) IsBlacklisted(token string) bool {
b.mutex.RLock()
defer b.mutex.RUnlock()
expiry, exists := b.tokens[token]
if !exists {
return false
}
// If token is expired, remove it and return false
if time.Now().After(expiry) {
// Switch to write lock to remove expired token
b.mutex.RUnlock()
b.mutex.Lock()
delete(b.tokens, token)
b.mutex.Unlock()
b.mutex.RLock()
return false
}
return true
}
// Cleanup removes expired tokens from the blacklist.
// Also removes tokens that will expire within the next minute to prevent edge cases.
func (b *TokenBlacklist) Cleanup() {
b.mutex.Lock()
defer b.mutex.Unlock()
now := time.Now()
futureThreshold := now.Add(time.Minute)
for token, expiry := range b.tokens {
// Remove tokens that are expired or will expire soon
if now.After(expiry) || futureThreshold.After(expiry) {
delete(b.tokens, token)
}
}
}
// Remove removes a token from the blacklist regardless of its expiration.
func (b *TokenBlacklist) Remove(token string) {
b.mutex.Lock()
defer b.mutex.Unlock()
delete(b.tokens, token)
}
// Count returns the current number of tokens in the blacklist.
func (b *TokenBlacklist) Count() int {
b.mutex.RLock()
defer b.mutex.RUnlock()
return len(b.tokens)
}
-74
View File
@@ -1,74 +0,0 @@
package traefikoidc
import (
"testing"
"time"
)
func TestTokenBlacklist_Add(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
if !blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be blacklisted, but it was not")
}
}
func TestTokenBlacklist_IsBlacklisted(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
if !blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be blacklisted, but it was not")
}
if blacklist.IsBlacklisted("nonExistentToken") {
t.Errorf("Expected non-existent token to not be blacklisted, but it was")
}
}
func TestTokenBlacklist_Cleanup(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(-time.Hour) // Expired token
blacklist.Add(token, expiry)
blacklist.Cleanup()
if blacklist.IsBlacklisted(token) {
t.Errorf("Expected expired token to be removed after cleanup, but it was not")
}
}
func TestTokenBlacklist_Remove(t *testing.T) {
blacklist := NewTokenBlacklist()
token := "testToken"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token, expiry)
blacklist.Remove(token)
if blacklist.IsBlacklisted(token) {
t.Errorf("Expected token to be removed, but it was not")
}
}
func TestTokenBlacklist_Count(t *testing.T) {
blacklist := NewTokenBlacklist()
token1 := "token1"
token2 := "token2"
expiry := time.Now().Add(time.Hour)
blacklist.Add(token1, expiry)
blacklist.Add(token2, expiry)
if blacklist.Count() != 2 {
t.Errorf("Expected blacklist count to be 2, but got %d", blacklist.Count())
}
}
+1 -144
View File
@@ -81,7 +81,7 @@ type TokenResponse struct {
// - codeOrToken: Either the authorization code or refresh token
// - redirectURL: The callback URL for authorization code grant
// - codeVerifier: Optional PKCE code verifier for authorization code grant
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string, codeVerifier string) (*TokenResponse, error) {
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
"client_id": {t.clientID},
@@ -153,149 +153,6 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
return tokenResponse, nil
}
// handleExpiredToken manages token expiration by clearing the session
// and initiating a new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Clear authentication data but preserve CSRF state
session.SetAuthenticated(false)
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
// Save the cleared session state
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save cleared session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// handleCallback processes the authentication callback from the OIDC provider.
// It validates the callback parameters, exchanges the authorization code for
// tokens, verifies the tokens, and establishes the user's session.
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Check for errors in the callback
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
return
}
// Validate CSRF state
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Exchange code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
http.Error(rw, "No code in callback", http.StatusBadRequest)
return
}
// Get the code verifier from the session for PKCE flow
codeVerifier := session.GetCodeVerifier()
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL, codeVerifier)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify tokens and claims
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify nonce to prevent replay attacks
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Validate user's email domain
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
return
}
// Update session with authentication data
session.SetAuthenticated(true)
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Redirect to original path or root
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// extractClaims parses a JWT token and extracts its claims.
// It handles base64url decoding and JSON parsing of the token payload.
func extractClaims(tokenString string) (map[string]interface{}, error) {
+6 -166
View File
@@ -7,172 +7,12 @@ import (
"time"
)
func TestTokenBlacklistSizeLimit(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens up to maxSize
for i := 0; i < 1000; i++ {
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
}
// Verify size is at max
if tb.Count() != 1000 {
t.Errorf("Expected blacklist size to be 1000, got %d", tb.Count())
}
// Add one more token, should trigger cleanup/eviction
tb.Add("newtoken", time.Now().Add(time.Hour))
// Size should still be at max
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
}
}
func TestTokenBlacklistExpiredCleanup(t *testing.T) {
tb := NewTokenBlacklist()
// Add some expired tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("expired%d", i), time.Now().Add(-time.Hour))
}
// Add some valid tokens
for i := 0; i < 500; i++ {
tb.Add(fmt.Sprintf("valid%d", i), time.Now().Add(time.Hour))
}
// Force cleanup
tb.Cleanup()
// Only valid tokens should remain
if tb.Count() != 500 {
t.Errorf("Expected 500 valid tokens after cleanup, got %d", tb.Count())
}
// Verify only valid tokens remain
tb.mutex.RLock()
defer tb.mutex.RUnlock()
for token, expiry := range tb.tokens {
if time.Now().After(expiry) {
t.Errorf("Found expired token after cleanup: %s", token)
}
}
}
func TestTokenBlacklistOldestEviction(t *testing.T) {
tb := NewTokenBlacklist()
// Add tokens at capacity with different expiration times
baseTime := time.Now()
oldestToken := "oldest"
// Add oldest token first
tb.Add(oldestToken, baseTime.Add(time.Hour))
// Fill up to capacity with newer tokens
for i := 0; i < 999; i++ {
tb.Add(fmt.Sprintf("token%d", i), baseTime.Add(time.Hour*2))
}
// Add a new token that should evict the oldest
newToken := "newest"
tb.Add(newToken, baseTime.Add(time.Hour*3))
// Verify oldest token was evicted
if tb.IsBlacklisted(oldestToken) {
t.Error("Oldest token should have been evicted")
}
// Verify newest token is present
if !tb.IsBlacklisted(newToken) {
t.Error("Newest token should be present")
}
}
func TestTokenBlacklistMemoryUsage(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 10000
// Force initial GC
runtime.GC()
// Record initial memory stats
var m1, m2 runtime.MemStats
runtime.ReadMemStats(&m1)
// Simulate heavy usage
for i := 0; i < iterations; i++ {
// Add new token
tb.Add(fmt.Sprintf("token%d", i), time.Now().Add(time.Hour))
// Periodically check blacklisted status
if i%100 == 0 {
tb.IsBlacklisted(fmt.Sprintf("token%d", i-50))
}
// Periodically cleanup
if i%1000 == 0 {
tb.Cleanup()
}
}
// Force GC and wait for it to complete
runtime.GC()
time.Sleep(100 * time.Millisecond)
runtime.ReadMemStats(&m2)
// Check memory growth (using HeapAlloc for more accurate measurement)
memoryGrowth := int64(m2.HeapAlloc - m1.HeapAlloc)
maxAllowedGrowth := int64(2 * 1024 * 1024) // 2MB max growth
if memoryGrowth > maxAllowedGrowth {
t.Logf("Initial HeapAlloc: %d, Final HeapAlloc: %d", m1.HeapAlloc, m2.HeapAlloc)
t.Errorf("Excessive memory growth: %d bytes", memoryGrowth)
}
// Verify size stayed within limits
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size: %d", tb.Count())
}
}
func TestConcurrentTokenBlacklistOperations(t *testing.T) {
tb := NewTokenBlacklist()
iterations := 1000
concurrency := 10
done := make(chan bool)
// Start multiple goroutines performing operations
for i := 0; i < concurrency; i++ {
go func(id int) {
for j := 0; j < iterations; j++ {
// Add tokens
token := fmt.Sprintf("token%d-%d", id, j)
tb.Add(token, time.Now().Add(time.Hour))
// Check blacklist status
tb.IsBlacklisted(token)
// Periodic cleanup
if j%100 == 0 {
tb.Cleanup()
}
}
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < concurrency; i++ {
<-done
}
// Verify size constraints were maintained
if tb.Count() > 1000 {
t.Errorf("Blacklist exceeded max size under concurrent operations: %d", tb.Count())
}
}
// Removed tests related to the old TokenBlacklist implementation:
// - TestTokenBlacklistSizeLimit
// - TestTokenBlacklistExpiredCleanup
// - TestTokenBlacklistOldestEviction
// - TestTokenBlacklistMemoryUsage
// - TestConcurrentTokenBlacklistOperations
func TestTokenCacheMemoryUsage(t *testing.T) {
tc := NewTokenCache()
+5 -2
View File
@@ -15,8 +15,10 @@ import (
"time"
)
var replayCacheMu sync.Mutex
var replayCache = make(map[string]time.Time)
var (
replayCacheMu sync.Mutex
replayCache = make(map[string]time.Time)
)
func cleanupReplayCache() {
now := time.Now()
@@ -164,6 +166,7 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return nil
}
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
+259 -38
View File
@@ -9,11 +9,10 @@ import (
"net"
"net/http"
"net/url"
"runtime"
"strings"
"time"
"runtime"
"github.com/google/uuid"
"golang.org/x/time/rate"
)
@@ -52,7 +51,10 @@ func createDefaultHTTPClient() *http.Client {
}
}
const ConstSessionTimeout = 86400 // Session timeout in seconds
const (
ConstSessionTimeout = 86400 // Session timeout in seconds
defaultBlacklistDuration = 24 * time.Hour // Default duration to blacklist a JTI
)
// TokenVerifier interface for token verification
type TokenVerifier interface {
@@ -64,6 +66,13 @@ type JWTVerifier interface {
VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
}
// TokenExchanger defines methods for OIDC token operations
type TokenExchanger interface {
ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error)
GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error)
RevokeTokenWithProvider(token, tokenType string) error
}
// TraefikOidc is the main struct for the OIDC middleware
type TraefikOidc struct {
next http.Handler
@@ -74,7 +83,7 @@ type TraefikOidc struct {
revocationURL string
jwkCache JWKCacheInterface
metadataCache *MetadataCache
tokenBlacklist *TokenBlacklist
tokenBlacklist *Cache // Replaced TokenBlacklist with generic Cache
jwksURL string
clientID string
clientSecret string
@@ -94,12 +103,13 @@ type TraefikOidc struct {
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
postLogoutRedirectURI string
sessionManager *SessionManager
// exchangeCodeForTokenFunc func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) // Replaced by interface
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
postLogoutRedirectURI string
sessionManager *SessionManager
tokenExchanger TokenExchanger // Added field for mocking
}
// ProviderMetadata holds OIDC provider metadata
@@ -155,6 +165,29 @@ func (t *TraefikOidc) VerifyToken(token string) error {
// Cache the verified token
t.cacheVerifiedToken(token, jwt.Claims)
// Add JTI to blacklist AFTER successful verification to prevent replay
if jti, ok := jwt.Claims["jti"].(string); ok && jti != "" {
// Calculate expiry based on 'exp' claim if available, otherwise use default
expiry := time.Now().Add(defaultBlacklistDuration)
if expClaim, expOk := jwt.Claims["exp"].(float64); expOk {
expTime := time.Unix(int64(expClaim), 0)
tokenDuration := time.Until(expTime)
// Use token expiry if longer than default, capped at a reasonable max (e.g., 24h)
if tokenDuration > defaultBlacklistDuration && tokenDuration < (24*time.Hour) {
expiry = expTime
} else if tokenDuration <= 0 {
// If token already expired but somehow passed verification, use default
expiry = time.Now().Add(defaultBlacklistDuration)
} else {
// Use default if token expiry is shorter or excessively long
expiry = time.Now().Add(defaultBlacklistDuration)
}
}
// Use Set with a duration. Value 'true' is arbitrary, we only care about existence.
t.tokenBlacklist.Set(jti, true, time.Until(expiry))
t.logger.Debugf("Added JTI %s to blacklist cache", jti)
}
return nil
}
@@ -165,11 +198,22 @@ func (t *TraefikOidc) performPreVerificationChecks(token string) error {
return fmt.Errorf("rate limit exceeded")
}
// Check if token is blacklisted
if t.tokenBlacklist.IsBlacklisted(token) {
return fmt.Errorf("token is blacklisted")
// Check if the raw token string itself is blacklisted (e.g., via explicit revocation)
if _, exists := t.tokenBlacklist.Get(token); exists {
return fmt.Errorf("token is blacklisted (raw string) in cache")
}
// Also check if the JTI claim is blacklisted (replay detection)
claims, err := extractClaims(token) // Use existing helper
if err == nil { // Only check JTI if claims could be extracted
if jti, ok := claims["jti"].(string); ok && jti != "" {
if _, exists := t.tokenBlacklist.Get(jti); exists {
// Use a specific error message for replay
return fmt.Errorf("token replay detected (jti: %s) in cache", jti)
}
}
} // If claims extraction fails, proceed; full validation will catch token issues later.
return nil
}
@@ -296,7 +340,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
return config.PostLogoutRedirectURI
}(),
tokenBlacklist: NewTokenBlacklist(),
tokenBlacklist: NewCache(), // Use generic cache for blacklist
jwkCache: &JWKCache{},
metadataCache: NewMetadataCache(),
clientID: config.ClientID,
@@ -316,7 +360,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
// t.exchangeCodeForTokenFunc = t.exchangeCodeForToken // Removed, using interface now
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
@@ -329,6 +373,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t.tokenVerifier = t
t.jwtVerifier = t
t.startTokenCleanup()
t.tokenExchanger = t // Initialize the interface field to self
go t.initializeMetadata(config.ProviderURL)
return t, nil
@@ -532,15 +577,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
if !authenticated {
// Original logic: Always initiate authentication if not authenticated
t.logger.Debug("User not authenticated, initiating OIDC flow")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
return // Stop processing
}
if needsRefresh {
refreshed := t.refreshToken(rw, req, session)
if !refreshed {
// Original logic: Always handle failed refresh as an expired token
t.logger.Debug("Token refresh failed, handling as expired token")
t.handleExpiredToken(rw, req, session, redirectURL)
return
return // Stop processing
}
}
@@ -621,6 +670,151 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.next.ServeHTTP(rw, req)
}
// handleExpiredToken manages token expiration by clearing the session
// and initiating a new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Clear authentication data but preserve CSRF state
session.SetAuthenticated(false)
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
// Save the cleared session state
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save cleared session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// handleCallback processes the authentication callback from the OIDC provider.
// It validates the callback parameters, exchanges the authorization code for
// tokens, verifies the tokens, and establishes the user's session.
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Check for errors in the callback
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
return
}
// Validate CSRF state
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Exchange code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
http.Error(rw, "No code in callback", http.StatusBadRequest)
return
}
// Get the code verifier from the session for PKCE flow
codeVerifier := session.GetCodeVerifier()
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify tokens and claims
// Use the exported VerifyToken method now that handleCallback is in main.go
if err := t.VerifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify nonce to prevent replay attacks
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Validate user's email domain
// Use the unexported isAllowedDomain method now that handleCallback is in main.go
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
return
}
// Update session with authentication data
session.SetAuthenticated(true)
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Redirect to original path or root
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// determineExcludedURL checks if the current request URL is in the excluded list
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
for excludedURL := range t.excludedURLs {
@@ -675,17 +869,26 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
return false, false, true // Session is invalid, consider it expired
}
// Verify the token
if err := t.verifyToken(accessToken); err != nil {
t.logger.Errorf("Token verification failed: %v", err)
return false, false, true // Token is invalid, consider it expired
// Verify the token structure and signature first
jwt, err := parseJWT(accessToken)
if err != nil {
t.logger.Errorf("Failed to parse JWT during auth check: %v", err)
return false, false, true // Invalid format, treat as expired/invalid
}
if err := t.VerifyJWTSignatureAndClaims(jwt, accessToken); err != nil {
// Check if the error is specifically about expiration
if strings.Contains(err.Error(), "token has expired") {
t.logger.Debugf("Token signature/claims valid but token expired, attempting refresh")
// Token is expired but otherwise valid, signal for refresh
return true, true, false // Authenticated=true (was valid), NeedsRefresh=true, Expired=false (because refresh is possible)
}
// Other verification error (signature, issuer, audience etc.)
t.logger.Errorf("Token verification failed (non-expiration): %v", err)
return false, false, true // Token is invalid for other reasons
}
claims, err := extractClaims(accessToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
return false, false, true
}
// Claims already parsed within VerifyJWTSignatureAndClaims if it didn't error early
claims := jwt.Claims
expClaim, ok := claims["exp"].(float64)
if !ok {
@@ -696,17 +899,18 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
now := time.Now().Unix()
expTime := int64(expClaim)
if now > expTime {
t.logger.Debug("Token has expired")
return false, false, true
}
gracePeriod := time.Minute * 5
if now+int64(gracePeriod.Seconds()) > expTime {
t.logger.Debug("Token will expire soon")
return true, true, false // Token will expire soon, needs refresh
// Expiration check is now handled within VerifyJWTSignatureAndClaims logic above
// We only get here if the token is valid and not expired
// Check if token is nearing expiration (needs refresh proactively)
// Define a grace period, e.g., 5 minutes before actual expiry
refreshGracePeriod := int64(5 * 60)
if expTime-now < refreshGracePeriod {
t.logger.Debugf("Token nearing expiration (within %d seconds), scheduling refresh", refreshGracePeriod)
return true, true, false // Needs proactive refresh
}
// Token is valid, not expired, and not nearing expiration
return true, false, false
}
@@ -827,7 +1031,7 @@ func (t *TraefikOidc) startTokenCleanup() {
for range ticker.C {
t.logger.Debug("Starting token cleanup cycle")
t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
// t.tokenBlacklist.Cleanup() // Removed: Generic Cache handles its own cleanup
t.jwkCache.Cleanup() // Assuming jwkCache is the cache from cache.go
// Removed runtime.GC() call
}
@@ -841,7 +1045,8 @@ func (t *TraefikOidc) RevokeToken(token string) {
// Add to blacklist with default expiration
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
t.tokenBlacklist.Add(token, expiry)
// Use Set with a duration. Value 'true' is arbitrary, we only care about existence.
t.tokenBlacklist.Set(token, true, time.Until(expiry))
}
// RevokeTokenWithProvider revokes the token with the provider
@@ -890,7 +1095,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
newToken, err := t.getNewTokenWithRefreshToken(refreshToken)
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
if err != nil {
t.logger.Errorf("Failed to refresh token: %v", err)
return false
@@ -986,3 +1191,19 @@ func buildFullURL(scheme, host, path string) string {
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
// --- TokenExchanger Interface Implementation ---
// ExchangeCodeForToken implements the TokenExchanger interface.
// It calls the existing exchangeTokens helper function.
func (t *TraefikOidc) ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Note: The original exchangeTokens helper is defined in helpers.go and is already a method on *TraefikOidc
return t.exchangeTokens(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
}
// GetNewTokenWithRefreshToken implements the TokenExchanger interface.
// It calls the existing getNewTokenWithRefreshToken helper function.
func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
// Note: The original getNewTokenWithRefreshToken helper is defined in helpers.go and is already a method on *TraefikOidc
return t.getNewTokenWithRefreshToken(refreshToken)
}
+303 -66
View File
@@ -101,29 +101,47 @@ func (ts *TestSuite) Setup() {
jwksURL: "https://test-jwks-url.com",
revocationURL: "https://revocation-endpoint.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: NewTokenBlacklist(),
tokenBlacklist: NewCache(), // Use generic cache for blacklist
tokenCache: NewTokenCache(),
logger: logger,
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
sessionManager: ts.sessionManager,
// Explicitly set paths as New() is bypassed
redirURLPath: "/callback", // Assume default callback path for tests
logoutURLPath: "/callback/logout", // Assume default logout path for tests
tokenURL: "https://test-issuer.com/token", // Explicitly set for refresh tests
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
sessionManager: ts.sessionManager,
}
close(ts.tOidc.initComplete)
ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc
// ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc // Removed
ts.tOidc.tokenVerifier = ts.tOidc
ts.tOidc.jwtVerifier = ts.tOidc
// Set default mock exchanger
ts.tOidc.tokenExchanger = &MockTokenExchanger{
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
// Default mock behavior for code exchange
return &TokenResponse{
IDToken: ts.token, // Use the valid token from setup
AccessToken: ts.token,
RefreshToken: "default-refresh-token",
ExpiresIn: 3600,
}, nil
},
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
// Default mock behavior for refresh (can be overridden in tests)
return nil, fmt.Errorf("default mock: refresh not expected")
},
RevokeTokenFunc: func(token, tokenType string) error {
// Default mock behavior for revoke
return nil
},
}
}
// Helper functions used by TraefikOidc
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
}
// Helper function exchangeCodeForTokenFunc removed as it's unused after refactoring to TokenExchanger interface.
// MockJWKCache implements JWKCacheInterface
type MockJWKCache struct {
@@ -141,6 +159,34 @@ func (m *MockJWKCache) Cleanup() {
m.Err = nil
}
// MockTokenExchanger implements TokenExchanger for testing
type MockTokenExchanger struct {
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
RevokeTokenFunc func(token, tokenType string) error
}
func (m *MockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
if m.ExchangeCodeFunc != nil {
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
}
return nil, fmt.Errorf("ExchangeCodeFunc not implemented in mock")
}
func (m *MockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
if m.RefreshTokenFunc != nil {
return m.RefreshTokenFunc(refreshToken)
}
return nil, fmt.Errorf("RefreshTokenFunc not implemented in mock")
}
func (m *MockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
if m.RevokeTokenFunc != nil {
return m.RevokeTokenFunc(token, tokenType)
}
return fmt.Errorf("RevokeTokenFunc not implemented in mock")
}
// Helper function to create a JWT token
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
header := map[string]interface{}{
@@ -228,13 +274,14 @@ func TestVerifyToken(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache for each test
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
ts.tOidc.tokenCache = NewTokenCache()
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
// Set up the test case
if tc.blacklist {
ts.tOidc.tokenBlacklist.Add(tc.token, time.Now().Add(1*time.Hour))
// Use Set with a duration. Value 'true' is arbitrary.
ts.tOidc.tokenBlacklist.Set(tc.token, true, 1*time.Hour)
}
if tc.rateLimit {
@@ -282,13 +329,53 @@ func TestServeHTTP(t *testing.T) {
ts.tOidc.next = nextHandler
ts.tOidc.name = "test"
// Helper to create an expired token
createExpiredToken := func() string {
exp := time.Now().Add(-1 * time.Hour).Unix() // Expired 1 hour ago
iat := time.Now().Add(-2 * time.Hour).Unix()
nbf := time.Now().Add(-2 * time.Hour).Unix()
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce-expired", // Different nonce for clarity
"jti": generateRandomString(16),
})
return expiredToken
}
// Helper to create a new valid token (simulating refresh)
createNewValidToken := func() string {
exp := time.Now().Add(1 * time.Hour).Unix() // Valid for 1 hour
iat := time.Now().Unix()
nbf := time.Now().Unix()
newToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject",
"email": "user@example.com",
// "nonce": "test-nonce-new", // Nonce is typically not included/validated in refreshed tokens
"jti": generateRandomString(16),
})
return newToken
}
tests := []struct {
name string
requestPath string
sessionValues map[interface{}]interface{}
expectedStatus int
expectedBody string
setupSession func(*SessionData)
name string
requestPath string
sessionValues map[interface{}]interface{}
expectedStatus int
expectedBody string
setupSession func(*SessionData)
mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error)
assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) // Added for post-request checks
}{
{
name: "Excluded URL",
@@ -299,28 +386,77 @@ func TestServeHTTP(t *testing.T) {
{
name: "Unauthenticated request to protected URL",
requestPath: "/protected",
expectedStatus: http.StatusFound,
expectedStatus: http.StatusFound, // Expect redirect to OIDC
},
{
name: "Authenticated request to protected URL",
name: "Authenticated request to protected URL (Valid Token)",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(ts.token)
session.SetAccessToken(ts.token) // Use the valid token generated in Setup
session.SetRefreshToken("valid-refresh-token")
},
expectedStatus: http.StatusOK,
expectedBody: "OK",
},
{
name: "Authenticated request with expired token and successful refresh",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Still marked authenticated initially
session.SetEmail("user@example.com")
session.SetAccessToken(createExpiredToken()) // Set expired token
session.SetRefreshToken("valid-refresh-token") // Set valid refresh token
},
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
return func(refreshToken string) (*TokenResponse, error) {
if refreshToken != "valid-refresh-token" {
return nil, fmt.Errorf("mock error: expected 'valid-refresh-token', got '%s'", refreshToken)
}
// Simulate successful refresh
newToken := createNewValidToken()
return &TokenResponse{
IDToken: newToken, // Return new valid token
AccessToken: newToken, // Often the same as ID token in tests
RefreshToken: "new-refresh-token",
ExpiresIn: 3600,
}, nil
}
},
expectedStatus: http.StatusOK, // Expect success after refresh
expectedBody: "OK",
assertSessionAfterRequest: func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) {
// Create a new request to read the cookies set by the response recorder
reqForCookieRead := httptest.NewRequest("GET", "/protected", nil)
for _, cookie := range rr.Result().Cookies() {
reqForCookieRead.AddCookie(cookie)
}
// Get session based on response cookies
session, err := sessionManager.GetSession(reqForCookieRead)
if err != nil {
t.Fatalf("Failed to get session after request: %v", err)
}
// Assert new tokens are in the session
// Direct comparison with createNewValidToken() is flawed as it generates a new token each time.
// Instead, check if the token was updated (not empty) and verify the refresh token.
if session.GetAccessToken() == "" {
t.Errorf("Expected access token to be updated in session, but it was empty")
}
if session.GetRefreshToken() != "new-refresh-token" {
t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken())
}
},
},
{
name: "Logout URL",
requestPath: "/logout",
requestPath: "/logout", // Assuming logout path is configured or defaulted correctly
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(ts.token)
},
expectedStatus: http.StatusOK,
expectedStatus: http.StatusFound, // Expect redirect after logout
expectedBody: "",
},
}
@@ -328,40 +464,79 @@ func TestServeHTTP(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", tc.requestPath, nil)
req.Header.Set("X-Forwarded-Proto", "http")
req.Header.Set("X-Forwarded-Host", "localhost")
// Set common headers needed by the logic (determineScheme, determineHost)
req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that
req.Header.Set("X-Forwarded-Host", "testhost.com")
req.Host = "testhost.com" // Also set Host header
rr := httptest.NewRecorder()
// Setup session if needed
session, err := ts.tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
t.Fatalf("Test %s: Failed to get initial session: %v", tc.name, err)
}
if tc.setupSession != nil {
tc.setupSession(session)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
// Save session to recorder to get cookies
saveRecorder := httptest.NewRecorder()
if err := session.Save(req, saveRecorder); err != nil {
t.Fatalf("Test %s: Failed to save initial session: %v", tc.name, err)
}
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
// Copy cookies from save recorder to the actual request
for _, cookie := range saveRecorder.Result().Cookies() {
req.AddCookie(cookie)
}
rr = httptest.NewRecorder()
}
// Mocking setup for TokenExchanger
originalExchanger := ts.tOidc.tokenExchanger // Store original
mockExchanger, isMock := originalExchanger.(*MockTokenExchanger)
if !isMock {
// This case should ideally not happen if Setup correctly assigns the mock,
// but handle it defensively.
t.Logf("Warning: Default exchanger was not the mock. Creating a temporary mock.")
mockExchanger = &MockTokenExchanger{
ExchangeCodeFunc: originalExchanger.ExchangeCodeForToken,
RefreshTokenFunc: originalExchanger.GetNewTokenWithRefreshToken,
RevokeTokenFunc: originalExchanger.RevokeTokenWithProvider,
}
ts.tOidc.tokenExchanger = mockExchanger // Temporarily assign mock
}
// Override specific mock methods if needed for the test case
originalMockRefreshFunc := mockExchanger.RefreshTokenFunc // Store current mock func
if tc.mockRefreshTokenFunc != nil {
// Assign the test case specific mock function
mockExchanger.RefreshTokenFunc = tc.mockRefreshTokenFunc(originalExchanger.GetNewTokenWithRefreshToken)
}
// Call ServeHTTP
ts.tOidc.ServeHTTP(rr, req)
// Check response
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
// Restore original exchanger and mock function state
ts.tOidc.tokenExchanger = originalExchanger
if tc.mockRefreshTokenFunc != nil && mockExchanger != nil {
// Restore the previous mock function if we overrode it
mockExchanger.RefreshTokenFunc = originalMockRefreshFunc
}
// Check response status
if rr.Code != tc.expectedStatus {
t.Errorf("Test %s: Expected status %d, got %d. Body: %s", tc.name, tc.expectedStatus, rr.Code, rr.Body.String())
}
// Check response body if expected
if tc.expectedBody != "" {
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
t.Errorf("Expected body %q, got %q", tc.expectedBody, body)
t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body)
}
}
// Perform post-request session assertions if defined
if tc.assertSessionAfterRequest != nil {
tc.assertSessionAfterRequest(t, rr, req, ts.tOidc.sessionManager)
}
})
}
}
@@ -552,17 +727,39 @@ func TestHandleCallback(t *testing.T) {
name: "Disallowed Email",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
// Generate a unique token for this test case to avoid replay issues
// Use claims relevant to this test (disallowed email)
now := time.Now()
exp := now.Add(1 * time.Hour).Unix()
iat := now.Unix()
nbf := now.Unix()
disallowedToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": exp,
"iat": iat,
"nbf": nbf,
"sub": "test-subject-disallowed",
"email": "user@disallowed.com", // The disallowed email for this test
"nonce": "test-nonce", // Match the nonce set in sessionSetupFunc
"jti": generateRandomString(16), // Unique JTI
})
if err != nil {
return nil, fmt.Errorf("failed to create disallowed token for test: %w", err)
}
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@disallowed.com",
"nonce": "test-nonce",
IDToken: disallowedToken,
RefreshToken: "test-refresh-token-disallowed",
}, nil
},
// Remove mock extractClaimsFunc - let the real one parse the disallowedToken
// The test should still fail correctly on the email check later.
// extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// return map[string]interface{}{
// "email": "user@disallowed.com",
// "nonce": "test-nonce",
// }, nil
// },
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
@@ -635,20 +832,61 @@ func TestHandleCallback(t *testing.T) {
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
// Clear the global replay cache before each test run
replayCacheMu.Lock()
replayCache = make(map[string]time.Time) // Reset the global cache
replayCacheMu.Unlock()
// Explicitly clear the shared blacklist at the start of each sub-test
// to ensure no state leaks, even though we expect the local one to be used.
// Note: This line might be redundant now that the verifier is local, but keep for safety.
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
logger := NewLogger("info")
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
// Create a new instance for each test to avoid state carryover
tOidc := &TraefikOidc{
allowedUserDomains: map[string]struct{}{"example.com": {}},
logger: logger,
exchangeCodeForTokenFunc: tc.exchangeCodeForToken,
extractClaimsFunc: tc.extractClaimsFunc,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
sessionManager: sessionManager,
instanceExtractClaimsFunc := tc.extractClaimsFunc
if instanceExtractClaimsFunc == nil {
instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case
}
tOidc := &TraefikOidc{
allowedUserDomains: map[string]struct{}{"example.com": {}},
logger: logger,
// exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field
extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function
tokenVerifier: nil, // Will be set to self below
jwtVerifier: nil, // Temporarily nil, will be set below
sessionManager: sessionManager,
tokenExchanger: &MockTokenExchanger{ // Create a new mock exchanger for this specific test run
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
// Wrap the test case function to match the required signature
if tc.exchangeCodeForToken != nil {
// Only call if the test case provided a function
return tc.exchangeCodeForToken(codeOrToken, redirectURL, codeVerifier)
}
// Provide a default behavior or error if no mock was provided for this test case
return nil, fmt.Errorf("mock ExchangeCodeFunc not implemented for this test case")
},
// Keep other mock funcs nil or provide defaults if needed by other parts of handleCallback
},
tokenCache: NewTokenCache(), // Initialize token cache
limiter: rate.NewLimiter(rate.Inf, 0), // Initialize rate limiter
tokenBlacklist: NewCache(), // Initialize token blacklist cache
// Add potentially missing fields based on New() comparison
clientID: ts.tOidc.clientID,
issuerURL: ts.tOidc.issuerURL,
jwkCache: ts.tOidc.jwkCache, // Use the mock cache from TestSuite
httpClient: ts.tOidc.httpClient,
initComplete: make(chan struct{}), // Initialize the channel
// Setting other fields like paths, enablePKCE etc. if needed
}
tOidc.tokenVerifier = tOidc // Point tokenVerifier to the local instance NOW
tOidc.jwtVerifier = tOidc // Point jwtVerifier to the local instance NOW
close(tOidc.initComplete) // Mark this test instance as initialized
// Create request and response recorder
req := httptest.NewRequest("GET", "/callback"+tc.queryParams, nil)
@@ -839,13 +1077,14 @@ func TestOIDCHandler(t *testing.T) {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
ts.tOidc.tokenCache = NewTokenCache()
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
// Set up the test case
if tc.blacklist {
ts.tOidc.tokenBlacklist.Add(ts.token, time.Now().Add(1*time.Hour))
// Use Set with a duration. Value 'true' is arbitrary.
ts.tOidc.tokenBlacklist.Set(ts.token, true, 1*time.Hour)
}
if tc.rateLimit {
@@ -948,7 +1187,7 @@ func TestHandleLogout(t *testing.T) {
endSessionURL: tc.endSessionURL,
scheme: "http",
logger: logger,
tokenBlacklist: NewTokenBlacklist(),
tokenBlacklist: NewCache(), // Use generic cache for blacklist
httpClient: &http.Client{},
clientID: "test-client-id",
clientSecret: "test-client-secret",
@@ -1018,13 +1257,13 @@ func TestHandleLogout(t *testing.T) {
// Check token blacklist
if token := session.GetAccessToken(); token != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
t.Error("Access token was not blacklisted")
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
t.Error("Access token was not blacklisted in cache")
}
}
if token := session.GetRefreshToken(); token != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
t.Error("Refresh token was not blacklisted")
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
t.Error("Refresh token was not blacklisted in cache")
}
}
})
@@ -1121,7 +1360,7 @@ func TestRevokeToken(t *testing.T) {
t.Run("Token revocation", func(t *testing.T) {
// Create a new instance for this specific test
tOidc := &TraefikOidc{
tokenBlacklist: NewTokenBlacklist(),
tokenBlacklist: NewCache(), // Use generic cache for blacklist
tokenCache: NewTokenCache(),
}
@@ -1136,8 +1375,8 @@ func TestRevokeToken(t *testing.T) {
t.Error("Token was not removed from cache")
}
// Verify token was added to blacklist
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
// Verify token was added to blacklist cache
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
t.Error("Token was not added to blacklist")
}
})
@@ -1404,7 +1643,6 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config, "test")
if err != nil {
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
}
@@ -2043,7 +2281,6 @@ func TestExchangeCodeForToken(t *testing.T) {
// Test exchangeCodeForToken
response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
+4 -9
View File
@@ -34,6 +34,7 @@ func (c *MetadataCache) Cleanup() {
c.metadata = nil
}
}
func (c *MetadataCache) isCacheValid() bool {
return c.metadata != nil && time.Now().Before(c.expiresAt)
}
@@ -67,15 +68,9 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
}
c.metadata = metadata
// Calculate expiration time based on usage patterns
usageCount := 0 // This should be replaced with actual usage tracking logic
if usageCount < 10 {
c.expiresAt = time.Now().Add(30 * time.Minute)
} else if usageCount < 50 {
c.expiresAt = time.Now().Add(1 * time.Hour)
} else {
c.expiresAt = time.Now().Add(2 * time.Hour)
}
// Set a fixed cache lifetime (e.g., 1 hour)
// TODO: Consider making this configurable or respecting HTTP cache headers
c.expiresAt = time.Now().Add(1 * time.Hour)
// End of GetMetadata
return metadata, nil
+9 -2
View File
@@ -1,7 +1,9 @@
package traefikoidc
import (
"math/rand"
"crypto/rand"
"fmt"
"math/big"
"net/http/httptest"
"strings"
"testing"
@@ -12,7 +14,12 @@ func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
// Handle error appropriately in a real application, maybe panic in test helper
panic(fmt.Sprintf("crypto/rand failed: %v", err))
}
b[i] = charset[num.Int64()]
}
return string(b)
}