package traefikoidc import ( "context" "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/rsa" "encoding/base64" "encoding/json" "fmt" "math/big" "net/http" "net/http/httptest" "net/url" "strings" "sync" "testing" "time" "github.com/gorilla/sessions" "golang.org/x/time/rate" ) // TestSuite holds common test data and setup type TestSuite struct { t *testing.T rsaPrivateKey *rsa.PrivateKey rsaPublicKey *rsa.PublicKey ecPrivateKey *ecdsa.PrivateKey tOidc *TraefikOidc mockJWKCache *MockJWKCache sessionManager *SessionManager // utf *UnifiedTestFramework // Removed - consolidated test framework token string } // NewTestSuite creates a new test suite with automatic cleanup func NewTestSuite(t *testing.T) *TestSuite { ts := &TestSuite{ t: t, // utf: NewUnifiedTestFramework(t), // Removed } return ts } // Setup initializes the test suite func (ts *TestSuite) Setup() { // Initialize unified test framework if not already done // Unified test framework removed - using direct cleanup var err error ts.rsaPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048) if err != nil { ts.t.Fatalf("Failed to generate RSA key: %v", err) } ts.rsaPublicKey = &ts.rsaPrivateKey.PublicKey // Generate EC key for EC key tests ts.ecPrivateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { ts.t.Fatalf("Failed to generate EC key: %v", err) } // Create a JWK for the RSA public key jwk := JWK{ Kty: "RSA", Kid: "test-key-id", Alg: "RS256", N: base64.RawURLEncoding.EncodeToString(ts.rsaPublicKey.N.Bytes()), E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(ts.rsaPublicKey.E)))), } jwks := &JWKSet{ Keys: []JWK{jwk}, } // Create a mock JWKCache ts.mockJWKCache = &MockJWKCache{ JWKS: jwks, Err: nil, } // Create a test JWT token signed with the RSA private key // Create timestamps with proper clock skew now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", "jti": generateRandomString(16), }) if err != nil { ts.t.Fatalf("Failed to create test JWT: %v", err) } logger := NewLogger("info") ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger) // Create WaitGroup for the OIDC instance goroutineWG := &sync.WaitGroup{} // Initialize caches properly tokenBlacklist := NewCache() tokenCacheInternal := NewCache() tokenCache := &TokenCache{} if tokenCache.cache == nil { // Type assert to get the underlying UniversalCache if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok { tokenCache.cache = wrapper.cache } } // Common TraefikOidc instance ts.tOidc = &TraefikOidc{ issuerURL: "https://test-issuer.com", clientID: "test-client-id", audience: "test-client-id", clientSecret: "test-client-secret", roleClaimName: "roles", // Set default for backward compatibility groupClaimName: "groups", // Set default for backward compatibility userIdentifierClaim: "email", // Set default for backward compatibility jwkCache: ts.mockJWKCache, jwksURL: "https://test-jwks-url.com", revocationURL: "https://revocation-endpoint.com", limiter: rate.NewLimiter(rate.Every(time.Second), 10), tokenBlacklist: tokenBlacklist, tokenCache: tokenCache, logger: logger, allowedUserDomains: map[string]struct{}{"example.com": {}}, excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}}, httpClient: &http.Client{Timeout: 10 * time.Second}, // Explicitly set paths as New() is bypassed redirURLPath: "/callback", // Assume default callback path for tests logoutURLPath: "/callback/logout", // Assume default logout path for tests tokenURL: "https://test-issuer.com/token", // Explicitly set for refresh tests extractClaimsFunc: extractClaims, initComplete: make(chan struct{}), sessionManager: ts.sessionManager, goroutineWG: goroutineWG, ctx: context.Background(), tokenCleanupStopChan: make(chan struct{}), metadataRefreshStopChan: make(chan struct{}), } close(ts.tOidc.initComplete) ts.tOidc.tokenVerifier = ts.tOidc ts.tOidc.jwtVerifier = ts.tOidc // Set default mock exchanger ts.tOidc.tokenExchanger = &MockTokenExchanger{ ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) { // Default mock behavior for code exchange return &TokenResponse{ IDToken: ts.token, // Use the valid token from setup AccessToken: ts.token, RefreshToken: "default-refresh-token", ExpiresIn: 3600, }, nil }, RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) { // Default mock behavior for refresh (can be overridden in tests) return nil, fmt.Errorf("default mock: refresh not expected") }, RevokeTokenFunc: func(token, tokenType string) error { // Default mock behavior for revoke return nil }, } // OIDC instance created // Register cleanup ts.t.Cleanup(func() { if ts.tOidc.tokenBlacklist != nil { ts.tOidc.tokenBlacklist.Close() } if ts.tOidc.tokenCache != nil && ts.tOidc.tokenCache.cache != nil { ts.tOidc.tokenCache.cache.Close() } }) } // Helper function exchangeCodeForTokenFunc removed as it's unused after refactoring to TokenExchanger interface. // MockJWKCache implements JWKCacheInterface type MockJWKCache struct { Err error JWKS *JWKSet mu sync.RWMutex } // Close is a no-op for the mock. func (m *MockJWKCache) Close() { // No operation needed for the mock. } func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) { m.mu.RLock() defer m.mu.RUnlock() return m.JWKS, m.Err } func (m *MockJWKCache) Cleanup() { // Mock cleanup is a no-op - we don't want to destroy the mock JWKS data // Real cleanup is for expired entries, not resetting all data } // MockTokenVerifier implements TokenVerifier for testing, allowing interception of VerifyToken calls. type MockTokenVerifier struct { VerifyFunc func(token string) error } func (m *MockTokenVerifier) VerifyToken(token string) error { if m.VerifyFunc != nil { return m.VerifyFunc(token) } return fmt.Errorf("VerifyFunc not implemented in mock") } // MockTokenExchanger implements TokenExchanger for testing type MockTokenExchanger struct { ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) RefreshTokenFunc func(refreshToken string) (*TokenResponse, error) RevokeTokenFunc func(token, tokenType string) error } func (m *MockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) { if m.ExchangeCodeFunc != nil { return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier) } return nil, fmt.Errorf("ExchangeCodeFunc not implemented in mock") } func (m *MockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) { if m.RefreshTokenFunc != nil { return m.RefreshTokenFunc(refreshToken) } return nil, fmt.Errorf("RefreshTokenFunc not implemented in mock") } func (m *MockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error { if m.RevokeTokenFunc != nil { return m.RevokeTokenFunc(token, tokenType) } return fmt.Errorf("RevokeTokenFunc not implemented in mock") } // Helper function to check if a token is a test token func isTestToken(token string) bool { // Parse the token without verification to check if it's a test token claims, err := extractClaims(token) if err != nil { return false } // Check if the issuer is our test issuer if iss, ok := claims["iss"].(string); ok { return iss == "https://test-issuer.com" } // Check if audience is our test client if aud, ok := claims["aud"].(string); ok { return aud == "test-client-id" } return false } // Helper function to create a new valid token for refresh tests using test suite func (ts *TestSuite) createNewValidToken() string { now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Add(-2 * time.Minute).Unix() nbf := now.Add(-2 * time.Minute).Unix() token, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", "jti": generateRandomString(16), }) return token } // Helper function to create a JWT token func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) { header := map[string]interface{}{ "alg": alg, "kid": kid, "typ": "JWT", } headerJSON, err := json.Marshal(header) if err != nil { return "", err } headerEncoded := base64.RawURLEncoding.EncodeToString(headerJSON) claimsJSON, err := json.Marshal(claims) if err != nil { return "", err } claimsEncoded := base64.RawURLEncoding.EncodeToString(claimsJSON) signedContent := headerEncoded + "." + claimsEncoded // Select the appropriate hash function based on algorithm var hashFunc crypto.Hash switch alg { case "RS256", "PS256": hashFunc = crypto.SHA256 case "RS384", "PS384": hashFunc = crypto.SHA384 case "RS512", "PS512": hashFunc = crypto.SHA512 default: return "", fmt.Errorf("unsupported algorithm: %s", alg) } hasher := hashFunc.New() hasher.Write([]byte(signedContent)) hashed := hasher.Sum(nil) var signatureBytes []byte // Use appropriate signing method based on algorithm if strings.HasPrefix(alg, "RS") { // PKCS1v15 signing for RS* algorithms signatureBytes, err = rsa.SignPKCS1v15(rand.Reader, privateKey, hashFunc, hashed) } else if strings.HasPrefix(alg, "PS") { // PSS signing for PS* algorithms signatureBytes, err = rsa.SignPSS(rand.Reader, privateKey, hashFunc, hashed, nil) } else { return "", fmt.Errorf("unsupported RSA algorithm: %s", alg) } if err != nil { return "", err } signatureEncoded := base64.RawURLEncoding.EncodeToString(signatureBytes) token := signedContent + "." + signatureEncoded return token, nil } func bigIntToBytes(i *big.Int) []byte { return i.Bytes() } // TestVerifyToken tests the VerifyToken method func TestVerifyToken(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { name string token string blacklist bool rateLimit bool cacheToken bool expectedError bool }{ { name: "Valid token", token: ts.token, expectedError: false, }, { name: "Invalid token signature", token: ts.token + "invalid", expectedError: true, }, { name: "Blacklisted token", token: ts.token, blacklist: true, expectedError: true, }, { name: "Rate limit exceeded", token: ts.token, rateLimit: true, expectedError: true, }, { name: "Token in cache", token: ts.token, cacheToken: true, expectedError: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Reset token blacklist and cache for each test ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist // Clear the token cache instead of creating a new one (it's a singleton) ts.tOidc.tokenCache = NewTokenCache() ts.tOidc.tokenCache.Clear() ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10) // Set up the test case if tc.blacklist { // Use Set with a duration. Value 'true' is arbitrary. ts.tOidc.tokenBlacklist.Set(tc.token, true, 1*time.Hour) } if tc.rateLimit { // Exceed rate limit ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0) } if tc.cacheToken { // Use more realistic claims for cached token ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{ "iss": "https://test-issuer.com", "sub": "test-subject", "exp": float64(time.Now().Add(1 * time.Hour).Unix()), "jti": generateRandomString(16), // Add a JTI claim to prevent replay detection }, time.Minute) // Verify the token is actually in the cache if claims, exists := ts.tOidc.tokenCache.Get(tc.token); exists { t.Logf("Token found in cache with claims: %v", claims) } else { t.Logf("Token NOT found in cache despite cacheToken=true") } } err := ts.tOidc.VerifyToken(tc.token) if tc.expectedError && err == nil { t.Errorf("Test %s: expected error but got nil", tc.name) } if !tc.expectedError && err != nil { t.Errorf("Test %s: expected no error but got %v", tc.name, err) } }) } } // TestServeHTTP tests the ServeHTTP method func TestServeHTTP(t *testing.T) { ts := NewTestSuite(t) ts.Setup() nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) }) ts.tOidc.next = nextHandler ts.tOidc.name = "test" // Helper to create an expired token createExpiredToken := func() string { exp := time.Now().Add(-1 * time.Hour).Unix() // Expired 1 hour ago iat := time.Now().Add(-2 * time.Hour).Unix() nbf := time.Now().Add(-2 * time.Hour).Unix() expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce-expired", // Different nonce for clarity "jti": generateRandomString(16), }) return expiredToken } tests := []struct { sessionValues map[interface{}]interface{} setupSession func(*SessionData) mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) requestHeaders map[string]string name string requestPath string expectedBody string expectedStatus int }{ { name: "Excluded URL", requestPath: "/favicon.ico", expectedStatus: http.StatusOK, expectedBody: "OK", }, { name: "Unauthenticated request (no refresh token) to protected URL", requestPath: "/protected", setupSession: func(session *SessionData) { // Ensure no tokens are set session.SetAuthenticated(false) session.SetAccessToken("") session.SetRefreshToken("") }, expectedStatus: http.StatusFound, // Expect redirect to OIDC as there's no refresh token }, { name: "Unauthenticated request (with refresh token) to protected URL - Expect Refresh Attempt", requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(false) // Not authenticated session.SetAccessToken("") // No access token session.SetRefreshToken("valid-refresh-token-for-unauth-test") // BUT has refresh token }, mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { return func(refreshToken string) (*TokenResponse, error) { if refreshToken != "valid-refresh-token-for-unauth-test" { return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken) } // Simulate successful refresh newToken := ts.createNewValidToken() // Use helper from TestServeHTTP return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-unauth", ExpiresIn: 3600}, nil } }, expectedStatus: http.StatusOK, // Expect OK after successful refresh expectedBody: "OK", }, { name: "Unauthenticated request (with refresh token) to protected URL - Refresh Fails", requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(false) // Not authenticated session.SetAccessToken("") // No access token session.SetRefreshToken("invalid-refresh-token-for-unauth-test") // Invalid refresh token }, mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { return func(refreshToken string) (*TokenResponse, error) { // Simulate failed refresh return nil, fmt.Errorf("mock error: refresh token invalid") } }, expectedStatus: http.StatusFound, // Expect redirect to OIDC after failed refresh }, { name: "Authenticated request to protected URL (Valid Token)", requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") // Generate a fresh valid token for this test case to avoid replay issues freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), "iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16), // Unique JTI }) session.SetAccessToken(freshToken) session.SetIDToken(freshToken) // Ensure ID token is also set session.SetRefreshToken("valid-refresh-token") }, expectedStatus: http.StatusOK, expectedBody: "OK", }, // This test case remains valid as the logic should still attempt refresh when expired token + refresh token exist { name: "Authenticated request with expired token and successful refresh", requestPath: "/protected", setupSession: func(session *SessionData) { // NOTE: isUserAuthenticated now returns authenticated=false if access token is expired, // even if session.SetAuthenticated(true) was called. // We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt. session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token session.SetEmail("user@example.com") // Create an expired token for this test expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(), "iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(), "sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16), }) session.SetAccessToken(expiredToken) // Set expired token session.SetRefreshToken("valid-refresh-token") // Set valid refresh token }, mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { return func(refreshToken string) (*TokenResponse, error) { if refreshToken != "valid-refresh-token" { return nil, fmt.Errorf("mock error: expected 'valid-refresh-token', got '%s'", refreshToken) } // Simulate successful refresh newToken := ts.createNewValidToken() return &TokenResponse{ IDToken: newToken, // Return new valid token AccessToken: newToken, // Often the same as ID token in tests RefreshToken: "new-refresh-token", ExpiresIn: 3600, }, nil } }, expectedStatus: http.StatusOK, // Expect success after refresh expectedBody: "OK", assertSessionAfterRequest: func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) { // Create a new request to read the cookies set by the response recorder reqForCookieRead := httptest.NewRequest("GET", "/protected", nil) for _, cookie := range rr.Result().Cookies() { reqForCookieRead.AddCookie(cookie) } // Get session based on response cookies session, err := sessionManager.GetSession(reqForCookieRead) if err != nil { t.Fatalf("Failed to get session after request: %v", err) } // Assert new tokens are in the session if session.GetAccessToken() == "" || session.GetAccessToken() == createExpiredToken() { t.Errorf("Expected access token to be updated in session, but it was empty or still the expired one") } if session.GetRefreshToken() != "new-refresh-token" { t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken()) } // Also check authenticated flag is now true if !session.GetAuthenticated() { t.Errorf("Expected session to be marked authenticated after successful refresh") } }, }, // This test case remains valid as the logic should still return 401 for API clients on refresh failure { name: "Logout URL", requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") // Generate a fresh valid token for this test case freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), "iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16), // Unique JTI }) session.SetAccessToken(freshToken) }, expectedStatus: http.StatusFound, // Expect redirect after logout expectedBody: "", // No specific session assertion needed for logout redirect itself }, { name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)", requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) // Set flag initially session.SetEmail("user@example.com") // Create an expired token for this test expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(), "iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(), "sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16), }) session.SetAccessToken(expiredToken) // Expired access token session.SetRefreshToken("valid-refresh-token") // Valid refresh token }, mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { return func(refreshToken string) (*TokenResponse, error) { // Simulate failed refresh return nil, fmt.Errorf("mock error: refresh token invalid or provider down") } }, requestHeaders: map[string]string{ "Accept": "application/json", }, expectedStatus: http.StatusUnauthorized, // Expect 401 for API client after failed refresh attempt expectedBody: `{"error":"Unauthorized","error_description":"Token refresh failed","status_code":401}`, }, // This test case remains valid as the logic should still redirect browser clients on refresh failure { name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)", requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) // Set flag initially session.SetEmail("user@example.com") // Create an expired token for this test expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(), "iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(), "sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16), }) session.SetAccessToken(expiredToken) // Expired access token session.SetRefreshToken("valid-refresh-token") // Valid refresh token }, mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { return func(refreshToken string) (*TokenResponse, error) { // Simulate failed refresh return nil, fmt.Errorf("mock error: refresh token invalid or provider down") } }, requestHeaders: map[string]string{ "Accept": "text/html", // Browser client }, expectedStatus: http.StatusFound, // Expect redirect to OIDC for browser client after failed refresh attempt }, // This test case remains valid as proactive refresh should still be attempted { name: "Authenticated request with token nearing expiry (needs refresh)", requestPath: "/protected", setupSession: func(session *SessionData) { // Create token expiring soon (e.g., 30s, within default 60s grace period) exp := time.Now().Add(30 * time.Second).Unix() iat := time.Now().Add(-1 * time.Minute).Unix() nbf := time.Now().Add(-1 * time.Minute).Unix() nearExpiryToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16), }) session.SetAuthenticated(true) session.SetEmail("user@example.com") session.SetAccessToken(nearExpiryToken) session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh }, mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { return func(refreshToken string) (*TokenResponse, error) { if refreshToken != "valid-refresh-token-for-near-expiry" { return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken) } // Simulate successful refresh newToken := ts.createNewValidToken() return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-near-expiry", ExpiresIn: 3600}, nil } }, expectedStatus: http.StatusOK, // Expect success after proactive refresh expectedBody: "OK", }, // This test case remains valid as no refresh should be attempted { name: "Authenticated request with token valid (outside grace period)", requestPath: "/protected", setupSession: func(session *SessionData) { // Create token expiring later (e.g., 10 mins, outside default 60s grace period) exp := time.Now().Add(10 * time.Minute).Unix() iat := time.Now().Add(-1 * time.Minute).Unix() nbf := time.Now().Add(-1 * time.Minute).Unix() validToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16), }) session.SetAuthenticated(true) session.SetEmail("user@example.com") session.SetAccessToken(validToken) session.SetIDToken(validToken) // Ensure ID token is also set session.SetRefreshToken("should-not-be-used-refresh-token") }, mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) { // This should NOT be called return func(refreshToken string) (*TokenResponse, error) { t.Errorf("Refresh token function was called unexpectedly for valid token outside grace period") return nil, fmt.Errorf("refresh should not have been attempted") } }, expectedStatus: http.StatusOK, // Expect success, no refresh needed expectedBody: "OK", }, { name: "Disallowed Domain (Accept: JSON)", requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@disallowed.com") // Use disallowed domain // Generate a fresh valid token for this test case freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), "iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email "jti": generateRandomString(16), // Unique JTI }) session.SetAccessToken(freshToken) session.SetIDToken(freshToken) // Ensure ID token is also set session.SetRefreshToken("valid-refresh-token") }, requestHeaders: map[string]string{ "Accept": "application/json", }, expectedStatus: http.StatusForbidden, expectedBody: `{"error":"Forbidden","error_description":"Access denied: You are not authorized to access this resource. To log out, visit: /callback/logout","status_code":403}`, }, { name: "Disallowed Domain (Accept: HTML)", requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@disallowed.com") // Use disallowed domain // Generate a fresh valid token for this test case freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), "iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email "jti": generateRandomString(16), // Unique JTI }) session.SetAccessToken(freshToken) session.SetIDToken(freshToken) // Ensure ID token is also set session.SetRefreshToken("valid-refresh-token") }, requestHeaders: map[string]string{ "Accept": "text/html", }, expectedStatus: http.StatusForbidden, // Still Forbidden, but HTML response expectedBody: "", // Body check is harder for HTML, focus on status and content-type }, } // Configure allowed domains for domain restriction tests // This allows example.com but not disallowed.com ts.tOidc.allowedUserDomains = map[string]struct{}{ "example.com": {}, } // Use mock JWK cache to enable proper token verification ts.tOidc.jwkCache = ts.mockJWKCache for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Reset token blacklist and cache for each test to prevent token replay detection errors ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist ts.tOidc.tokenCache = NewTokenCache() // Reset the global replayCache to prevent "token replay detected" errors cleanupReplayCache() initReplayCache() // Store original tokenVerifier to restore later origTokenVerifier := ts.tOidc.tokenVerifier // Create a mock tokenVerifier that clears the replay cache before verification // This prevents replay detection when the same token is verified multiple times within a test mockTokenVerifier := &MockTokenVerifier{ VerifyFunc: func(token string) error { // Clear replay cache before token verification cleanupReplayCache() initReplayCache() // For test tokens, perform basic validation without JWKS dependency if isTestToken(token) { // Parse the token to check basic validity and expiration claims, err := extractClaims(token) if err != nil { return fmt.Errorf("token parsing failed: %v", err) } // Check token expiration if exp, ok := claims["exp"].(float64); ok { if time.Now().Unix() > int64(exp) { return fmt.Errorf("token has expired") } } // Token is valid for test purposes - also cache the claims like the real verifier would if ts.tOidc.tokenCache != nil { ts.tOidc.tokenCache.Set(token, claims, time.Hour) } return nil } // For non-test tokens, call the original verifier if origTokenVerifier != nil { return origTokenVerifier.VerifyToken(token) } return fmt.Errorf("original token verifier is nil") }, } // Replace tokenVerifier with our mock ts.tOidc.tokenVerifier = mockTokenVerifier // Restore original tokenVerifier after test defer func() { ts.tOidc.tokenVerifier = origTokenVerifier }() req := httptest.NewRequest("GET", tc.requestPath, nil) // Set common headers needed by the logic (determineScheme, determineHost) req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that req.Header.Set("X-Forwarded-Host", "testhost.com") req.Host = "testhost.com" // Also set Host header // Set request headers from test case if tc.requestHeaders != nil { for key, value := range tc.requestHeaders { req.Header.Set(key, value) } } rr := httptest.NewRecorder() // Setup session if needed session, err := ts.tOidc.sessionManager.GetSession(req) if err != nil { t.Fatalf("Test %s: Failed to get initial session: %v", tc.name, err) } if tc.setupSession != nil { tc.setupSession(session) // Save session to recorder to get cookies saveRecorder := httptest.NewRecorder() if err := session.Save(req, saveRecorder); err != nil { t.Fatalf("Test %s: Failed to save initial session: %v", tc.name, err) } // Copy cookies from save recorder to the actual request for _, cookie := range saveRecorder.Result().Cookies() { req.AddCookie(cookie) } } // Mocking setup for TokenExchanger originalExchanger := ts.tOidc.tokenExchanger // Store original mockExchanger, isMock := originalExchanger.(*MockTokenExchanger) if !isMock { // This case should ideally not happen if Setup correctly assigns the mock, // but handle it defensively. t.Logf("Warning: Default exchanger was not the mock. Creating a temporary mock.") mockExchanger = &MockTokenExchanger{ ExchangeCodeFunc: originalExchanger.ExchangeCodeForToken, RefreshTokenFunc: originalExchanger.GetNewTokenWithRefreshToken, RevokeTokenFunc: originalExchanger.RevokeTokenWithProvider, } ts.tOidc.tokenExchanger = mockExchanger // Temporarily assign mock } // Override specific mock methods if needed for the test case originalMockRefreshFunc := mockExchanger.RefreshTokenFunc // Store current mock func if tc.mockRefreshTokenFunc != nil { // Assign the test case specific mock function mockExchanger.RefreshTokenFunc = tc.mockRefreshTokenFunc(originalExchanger.GetNewTokenWithRefreshToken) } // Call ServeHTTP ts.tOidc.ServeHTTP(rr, req) // Restore original exchanger and mock function state ts.tOidc.tokenExchanger = originalExchanger if tc.mockRefreshTokenFunc != nil && mockExchanger != nil { // Restore the previous mock function if we overrode it mockExchanger.RefreshTokenFunc = originalMockRefreshFunc } // Check response status if rr.Code != tc.expectedStatus { t.Errorf("Test %s: Expected status %d, got %d. Body: %s", tc.name, tc.expectedStatus, rr.Code, rr.Body.String()) } // Check response body if expected // Check response body if expected (handle JSON vs HTML) if tc.expectedBody != "" { // For JSON, compare directly if strings.Contains(rr.Header().Get("Content-Type"), "application/json") { if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody { t.Errorf("Test %s: Expected JSON body %q, got %q", tc.name, tc.expectedBody, body) } } else if tc.expectedBody == "OK" { // Simple check for the "OK" body from next handler if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody { t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body) } } // Add more sophisticated HTML body checks if needed } // Perform post-request session assertions if defined if tc.assertSessionAfterRequest != nil { tc.assertSessionAfterRequest(t, rr, req, ts.tOidc.sessionManager) } }) } } func TestJWKToPEM(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { jwk *JWK name string errorContains string expectError bool }{ { name: "Unsupported Key Type", jwk: &JWK{ Kty: "unsupported", Kid: "test-key-id", }, expectError: true, errorContains: "unsupported key type", }, { name: "EC Key", jwk: &JWK{ Kty: "EC", Kid: "test-ec-key-id", Crv: "P-256", X: base64.RawURLEncoding.EncodeToString(ts.ecPrivateKey.PublicKey.X.Bytes()), Y: base64.RawURLEncoding.EncodeToString(ts.ecPrivateKey.PublicKey.Y.Bytes()), }, expectError: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { pemBytes, err := jwkToPEM(tc.jwk) if tc.expectError { if err == nil { t.Errorf("Expected error, got nil") } else if !strings.Contains(err.Error(), tc.errorContains) { t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err) } } else { if err != nil { t.Errorf("Unexpected error: %v", err) } if len(pemBytes) == 0 { t.Error("PEM bytes should not be empty") } } }) } } func TestParseJWT(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { name string token string errorContains string expectError bool }{ { name: "Invalid Format", token: "invalid.jwt.token", expectError: true, errorContains: "invalid JWT format", }, { name: "Valid Token", token: ts.token, expectError: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { _, err := parseJWT(tc.token) if tc.expectError { if err == nil { t.Errorf("Expected error, got nil") } else if !strings.Contains(err.Error(), tc.errorContains) { t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err) } } else { if err != nil { t.Errorf("Unexpected error: %v", err) } } }) } } func TestJWTVerify_MissingClaims(t *testing.T) { ts := NewTestSuite(t) ts.Setup() jwt := &JWT{ Header: map[string]interface{}{ "alg": "RS256", "kid": "test-key-id", }, Claims: map[string]interface{}{ // Missing 'iss', 'aud', 'exp', 'iat', 'sub' }, } err := jwt.Verify("https://test-issuer.com", "test-client-id") if err == nil { t.Error("Expected error for missing claims, got nil") } } func TestHandleCallback(t *testing.T) { ts := NewTestSuite(t) ts.Setup() redirectURL := "http://example.com/" tests := []struct { exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error) sessionSetupFunc func(*SessionData) name string queryParams string expectedStatus int }{ { name: "Success", queryParams: "?code=test-code&state=test-csrf-token", exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", }, nil }, extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { return map[string]interface{}{ "email": "user@example.com", "nonce": "test-nonce", }, nil }, sessionSetupFunc: func(session *SessionData) { session.SetCSRF("test-csrf-token") session.SetNonce("test-nonce") }, expectedStatus: http.StatusFound, }, { name: "Missing Code", queryParams: "", sessionSetupFunc: func(session *SessionData) { session.SetCSRF("test-csrf-token") session.SetNonce("test-nonce") }, expectedStatus: http.StatusBadRequest, }, { name: "Exchange Code Error", queryParams: "?code=test-code&state=test-csrf-token", exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return nil, fmt.Errorf("exchange code error") }, sessionSetupFunc: func(session *SessionData) { session.SetCSRF("test-csrf-token") session.SetNonce("test-nonce") }, expectedStatus: http.StatusInternalServerError, }, { name: "Missing ID Token", queryParams: "?code=test-code&state=test-csrf-token", exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{}, nil }, sessionSetupFunc: func(session *SessionData) { session.SetCSRF("test-csrf-token") session.SetNonce("test-nonce") }, expectedStatus: http.StatusInternalServerError, }, { name: "Disallowed Email", queryParams: "?code=test-code&state=test-csrf-token", exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { // Generate a unique token for this test case to avoid replay issues // Use claims relevant to this test (disallowed email) now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Unix() nbf := now.Unix() disallowedToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject-disallowed", "email": "user@disallowed.com", // The disallowed email for this test "nonce": "test-nonce", // Match the nonce set in sessionSetupFunc "jti": generateRandomString(16), // Unique JTI }) if err != nil { return nil, fmt.Errorf("failed to create disallowed token for test: %w", err) } return &TokenResponse{ IDToken: disallowedToken, RefreshToken: "test-refresh-token-disallowed", }, nil }, sessionSetupFunc: func(session *SessionData) { session.SetCSRF("test-csrf-token") session.SetNonce("test-nonce") }, expectedStatus: http.StatusForbidden, }, { name: "Invalid State Parameter", queryParams: "?code=test-code&state=invalid-csrf-token", exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", }, nil }, extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { return map[string]interface{}{ "email": "user@example.com", "nonce": "test-nonce", }, nil }, sessionSetupFunc: func(session *SessionData) { session.SetCSRF("test-csrf-token") session.SetNonce("test-nonce") }, expectedStatus: http.StatusBadRequest, }, { name: "Nonce Mismatch", queryParams: "?code=test-code&state=test-csrf-token", exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", }, nil }, extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { return map[string]interface{}{ "email": "user@example.com", "nonce": "invalid-nonce", }, nil }, sessionSetupFunc: func(session *SessionData) { session.SetCSRF("test-csrf-token") session.SetNonce("test-nonce") }, expectedStatus: http.StatusInternalServerError, }, { name: "Missing Nonce in Claims", queryParams: "?code=test-code&state=test-csrf-token", exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", }, nil }, extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { return map[string]interface{}{ "email": "user@example.com", // Missing nonce }, nil }, sessionSetupFunc: func(session *SessionData) { session.SetCSRF("test-csrf-token") session.SetNonce("test-nonce") }, expectedStatus: http.StatusInternalServerError, }, } for _, tc := range tests { // Capture range variable t.Run(tc.name, func(t *testing.T) { // Clear the global replay cache before each test run cleanupReplayCache() initReplayCache() // Explicitly clear the shared blacklist at the start of each sub-test // to ensure no state leaks, even though we expect the local one to be used. // Note: This line might be redundant now that the verifier is local, but keep for safety. ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist logger := NewLogger("info") sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger) // Create a new instance for each test to avoid state carryover instanceExtractClaimsFunc := tc.extractClaimsFunc if instanceExtractClaimsFunc == nil { instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case } tOidc := &TraefikOidc{ allowedUserDomains: map[string]struct{}{"example.com": {}}, logger: logger, userIdentifierClaim: "email", // Required for claim extraction // exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function tokenVerifier: nil, // Will be set to self below jwtVerifier: nil, // Temporarily nil, will be set below sessionManager: sessionManager, tokenExchanger: &MockTokenExchanger{ // Create a new mock exchanger for this specific test run ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) { // Wrap the test case function to match the required signature if tc.exchangeCodeForToken != nil { // Only call if the test case provided a function return tc.exchangeCodeForToken(codeOrToken, redirectURL, codeVerifier) } // Provide a default behavior or error if no mock was provided for this test case return nil, fmt.Errorf("mock ExchangeCodeFunc not implemented for this test case") }, // Keep other mock funcs nil or provide defaults if needed by other parts of handleCallback }, tokenCache: NewTokenCache(), // Initialize token cache limiter: rate.NewLimiter(rate.Inf, 0), // Initialize rate limiter tokenBlacklist: NewCache(), // Initialize token blacklist cache // Add potentially missing fields based on New() comparison clientID: ts.tOidc.clientID, audience: ts.tOidc.clientID, issuerURL: ts.tOidc.issuerURL, jwkCache: ts.tOidc.jwkCache, // Use the mock cache from TestSuite httpClient: ts.tOidc.httpClient, initComplete: make(chan struct{}), // Initialize the channel // Setting other fields like paths, enablePKCE etc. if needed } tOidc.tokenVerifier = tOidc // Point tokenVerifier to the local instance NOW tOidc.jwtVerifier = tOidc // Point jwtVerifier to the local instance NOW close(tOidc.initComplete) // Mark this test instance as initialized // Create request and response recorder req := httptest.NewRequest("GET", "/callback"+tc.queryParams, nil) rr := httptest.NewRecorder() // Create session session, err := sessionManager.GetSession(req) if err != nil { t.Fatalf("Failed to get session: %v", err) } if tc.sessionSetupFunc != nil { tc.sessionSetupFunc(session) } if err := session.Save(req, rr); err != nil { t.Fatalf("Failed to save session: %v", err) } // Copy cookies to the new request for _, cookie := range rr.Result().Cookies() { req.AddCookie(cookie) } // Reset response recorder for the actual test rr = httptest.NewRecorder() // Call handleCallback tOidc.handleCallback(rr, req, redirectURL) // Check response if rr.Code != tc.expectedStatus { t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code) } }) } } func TestIsAllowedDomain(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { allowedDomains map[string]struct{} allowedUsers map[string]struct{} name string email string expectedLogOutput string allowed bool }{ { name: "Allowed domain", email: "user@example.com", allowedDomains: map[string]struct{}{"example.com": {}}, allowedUsers: map[string]struct{}{}, allowed: true, }, { name: "Disallowed domain", email: "user@notallowed.com", allowedDomains: map[string]struct{}{"example.com": {}}, allowedUsers: map[string]struct{}{}, allowed: false, }, { name: "Invalid email", email: "invalid-email", allowedDomains: map[string]struct{}{"example.com": {}}, allowedUsers: map[string]struct{}{}, allowed: false, }, { name: "Specific user is allowed regardless of domain", email: "specific.user@otherdomain.com", allowedDomains: map[string]struct{}{"example.com": {}}, allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}}, allowed: true, }, { name: "Case-insensitive email matching for specific user", email: "Specific.User@otherdomain.com", // Mixed case allowedDomains: map[string]struct{}{"example.com": {}}, allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}}, // Lowercase allowed: true, }, { name: "Only allowed users configured (no domains)", email: "specific.user@otherdomain.com", allowedDomains: map[string]struct{}{}, // Empty allowed domains allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}}, allowed: true, }, { name: "User not in allowed list when only specific users configured", email: "other.user@otherdomain.com", allowedDomains: map[string]struct{}{}, // Empty allowed domains allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}}, allowed: false, }, { name: "No restrictions (both empty)", email: "anyone@anydomain.com", allowedDomains: map[string]struct{}{}, allowedUsers: map[string]struct{}{}, allowed: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Configure TraefikOidc instance for this test case tOidc := ts.tOidc tOidc.allowedUserDomains = tc.allowedDomains tOidc.allowedUsers = tc.allowedUsers allowed := tOidc.isAllowedDomain(tc.email) if allowed != tc.allowed { t.Errorf("Expected allowed=%v, got %v", tc.allowed, allowed) } }) } } func TestIsAllowedUser(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { allowedDomains map[string]struct{} allowedUsers map[string]struct{} userIdentifierClaim string name string userIdentifier string allowed bool }{ // Email-based identification (default behavior) { name: "Email identifier - allowed domain", userIdentifier: "user@example.com", userIdentifierClaim: "email", allowedDomains: map[string]struct{}{"example.com": {}}, allowedUsers: map[string]struct{}{}, allowed: true, }, { name: "Email identifier - disallowed domain", userIdentifier: "user@notallowed.com", userIdentifierClaim: "email", allowedDomains: map[string]struct{}{"example.com": {}}, allowedUsers: map[string]struct{}{}, allowed: false, }, { name: "Email identifier - specific user allowed", userIdentifier: "specific.user@otherdomain.com", userIdentifierClaim: "email", allowedDomains: map[string]struct{}{"example.com": {}}, allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}}, allowed: true, }, // Non-email identifier (sub claim - for Azure AD users without email) { name: "Sub identifier - allowed in allowedUsers", userIdentifier: "abc12345-6789-0abc-def0-123456789abc", userIdentifierClaim: "sub", allowedDomains: map[string]struct{}{}, allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, allowed: true, }, { name: "Sub identifier - not in allowedUsers", userIdentifier: "xyz-not-allowed-user", userIdentifierClaim: "sub", allowedDomains: map[string]struct{}{}, allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, allowed: false, }, { name: "Sub identifier - allowedDomains ignored for non-email", userIdentifier: "user-id-12345", userIdentifierClaim: "sub", allowedDomains: map[string]struct{}{"example.com": {}}, // Should be ignored allowedUsers: map[string]struct{}{"user-id-12345": {}}, allowed: true, }, { name: "Sub identifier - no restrictions allows all", userIdentifier: "any-user-id", userIdentifierClaim: "sub", allowedDomains: map[string]struct{}{}, allowedUsers: map[string]struct{}{}, allowed: true, }, { name: "Sub identifier - case insensitive matching", userIdentifier: "ABC12345-6789-0ABC-DEF0-123456789ABC", // Uppercase userIdentifierClaim: "sub", allowedDomains: map[string]struct{}{}, allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, // Lowercase allowed: true, }, // OID claim (Azure AD object ID) { name: "OID identifier - allowed user", userIdentifier: "oid-12345-67890", userIdentifierClaim: "oid", allowedDomains: map[string]struct{}{}, allowedUsers: map[string]struct{}{"oid-12345-67890": {}}, allowed: true, }, // UPN claim (Azure AD User Principal Name) { name: "UPN identifier - allowed user (looks like email but use sub logic)", userIdentifier: "user@tenant.onmicrosoft.com", userIdentifierClaim: "upn", allowedDomains: map[string]struct{}{"example.com": {}}, // Different domain, should be ignored allowedUsers: map[string]struct{}{"user@tenant.onmicrosoft.com": {}}, allowed: true, }, // Edge cases { name: "Empty identifier - not allowed", userIdentifier: "", userIdentifierClaim: "sub", allowedDomains: map[string]struct{}{}, allowedUsers: map[string]struct{}{"some-user": {}}, allowed: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Configure TraefikOidc instance for this test case tOidc := ts.tOidc tOidc.allowedUserDomains = tc.allowedDomains tOidc.allowedUsers = tc.allowedUsers tOidc.userIdentifierClaim = tc.userIdentifierClaim allowed := tOidc.isAllowedUser(tc.userIdentifier) if allowed != tc.allowed { t.Errorf("Expected allowed=%v, got %v for userIdentifier=%q with claim=%q", tc.allowed, allowed, tc.userIdentifier, tc.userIdentifierClaim) } }) } } func TestUserIdentifierClaimExtraction(t *testing.T) { // Test that the correct claim is extracted based on userIdentifierClaim config tests := []struct { name string userIdentifierClaim string claims map[string]interface{} expectedIdentifier string shouldFallbackToSub bool }{ { name: "Email claim extraction (default)", userIdentifierClaim: "email", claims: map[string]interface{}{ "sub": "user-sub-id", "email": "user@example.com", }, expectedIdentifier: "user@example.com", shouldFallbackToSub: false, }, { name: "Sub claim extraction", userIdentifierClaim: "sub", claims: map[string]interface{}{ "sub": "user-sub-id", "email": "user@example.com", }, expectedIdentifier: "user-sub-id", shouldFallbackToSub: false, }, { name: "OID claim extraction (Azure AD)", userIdentifierClaim: "oid", claims: map[string]interface{}{ "sub": "user-sub-id", "email": "user@example.com", "oid": "azure-object-id", }, expectedIdentifier: "azure-object-id", shouldFallbackToSub: false, }, { name: "UPN claim extraction (Azure AD)", userIdentifierClaim: "upn", claims: map[string]interface{}{ "sub": "user-sub-id", "upn": "user@tenant.onmicrosoft.com", }, expectedIdentifier: "user@tenant.onmicrosoft.com", shouldFallbackToSub: false, }, { name: "Fallback to sub when configured claim is missing", userIdentifierClaim: "email", claims: map[string]interface{}{ "sub": "fallback-sub-id", // email is missing }, expectedIdentifier: "fallback-sub-id", shouldFallbackToSub: true, }, { name: "preferred_username claim extraction", userIdentifierClaim: "preferred_username", claims: map[string]interface{}{ "sub": "user-sub-id", "preferred_username": "jdoe", }, expectedIdentifier: "jdoe", shouldFallbackToSub: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Extract user identifier using the same logic as auth_flow.go userIdentifier, _ := tc.claims[tc.userIdentifierClaim].(string) usedFallback := false if userIdentifier == "" && tc.userIdentifierClaim != "sub" { userIdentifier, _ = tc.claims["sub"].(string) usedFallback = true } if userIdentifier != tc.expectedIdentifier { t.Errorf("Expected identifier %q, got %q", tc.expectedIdentifier, userIdentifier) } if usedFallback != tc.shouldFallbackToSub { t.Errorf("Expected fallback=%v, got %v", tc.shouldFallbackToSub, usedFallback) } }) } } func TestOIDCHandler(t *testing.T) { ts := NewTestSuite(t) ts.Setup() ts.token = "valid.jwt.token" tests := []struct { exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) extractClaimsFunc func(tokenString string) (map[string]interface{}, error) sessionSetupFunc func(session *sessions.Session) name string queryParams string expectedStatus int blacklist bool rateLimit bool cacheToken bool }{ { name: "Missing Code", queryParams: "", sessionSetupFunc: func(session *sessions.Session) { // Set CSRF and nonce values in session session.Values["csrf"] = "test-csrf-token" session.Values["nonce"] = "test-nonce" }, exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { // Simulate token exchange return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", }, nil }, extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { // Simulate extraction of claims with invalid nonce return map[string]interface{}{ "email": "user@example.com", "nonce": "invalid-nonce", }, nil }, expectedStatus: http.StatusInternalServerError, }, { name: "Missing Nonce in Claims", queryParams: "?code=test-code&state=test-csrf-token", sessionSetupFunc: func(session *sessions.Session) { // Set CSRF and nonce values in session session.Values["csrf"] = "test-csrf-token" session.Values["nonce"] = "test-nonce" }, exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { // Simulate token exchange return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", }, nil }, extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { // Simulate extraction of claims without nonce return map[string]interface{}{ "email": "user@example.com", }, nil }, expectedStatus: http.StatusBadRequest, }, { name: "Invalid State Parameter", queryParams: "?code=test-code&state=invalid-csrf-token", sessionSetupFunc: func(session *sessions.Session) { // Set CSRF and nonce values in session session.Values["csrf"] = "test-csrf-token" session.Values["nonce"] = "test-nonce" }, exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { // Simulate token exchange return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", }, nil }, extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { // Simulate extraction of claims return map[string]interface{}{ "email": "user@example.com", "nonce": "test-nonce", }, nil }, expectedStatus: http.StatusBadRequest, }, { name: "Nonce Mismatch", queryParams: "?code=test-code&state=test-csrf-token", sessionSetupFunc: func(session *sessions.Session) { // Set CSRF and nonce values in session session.Values["csrf"] = "test-csrf-token" session.Values["nonce"] = "test-nonce" }, exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) { // Simulate token exchange return &TokenResponse{ IDToken: ts.token, RefreshToken: "test-refresh-token", }, nil }, extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) { // Simulate extraction of claims with mismatched nonce return map[string]interface{}{ "email": "user@example.com", "nonce": "invalid-nonce", }, nil }, expectedStatus: http.StatusBadRequest, }, } for _, tc := range tests { // Capture range variable t.Run(tc.name, func(t *testing.T) { // Reset token blacklist and cache ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist ts.tOidc.tokenCache = NewTokenCache() ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10) // Set up the test case if tc.blacklist { // Use Set with a duration. Value 'true' is arbitrary. ts.tOidc.tokenBlacklist.Set(ts.token, true, 1*time.Hour) } if tc.rateLimit { // Exceed rate limit ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0) } if tc.cacheToken { // Cache the token with dummy claims ts.tOidc.tokenCache.Set(ts.token, map[string]interface{}{ "empty": "claim", }, 60) } }) } } // TestHandleLogout tests the logout functionality func TestHandleLogout(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Create mock revocation endpoint server mockRevocationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { t.Errorf("Expected POST request, got %s", r.Method) } if err := r.ParseForm(); err != nil { t.Fatalf("Failed to parse form: %v", err) } // Verify the required parameters are present if r.Form.Get("token") == "" { t.Error("Missing token parameter") } if r.Form.Get("token_type_hint") == "" { t.Error("Missing token_type_hint parameter") } w.WriteHeader(http.StatusOK) })) defer mockRevocationServer.Close() tests := []struct { setupSession func(*SessionData) name string endSessionURL string expectedURL string host string expectedStatus int }{ { name: "Successful logout with end session endpoint", setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetAccessToken(ValidAccessToken) session.SetIDToken(ValidIDToken) session.SetRefreshToken(ValidRefreshToken) }, endSessionURL: "https://provider/end-session", expectedStatus: http.StatusFound, expectedURL: "https://provider/end-session?id_token_hint=" + url.QueryEscape(ValidIDToken) + "&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F", host: "test-host", }, { name: "Successful logout without end session endpoint", setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetAccessToken(ValidAccessToken) session.SetIDToken(ValidIDToken) session.SetRefreshToken(ValidRefreshToken) }, endSessionURL: "", expectedStatus: http.StatusFound, expectedURL: "http://example.com/", host: "test-host", }, { name: "Logout with empty session", setupSession: func(session *SessionData) {}, expectedStatus: http.StatusFound, expectedURL: "http://example.com/", host: "test-host", }, { name: "Logout with invalid end session URL", setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetAccessToken(ValidAccessToken) session.SetIDToken(ValidIDToken) session.SetRefreshToken(ValidRefreshToken) }, endSessionURL: ":\\invalid-url", expectedStatus: http.StatusInternalServerError, host: "test-host", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { logger := NewLogger("info") sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger) tOidc := &TraefikOidc{ revocationURL: mockRevocationServer.URL, endSessionURL: tc.endSessionURL, logger: logger, tokenBlacklist: NewCache(), // Use generic cache for blacklist httpClient: &http.Client{}, clientID: "test-client-id", audience: "test-client-id", clientSecret: "test-client-secret", tokenCache: NewTokenCache(), forceHTTPS: false, sessionManager: sessionManager, } // Create request with proper headers req := httptest.NewRequest("GET", "/logout", nil) req.Header.Set("Host", tc.host) // Create a response recorder rr := httptest.NewRecorder() // Get a session session, err := sessionManager.GetSession(req) if err != nil { t.Fatalf("Failed to get session: %v", err) } if tc.setupSession != nil { tc.setupSession(session) } if err := session.Save(req, rr); err != nil { t.Fatalf("Failed to save session: %v", err) } // Copy cookies to the new request for _, cookie := range rr.Result().Cookies() { req.AddCookie(cookie) } // Reset response recorder rr = httptest.NewRecorder() // Handle logout tOidc.handleLogout(rr, req) // Check response if rr.Code != tc.expectedStatus { t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code) } if tc.expectedURL != "" { location := rr.Header().Get("Location") if location != tc.expectedURL { t.Errorf("Expected redirect to %q, got %q", tc.expectedURL, location) } } // Verify session is cleared updatedSession, err := sessionManager.GetSession(req) if err != nil { t.Fatalf("Failed to get updated session: %v", err) } // Verify tokens are cleared if token := updatedSession.GetAccessToken(); token != "" { t.Error("Access token not cleared") } if token := updatedSession.GetRefreshToken(); token != "" { t.Error("Refresh token not cleared") } if updatedSession.GetAuthenticated() { t.Error("Session still marked as authenticated") } // Check token blacklist if token := session.GetAccessToken(); token != "" { if _, exists := tOidc.tokenBlacklist.Get(token); !exists { t.Error("Access token was not blacklisted in cache") } } if token := session.GetRefreshToken(); token != "" { if _, exists := tOidc.tokenBlacklist.Get(token); !exists { t.Error("Refresh token was not blacklisted in cache") } } }) } } // TestRevokeTokenWithProvider tests the token revocation with provider func TestRevokeTokenWithProvider(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { name string token string tokenType string statusCode int expectError bool }{ { name: "Successful token revocation", token: "valid-token", tokenType: "refresh_token", statusCode: http.StatusOK, expectError: false, }, { name: "Failed token revocation", token: "invalid-token", tokenType: "refresh_token", statusCode: http.StatusBadRequest, expectError: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Create test server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Verify request method and content type if r.Method != "POST" { t.Errorf("Expected POST request, got %s", r.Method) } if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { t.Errorf("Expected Content-Type application/x-www-form-urlencoded, got %s", ct) } // Verify form values if err := r.ParseForm(); err != nil { t.Fatalf("Failed to parse form: %v", err) } if got := r.Form.Get("token"); got != tc.token { t.Errorf("Expected token %s, got %s", tc.token, got) } if got := r.Form.Get("token_type_hint"); got != tc.tokenType { t.Errorf("Expected token_type_hint %s, got %s", tc.tokenType, got) } if got := r.Form.Get("client_id"); got != ts.tOidc.clientID { t.Errorf("Expected client_id %s, got %s", ts.tOidc.clientID, got) } if got := r.Form.Get("client_secret"); got != ts.tOidc.clientSecret { t.Errorf("Expected client_secret %s, got %s", ts.tOidc.clientSecret, got) } w.WriteHeader(tc.statusCode) })) defer server.Close() // Set revocation URL to test server ts.tOidc.revocationURL = server.URL // Test token revocation err := ts.tOidc.RevokeTokenWithProvider(tc.token, tc.tokenType) if tc.expectError && err == nil { t.Error("Expected error but got nil") } if !tc.expectError && err != nil { t.Errorf("Unexpected error: %v", err) } }) } } // TestRevokeToken tests the token revocation functionality func TestRevokeToken(t *testing.T) { ts := NewTestSuite(t) ts.Setup() token := "test.token.with.claims" claims := map[string]interface{}{ "exp": float64(time.Now().Add(time.Hour).Unix()), } // Test token revocation t.Run("Token revocation", func(t *testing.T) { // Create a new instance for this specific test tOidc := &TraefikOidc{ tokenBlacklist: NewCache(), // Use generic cache for blacklist tokenCache: NewTokenCache(), logger: NewLogger("info"), // Initialize the logger } // Cache the token tOidc.tokenCache.Set(token, claims, time.Hour) // Revoke the token tOidc.RevokeToken(token) // Verify token was removed from cache if _, exists := tOidc.tokenCache.Get(token); exists { t.Error("Token was not removed from cache") } // Verify token was added to blacklist cache if _, exists := tOidc.tokenBlacklist.Get(token); !exists { t.Error("Token was not added to blacklist") } }) } // Add this new test function func TestBuildLogoutURL(t *testing.T) { tests := []struct { name string endSessionURL string idToken string postLogoutRedirect string expectedURL string expectError bool }{ { name: "Valid URL", endSessionURL: "https://provider/end-session", idToken: "test.id.token", postLogoutRedirect: "http://example.com/", expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F", expectError: false, }, { name: "Invalid URL", endSessionURL: "://invalid-url", idToken: "test.id.token", postLogoutRedirect: "http://example.com/", expectError: true, }, { name: "URL with existing query parameters", endSessionURL: "https://provider/end-session?existing=param", idToken: "test.id.token", postLogoutRedirect: "http://example.com/", expectedURL: "https://provider/end-session?existing=param&id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F", expectError: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { url, err := BuildLogoutURL(tc.endSessionURL, tc.idToken, tc.postLogoutRedirect) if tc.expectError { if err == nil { t.Error("Expected error but got nil") } } else { if err != nil { t.Errorf("Unexpected error: %v", err) } if url != tc.expectedURL { t.Errorf("Expected URL %q, got %q", tc.expectedURL, url) } } }) } } // Add this new test function func TestHandleExpiredToken(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { name string setupSession func(*SessionData) expectedPath string }{ { name: "Basic expired token", setupSession: func(session *SessionData) { session.SetAuthenticated(true) // Create an expired token for this test expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(), "iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(), "sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16), }) session.SetAccessToken(expiredToken) session.SetEmail("test@example.com") }, expectedPath: "/original/path", }, { name: "Session with additional values", setupSession: func(session *SessionData) { session.SetAuthenticated(true) // Create an expired token for this test expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(), "iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(), "sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16), }) session.SetAccessToken(expiredToken) session.mainSession.Values["custom_value"] = "should-be-cleared" }, expectedPath: "/another/path", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { logger := NewLogger("info") sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger) tOidc := &TraefikOidc{ sessionManager: sessionManager, logger: logger, tokenVerifier: ts.tOidc.tokenVerifier, jwtVerifier: ts.tOidc.jwtVerifier, initComplete: make(chan struct{}), initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { http.Redirect(rw, req, "/login", http.StatusFound) }, } close(tOidc.initComplete) // Create request req := httptest.NewRequest("GET", tc.expectedPath, nil) rr := httptest.NewRecorder() // Get session session, err := sessionManager.GetSession(req) if err != nil { t.Fatalf("Failed to get session: %v", err) } // Setup session data tc.setupSession(session) // Handle expired token tOidc.handleExpiredToken(rr, req, session, tc.expectedPath) // Get the updated session to verify changes updatedSession, err := sessionManager.GetSession(req) if err != nil { t.Fatalf("Failed to get updated session: %v", err) } // Verify main session values if updatedSession.GetCSRF() == "" { t.Error("CSRF token not set") } if path := updatedSession.GetIncomingPath(); path != tc.expectedPath { t.Errorf("Expected path %s, got %s", tc.expectedPath, path) } if updatedSession.GetNonce() == "" { t.Error("Nonce not set") } // Verify tokens are cleared if token := updatedSession.GetAccessToken(); token != "" { t.Error("Access token not cleared") } if token := updatedSession.GetRefreshToken(); token != "" { t.Error("Refresh token not cleared") } // Verify redirect status if rr.Code != http.StatusFound { t.Errorf("Expected status %d, got %d", http.StatusFound, rr.Code) } }) } } // Add this new test function func TestExtractGroupsAndRoles(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { name string claims map[string]interface{} expectGroups []string expectRoles []string expectError bool }{ { name: "Valid groups and roles", claims: map[string]interface{}{ "groups": []interface{}{"group1", "group2"}, "roles": []interface{}{"role1", "role2"}, }, expectGroups: []string{"group1", "group2"}, expectRoles: []string{"role1", "role2"}, expectError: false, }, { name: "Empty groups and roles", claims: map[string]interface{}{ "groups": []interface{}{}, "roles": []interface{}{}, }, expectGroups: []string{}, expectRoles: []string{}, expectError: false, }, { name: "Invalid groups format", claims: map[string]interface{}{ "groups": "not-an-array", "roles": []interface{}{"role1"}, }, expectError: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Create a test token with the claims token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims) if err != nil { t.Fatalf("Failed to create test token: %v", err) } groups, roles, err := ts.tOidc.extractGroupsAndRoles(token) if tc.expectError { if err == nil { t.Error("Expected error but got nil") } } else { if err != nil { t.Errorf("Unexpected error: %v", err) } // Compare groups if !stringSliceEqual(groups, tc.expectGroups) { t.Errorf("Expected groups %v, got %v", tc.expectGroups, groups) } // Compare roles if !stringSliceEqual(roles, tc.expectRoles) { t.Errorf("Expected roles %v, got %v", tc.expectRoles, roles) } } }) } } // TestMultipleMiddlewareInstances verifies that multiple middleware instances // can be created and initialized properly for different routes func TestMultipleMiddlewareInstances(t *testing.T) { if testing.Short() { t.Skip("Skipping test in short mode") } // Create mock provider metadata server mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/.well-known/openid-configuration" { w.WriteHeader(http.StatusNotFound) return } metadata := ProviderMetadata{ Issuer: "https://test-issuer.com", AuthURL: "https://test-issuer.com/auth", TokenURL: "https://test-issuer.com/token", JWKSURL: "https://test-issuer.com/jwks", RevokeURL: "https://test-issuer.com/revoke", EndSessionURL: "https://test-issuer.com/end-session", } json.NewEncoder(w).Encode(metadata) })) defer mockServer.Close() // Create base config config := &Config{ ProviderURL: mockServer.URL, ClientID: "test-client", ClientSecret: "test-secret", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-thats-long-enough", } // Create multiple middleware instances routes := []string{"/api/v1", "/api/v2", "/api/v3"} var middlewares []*TraefikOidc for _, route := range routes { config.CallbackURL = route + "/callback" middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }), config, "test") if err != nil { t.Fatalf("Failed to create middleware for route %s: %v", route, err) } // Type assert to access internal fields if m, ok := middleware.(*TraefikOidc); ok { middlewares = append(middlewares, m) } else { t.Fatalf("Middleware is not of type *TraefikOidc") } } // Clean up all middleware instances to prevent goroutine leaks defer func() { for i, m := range middlewares { if err := m.Close(); err != nil { t.Errorf("Failed to close middleware instance %d: %v", i, err) } } }() // Wait for all instances to initialize for i, m := range middlewares { select { case <-m.initComplete: case <-time.After(5 * time.Second): t.Fatalf("Middleware instance %d failed to initialize", i) } // Verify each instance has its own unique configuration if m.issuerURL != "https://test-issuer.com" { t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL) } if m.authURL != "https://test-issuer.com/auth" { t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL) } if m.tokenURL != "https://test-issuer.com/token" { t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL) } if m.jwksURL != "https://test-issuer.com/jwks" { t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL) } if m.redirURLPath != routes[i]+"/callback" { t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath) } } // Test that each instance can handle requests independently for i, m := range middlewares { req := httptest.NewRequest("GET", routes[i]+"/protected", nil) rr := httptest.NewRecorder() m.ServeHTTP(rr, req) // Should redirect to auth URL since not authenticated if rr.Code != http.StatusFound { t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code) } location := rr.Header().Get("Location") if !strings.Contains(location, "https://test-issuer.com/auth") { t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location) } } } // TestMultiRealmMetadataRefreshIsolation verifies that multiple middleware instances // with different provider URLs (e.g., different Keycloak realms) get separate // metadata refresh tasks. This addresses the issue reported in PR #88. func TestMultiRealmMetadataRefreshIsolation(t *testing.T) { if testing.Short() { t.Skip("Skipping test in short mode") } // Create two mock provider metadata servers simulating different Keycloak realms realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/.well-known/openid-configuration" { w.WriteHeader(http.StatusNotFound) return } metadata := ProviderMetadata{ Issuer: "https://keycloak.example.com/realms/realm1", AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth", TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token", JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs", EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout", } json.NewEncoder(w).Encode(metadata) })) defer realm1Server.Close() realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/.well-known/openid-configuration" { w.WriteHeader(http.StatusNotFound) return } metadata := ProviderMetadata{ Issuer: "https://keycloak.example.com/realms/realm2", AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth", TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token", JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs", EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout", } json.NewEncoder(w).Encode(metadata) })) defer realm2Server.Close() // Config for realm1 config1 := &Config{ ProviderURL: realm1Server.URL, ClientID: "realm1-client", ClientSecret: "realm1-secret", CallbackURL: "/realm1/callback", SessionEncryptionKey: "test-encryption-key-thats-long-enough", CookiePrefix: "_oidc_realm1_", } // Config for realm2 config2 := &Config{ ProviderURL: realm2Server.URL, ClientID: "realm2-client", ClientSecret: "realm2-secret", CallbackURL: "/realm2/callback", SessionEncryptionKey: "test-encryption-key-thats-long-enough", CookiePrefix: "_oidc_realm2_", } // Create middleware instances for both realms middleware1, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }), config1, "realm1-middleware") if err != nil { t.Fatalf("Failed to create middleware for realm1: %v", err) } middleware2, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }), config2, "realm2-middleware") if err != nil { t.Fatalf("Failed to create middleware for realm2: %v", err) } m1, ok1 := middleware1.(*TraefikOidc) m2, ok2 := middleware2.(*TraefikOidc) if !ok1 || !ok2 { t.Fatalf("Middleware is not of type *TraefikOidc") } // Clean up middleware instances defer func() { if err := m1.Close(); err != nil { t.Errorf("Failed to close realm1 middleware: %v", err) } if err := m2.Close(); err != nil { t.Errorf("Failed to close realm2 middleware: %v", err) } }() // Wait for both instances to initialize select { case <-m1.initComplete: case <-time.After(5 * time.Second): t.Fatalf("Realm1 middleware failed to initialize") } select { case <-m2.initComplete: case <-time.After(5 * time.Second): t.Fatalf("Realm2 middleware failed to initialize") } // Verify each instance has the correct issuer URL from their respective realms if !strings.Contains(m1.issuerURL, "realm1") { t.Errorf("Realm1 middleware expected issuer with realm1, got %s", m1.issuerURL) } if !strings.Contains(m2.issuerURL, "realm2") { t.Errorf("Realm2 middleware expected issuer with realm2, got %s", m2.issuerURL) } // Verify provider URLs are different if m1.providerURL == m2.providerURL { t.Errorf("Both middlewares should have different provider URLs, got same: %s", m1.providerURL) } // Test that each middleware can handle requests independently req1 := httptest.NewRequest("GET", "/realm1/protected", nil) rr1 := httptest.NewRecorder() m1.ServeHTTP(rr1, req1) req2 := httptest.NewRequest("GET", "/realm2/protected", nil) rr2 := httptest.NewRecorder() m2.ServeHTTP(rr2, req2) // Both should redirect to their respective auth URLs if rr1.Code != http.StatusFound { t.Errorf("Realm1: Expected redirect status %d, got %d", http.StatusFound, rr1.Code) } if rr2.Code != http.StatusFound { t.Errorf("Realm2: Expected redirect status %d, got %d", http.StatusFound, rr2.Code) } location1 := rr1.Header().Get("Location") location2 := rr2.Header().Get("Location") if !strings.Contains(location1, "realm1") { t.Errorf("Realm1: Expected redirect to realm1 auth URL, got %s", location1) } if !strings.Contains(location2, "realm2") { t.Errorf("Realm2: Expected redirect to realm2 auth URL, got %s", location2) } } // TestMetadataRecoveryOnProviderFailure verifies that the middleware automatically // recovers when the OIDC provider becomes available after initial failure. func TestMetadataRecoveryOnProviderFailure(t *testing.T) { if testing.Short() { t.Skip("Skipping test in short mode") } // Track whether the provider is "available" providerAvailable := false var mu sync.Mutex // Create mock provider that initially fails, then becomes available mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() available := providerAvailable mu.Unlock() if !available { w.WriteHeader(http.StatusServiceUnavailable) return } if r.URL.Path == "/.well-known/openid-configuration" { metadata := ProviderMetadata{ Issuer: "https://test-issuer.com", AuthURL: "https://test-issuer.com/auth", TokenURL: "https://test-issuer.com/token", JWKSURL: "https://test-issuer.com/jwks", EndSessionURL: "https://test-issuer.com/logout", } json.NewEncoder(w).Encode(metadata) return } w.WriteHeader(http.StatusNotFound) })) defer mockServer.Close() config := &Config{ ProviderURL: mockServer.URL, ClientID: "test-client", ClientSecret: "test-secret", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-thats-long-enough", } // Create middleware while provider is unavailable middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }), config, "test-recovery") if err != nil { t.Fatalf("Failed to create middleware: %v", err) } m, ok := middleware.(*TraefikOidc) if !ok { t.Fatalf("Middleware is not of type *TraefikOidc") } defer m.Close() // Wait for initial initialization to complete (it should fail) select { case <-m.initComplete: case <-time.After(15 * time.Second): t.Fatal("Initialization did not complete in time") } // Verify initial state - should be in failed state (no issuerURL) m.metadataMu.RLock() initialIssuer := m.issuerURL m.metadataMu.RUnlock() if initialIssuer != "" { t.Errorf("Expected empty issuerURL after failed init, got: %s", initialIssuer) } // First request should get 503 req1 := httptest.NewRequest("GET", "/protected", nil) rr1 := httptest.NewRecorder() m.ServeHTTP(rr1, req1) if rr1.Code != http.StatusServiceUnavailable { t.Errorf("Expected 503 when provider unavailable, got %d", rr1.Code) } // Now make the provider available mu.Lock() providerAvailable = true mu.Unlock() // Reset the retry timer to allow immediate retry m.metadataRetryMutex.Lock() m.lastMetadataRetryTime = time.Time{} // Reset to zero time m.metadataRetryMutex.Unlock() // Second request should trigger recovery attempt req2 := httptest.NewRequest("GET", "/protected", nil) rr2 := httptest.NewRecorder() m.ServeHTTP(rr2, req2) // Give the async recovery a moment to complete time.Sleep(100 * time.Millisecond) // Check if recovery happened m.metadataMu.RLock() recoveredIssuer := m.issuerURL m.metadataMu.RUnlock() if recoveredIssuer == "" { t.Error("Expected issuerURL to be recovered after provider became available") } // Third request should succeed (redirect to auth, not 503) req3 := httptest.NewRequest("GET", "/protected", nil) rr3 := httptest.NewRecorder() m.ServeHTTP(rr3, req3) if rr3.Code == http.StatusServiceUnavailable { t.Errorf("Expected redirect after recovery, still got 503") } t.Logf("Recovery test: initial_issuer=%q, recovered_issuer=%q, final_status=%d", initialIssuer, recoveredIssuer, rr3.Code) } func TestServeHTTPRolesAndGroups(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Create consistent timestamps for all test cases now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew tests := []struct { allowedRolesAndGroups map[string]struct{} claims map[string]interface{} setupSession func(*SessionData) expectedHeaders map[string]string name string expectedStatus int }{ { name: "User with allowed role", allowedRolesAndGroups: map[string]struct{}{ "admin": {}, }, claims: map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "roles": []interface{}{"admin", "user"}, "groups": []interface{}{"group1"}, "jti": generateRandomString(16), }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") }, expectedStatus: http.StatusOK, expectedHeaders: map[string]string{ "X-User-Roles": "admin,user", "X-User-Groups": "group1", }, }, { name: "User with allowed group", allowedRolesAndGroups: map[string]struct{}{ "allowed-group": {}, }, claims: map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "roles": []interface{}{"user"}, "groups": []interface{}{"allowed-group"}, "jti": generateRandomString(16), }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") }, expectedStatus: http.StatusOK, expectedHeaders: map[string]string{ "X-User-Roles": "user", "X-User-Groups": "allowed-group", }, }, { name: "User without allowed roles or groups", allowedRolesAndGroups: map[string]struct{}{ "admin": {}, "allowed-group": {}, }, claims: map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "roles": []interface{}{"user"}, "groups": []interface{}{"regular-group"}, "jti": generateRandomString(16), }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") }, expectedStatus: http.StatusForbidden, }, { name: "No role/group restrictions", allowedRolesAndGroups: map[string]struct{}{}, claims: map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "roles": []interface{}{"user"}, "groups": []interface{}{"regular-group"}, "jti": generateRandomString(16), }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") }, expectedStatus: http.StatusOK, expectedHeaders: map[string]string{ "X-User-Roles": "user", "X-User-Groups": "regular-group", }, }, { name: "Claims without roles and groups", allowedRolesAndGroups: map[string]struct{}{}, claims: map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "jti": generateRandomString(16), }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) session.SetEmail("user@example.com") }, expectedStatus: http.StatusOK, expectedHeaders: map[string]string{}, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Create token with claims token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims) if err != nil { t.Fatalf("Failed to create test token: %v", err) } // Create test handler nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) // Configure OIDC middleware tOidc := ts.tOidc tOidc.next = nextHandler tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups // Create request req := httptest.NewRequest("GET", "/protected", nil) rr := httptest.NewRecorder() // Set up session session, err := tOidc.sessionManager.GetSession(req) if err != nil { t.Fatalf("Failed to get session: %v", err) } tc.setupSession(session) session.SetAccessToken(token) if err := session.Save(req, rr); err != nil { t.Fatalf("Failed to save session: %v", err) } // Copy cookies to the new request for _, cookie := range rr.Result().Cookies() { req.AddCookie(cookie) } // Reset response recorder rr = httptest.NewRecorder() // Serve request tOidc.ServeHTTP(rr, req) // Check status code if rr.Code != tc.expectedStatus { t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code) } // Check headers if status is OK if tc.expectedStatus == http.StatusOK { for header, expectedValue := range tc.expectedHeaders { if value := req.Header.Get(header); value != expectedValue { t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value) } } } }) } } // Helper function to compare string slices func stringSliceEqual(a, b []string) bool { if len(a) != len(b) { return false } for i := range a { if a[i] != b[i] { return false } } return true } // TestExchangeTokensWithRedirects tests the token exchange process with redirects func TestExchangeTokensWithRedirects(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { setupServer func() *httptest.Server name string errorContains string expectError bool }{ { name: "Successful token exchange with redirects", setupServer: func() *httptest.Server { redirectCount := 0 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if redirectCount < 3 { // Set a cookie before redirecting http.SetCookie(w, &http.Cookie{ Name: fmt.Sprintf("redirect-cookie-%d", redirectCount), Value: "test-value", }) redirectCount++ w.Header().Set("Location", r.URL.String()) w.WriteHeader(http.StatusFound) return } // Verify all cookies from previous redirects are present cookies := r.Cookies() if len(cookies) != 3 { t.Errorf("Expected 3 cookies, got %d", len(cookies)) } for i := range 3 { found := false expectedName := fmt.Sprintf("redirect-cookie-%d", i) for _, cookie := range cookies { if cookie.Name == expectedName { found = true break } } if !found { t.Errorf("Cookie %s not found", expectedName) } } // Return successful token response w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(TokenResponse{ IDToken: "test.id.token", AccessToken: "test-access-token", TokenType: "Bearer", ExpiresIn: 3600, RefreshToken: "test-refresh-token", }) })) }, expectError: false, }, { name: "Too many redirects", setupServer: func() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Location", r.URL.String()) w.WriteHeader(http.StatusFound) })) }, expectError: true, errorContains: "stopped after 50 redirects", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { server := tc.setupServer() defer server.Close() // Configure the test instance tOidc := ts.tOidc tOidc.tokenURL = server.URL // Test token exchange response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback", "test-code-verifier") if tc.expectError { if err == nil { t.Error("Expected error but got nil") } else if !strings.Contains(err.Error(), tc.errorContains) { t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error()) } } else { if err != nil { t.Errorf("Unexpected error: %v", err) } if response == nil { t.Error("Expected token response but got nil") } else if response.IDToken != "test.id.token" { t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken) } } }) } } // TestBuildAuthURL tests the buildAuthURL function with various URL scenarios func TestBuildAuthURL(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { name string authURL string issuerURL string redirectURL string state string nonce string codeChallenge string expectedPrefix string enablePKCE bool checkPKCE bool }{ { name: "Absolute Auth URL", authURL: "https://auth.example.com/oauth/authorize", issuerURL: "https://auth.example.com", redirectURL: "https://app.example.com/callback", state: "test-state", nonce: "test-nonce", enablePKCE: false, codeChallenge: "", expectedPrefix: "https://auth.example.com/oauth/authorize?", checkPKCE: false, }, { name: "Relative Auth URL", authURL: "/oidc/auth", issuerURL: "https://logto.example.com", redirectURL: "https://app.example.com/callback", state: "test-state", nonce: "test-nonce", enablePKCE: false, codeChallenge: "", expectedPrefix: "https://logto.example.com/oidc/auth?", checkPKCE: false, }, { name: "Relative Auth URL with Different Issuer", authURL: "/sign-in", issuerURL: "https://auth.example.com:8443", redirectURL: "https://app.example.com/callback", state: "test-state", nonce: "test-nonce", enablePKCE: false, codeChallenge: "", expectedPrefix: "https://auth.example.com:8443/sign-in?", checkPKCE: false, }, { name: "With PKCE Enabled", authURL: "https://auth.example.com/oauth/authorize", issuerURL: "https://auth.example.com", redirectURL: "https://app.example.com/callback", state: "test-state", nonce: "test-nonce", enablePKCE: true, codeChallenge: "test-code-challenge", expectedPrefix: "https://auth.example.com/oauth/authorize?", checkPKCE: true, }, { name: "With PKCE Enabled but No Challenge", authURL: "https://auth.example.com/oauth/authorize", issuerURL: "https://auth.example.com", redirectURL: "https://app.example.com/callback", state: "test-state", nonce: "test-nonce", enablePKCE: true, codeChallenge: "", expectedPrefix: "https://auth.example.com/oauth/authorize?", checkPKCE: false, }, { name: "With PKCE Disabled but Challenge Provided", authURL: "https://auth.example.com/oauth/authorize", issuerURL: "https://auth.example.com", redirectURL: "https://app.example.com/callback", state: "test-state", nonce: "test-nonce", enablePKCE: false, codeChallenge: "test-code-challenge", expectedPrefix: "https://auth.example.com/oauth/authorize?", checkPKCE: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Configure the test instance tOidc := ts.tOidc tOidc.authURL = tc.authURL tOidc.issuerURL = tc.issuerURL tOidc.enablePKCE = tc.enablePKCE // Call buildAuthURL with code challenge result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce, tc.codeChallenge) // Verify the URL starts with the expected prefix if !strings.HasPrefix(result, tc.expectedPrefix) { t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result) } // Parse the resulting URL to verify query parameters parsedURL, err := url.Parse(result) if err != nil { t.Fatalf("Failed to parse resulting URL: %v", err) } query := parsedURL.Query() expectedParams := map[string]string{ "client_id": tOidc.clientID, "response_type": "code", "redirect_uri": tc.redirectURL, "state": tc.state, "nonce": tc.nonce, } for key, expected := range expectedParams { if got := query.Get(key); got != expected { t.Errorf("Expected %s=%q, got %q", key, expected, got) } } // Verify PKCE parameters if tc.checkPKCE { if got := query.Get("code_challenge"); got != tc.codeChallenge { t.Errorf("Expected code_challenge=%q, got %q", tc.codeChallenge, got) } if got := query.Get("code_challenge_method"); got != "S256" { t.Errorf("Expected code_challenge_method=%q, got %q", "S256", got) } } else { if got := query.Get("code_challenge"); got != "" { t.Errorf("Expected no code_challenge, but got %q", got) } if got := query.Get("code_challenge_method"); got != "" { t.Errorf("Expected no code_challenge_method, but got %q", got) } } // Verify scopes are present and correct if len(tOidc.scopes) > 0 { expectedScopes := strings.Join(tOidc.scopes, " ") if got := query.Get("scope"); got != expectedScopes { t.Errorf("Expected scope=%q, got %q", expectedScopes, got) } } }) } } // TestExchangeCodeForToken tests the exchangeCodeForToken function with PKCE support func TestExchangeCodeForToken(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { setupMock func(t *testing.T) *httptest.Server name string codeVerifier string enablePKCE bool }{ { name: "With PKCE Enabled and Code Verifier", enablePKCE: true, codeVerifier: "test-code-verifier", setupMock: func(t *testing.T) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { t.Fatalf("Failed to parse form: %v", err) } // Verify code_verifier is included if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "test-code-verifier" { t.Errorf("Expected code_verifier=test-code-verifier, got %s", codeVerifier) } // Return successful token response w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(TokenResponse{ IDToken: "test.id.token", AccessToken: "test-access-token", TokenType: "Bearer", ExpiresIn: 3600, RefreshToken: "test-refresh-token", }) })) }, }, { name: "With PKCE Disabled but Code Verifier Provided", enablePKCE: false, codeVerifier: "test-code-verifier", setupMock: func(t *testing.T) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { t.Fatalf("Failed to parse form: %v", err) } // Verify code_verifier is NOT included if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" { t.Errorf("Expected no code_verifier, got %s", codeVerifier) } // Return successful token response w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(TokenResponse{ IDToken: "test.id.token", AccessToken: "test-access-token", TokenType: "Bearer", ExpiresIn: 3600, RefreshToken: "test-refresh-token", }) })) }, }, { name: "With PKCE Enabled but No Code Verifier", enablePKCE: true, codeVerifier: "", setupMock: func(t *testing.T) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { t.Fatalf("Failed to parse form: %v", err) } // Verify code_verifier is NOT included if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" { t.Errorf("Expected no code_verifier, got %s", codeVerifier) } // Return successful token response w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(TokenResponse{ IDToken: "test.id.token", AccessToken: "test-access-token", TokenType: "Bearer", ExpiresIn: 3600, RefreshToken: "test-refresh-token", }) })) }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { server := tc.setupMock(t) defer server.Close() // Configure the test instance tOidc := ts.tOidc tOidc.tokenURL = server.URL tOidc.enablePKCE = tc.enablePKCE // Test exchangeCodeForToken response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier) if err != nil { t.Errorf("Unexpected error: %v", err) } if response == nil { t.Error("Expected token response but got nil") } else if response.IDToken != "test.id.token" { t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken) } }) } } // TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path. func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Create a request with query parameters req := httptest.NewRequest("GET", "/protected/resource?param1=value1¶m2=value2", nil) responseRecorder := httptest.NewRecorder() // Get session session, err := ts.sessionManager.GetSession(req) if err != nil { t.Fatalf("Failed to get session: %v", err) } // Call defaultInitiateAuthentication redirectURL := "http://example.com/callback" ts.tOidc.defaultInitiateAuthentication(responseRecorder, req, session, redirectURL) // Verify that the incoming path includes query parameters incomingPath := session.GetIncomingPath() expectedPath := "/protected/resource?param1=value1¶m2=value2" if incomingPath != expectedPath { t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath) } } // TestVerifyTimeConstraint tests the time constraint verification logic with separate past/future skew tolerances. func TestVerifyTimeConstraint(t *testing.T) { // Define tolerances used in jwt.go (ensure they match) toleranceFuture := 2 * time.Minute tolerancePast := 10 * time.Second now := time.Now() tests := []struct { name string claimTime time.Time claimName string futureCheck bool // true for exp, false for iat/nbf expectError bool }{ // Expiration (future=true, tolerance=2min) { name: "EXP: Valid (expires in 1 min)", claimTime: now.Add(1 * time.Minute), claimName: "exp", futureCheck: true, expectError: false, }, { name: "EXP: Expired (expired 3 min ago)", claimTime: now.Add(-3 * time.Minute), // Outside 2min tolerance claimName: "exp", futureCheck: true, expectError: true, }, { name: "EXP: Valid (expired 1 min ago, within 2min tolerance)", claimTime: now.Add(-1 * time.Minute), // Inside 2min tolerance claimName: "exp", futureCheck: true, expectError: false, // Should be allowed due to future tolerance }, // Issued At (future=false, tolerance=10s) { name: "IAT: Valid (issued 1 min ago)", claimTime: now.Add(-1 * time.Minute), claimName: "iat", futureCheck: false, expectError: false, }, { name: "IAT: Invalid (issued 15 sec in future)", claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance claimName: "iat", futureCheck: false, expectError: true, // "token used before issued" }, { name: "IAT: Valid (issued 5 sec in future, within 10s tolerance)", claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance claimName: "iat", futureCheck: false, expectError: false, // Should be allowed due to past tolerance }, // Not Before (future=false, tolerance=10s) { name: "NBF: Valid (active 1 min ago)", claimTime: now.Add(-1 * time.Minute), claimName: "nbf", futureCheck: false, expectError: false, }, { name: "NBF: Invalid (active in 15 sec)", claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance claimName: "nbf", futureCheck: false, expectError: true, // "token not yet valid" }, { name: "NBF: Valid (active in 5 sec, within 10s tolerance)", claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance claimName: "nbf", futureCheck: false, expectError: false, // Should be allowed due to past tolerance }, } // Temporarily adjust global tolerances for test consistency, then restore originalFutureTolerance := ClockSkewToleranceFuture originalPastTolerance := ClockSkewTolerancePast ClockSkewToleranceFuture = toleranceFuture ClockSkewTolerancePast = tolerancePast defer func() { ClockSkewToleranceFuture = originalFutureTolerance ClockSkewTolerancePast = originalPastTolerance }() for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Convert claim time to float64 unix timestamp unixTime := float64(tc.claimTime.Unix()) + float64(tc.claimTime.Nanosecond())/1e9 var err error // Call the specific verification function which uses verifyTimeConstraint if tc.claimName == "exp" { err = verifyExpiration(unixTime) } else if tc.claimName == "iat" { err = verifyIssuedAt(unixTime) } else if tc.claimName == "nbf" { err = verifyNotBefore(unixTime) } else { t.Fatalf("Unknown claim name in test setup: %s", tc.claimName) } if tc.expectError { if err == nil { t.Errorf("Expected error for claim %s at time %v (now=%v), but got nil", tc.claimName, tc.claimTime, now) } else { t.Logf("Got expected error: %v", err) // Log the error for confirmation } } else { if err != nil { t.Errorf("Expected no error for claim %s at time %v (now=%v), but got: %v", tc.claimName, tc.claimTime, now, err) } } }) } } // Add missing closing brace for TestVerifyTimeConstraint // ===== JWT REPLAY DETECTION TESTS ===== // These tests ensure the replay detection fix works correctly and prevents regressions // TestJWTVerifyWithSkipReplayCheck tests the new skipReplayCheck parameter functionality func TestJWTVerifyWithSkipReplayCheck(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Clear the global replay cache before test cleanupReplayCache() initReplayCache() // Create a test JWT with unique JTI jti := generateRandomString(16) now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Unix() nbf := now.Unix() token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", "jti": jti, }) if err != nil { t.Fatalf("Failed to create test JWT: %v", err) } jwt, err := parseJWT(token) if err != nil { t.Fatalf("Failed to parse JWT: %v", err) } tests := []struct { name string errorContains string skipReplayCheck bool firstCall bool expectError bool }{ { name: "First verification with skipReplayCheck=false should succeed", skipReplayCheck: false, firstCall: true, expectError: false, }, { name: "Second verification with skipReplayCheck=false should fail (replay detected)", skipReplayCheck: false, firstCall: false, expectError: true, errorContains: "token replay detected", }, { name: "Verification with skipReplayCheck=true should always succeed", skipReplayCheck: true, firstCall: false, // Even on subsequent calls expectError: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { if tc.firstCall { // Clear replay cache for first call tests cleanupReplayCache() initReplayCache() } err := jwt.Verify("https://test-issuer.com", "test-client-id", tc.skipReplayCheck) if tc.expectError { if err == nil { t.Errorf("Expected error containing '%s', but got nil", tc.errorContains) } else if !strings.Contains(err.Error(), tc.errorContains) { t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err) } } else { if err != nil { t.Errorf("Expected no error, but got: %v", err) } } }) } } // TestJWTVerifyBackwardCompatibility tests that calls without the skipReplayCheck parameter default to replay checking func TestJWTVerifyBackwardCompatibility(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Clear the global replay cache cleanupReplayCache() initReplayCache() // Create a test JWT with unique JTI jti := generateRandomString(16) now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Unix() nbf := now.Unix() token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", "jti": jti, }) if err != nil { t.Fatalf("Failed to create test JWT: %v", err) } jwt, err := parseJWT(token) if err != nil { t.Fatalf("Failed to parse JWT: %v", err) } // First call with old signature (no skipReplayCheck parameter) should succeed err = jwt.Verify("https://test-issuer.com", "test-client-id") if err != nil { t.Errorf("First verification should succeed, got: %v", err) } // Second call with old signature should fail due to replay detection err = jwt.Verify("https://test-issuer.com", "test-client-id") if err == nil { t.Error("Second verification should fail due to replay detection") } else if !strings.Contains(err.Error(), "token replay detected") { t.Errorf("Expected 'token replay detected' error, got: %v", err) } } // TestTokenReplayDetectionFalsePositiveFix tests the specific scenario that was causing false positives func TestTokenReplayDetectionFalsePositiveFix(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Clear the global replay cache cleanupReplayCache() initReplayCache() // Create a test JWT with unique JTI jti := generateRandomString(16) now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Unix() nbf := now.Unix() token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", "jti": jti, }) if err != nil { t.Fatalf("Failed to create test JWT: %v", err) } // Simulate the authentication flow that was causing false positives: // 1. Initial authentication adds JTI to cache // 2. Subsequent request validation should not trigger false positive // Step 1: Initial authentication (this would add JTI to cache) jwt1, err := parseJWT(token) if err != nil { t.Fatalf("Failed to parse JWT for initial auth: %v", err) } err = jwt1.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check if err != nil { t.Fatalf("Initial authentication should succeed: %v", err) } // Step 2: Subsequent request validation (this should skip replay check to avoid false positive) jwt2, err := parseJWT(token) if err != nil { t.Fatalf("Failed to parse JWT for subsequent request: %v", err) } err = jwt2.Verify("https://test-issuer.com", "test-client-id", true) // Skip replay check if err != nil { t.Errorf("Subsequent request validation should succeed with skipReplayCheck=true: %v", err) } // Step 3: Verify that actual replay attacks are still detected jwt3, err := parseJWT(token) if err != nil { t.Fatalf("Failed to parse JWT for replay attack test: %v", err) } err = jwt3.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check if err == nil { t.Error("Actual replay attack should be detected when skipReplayCheck=false") } else if !strings.Contains(err.Error(), "token replay detected") { t.Errorf("Expected 'token replay detected' error, got: %v", err) } } // TestAuthenticationFlowReplayDetection tests the complete authentication flow func TestAuthenticationFlowReplayDetection(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Clear the global replay cache cleanupReplayCache() initReplayCache() // Create a test JWT with unique JTI jti := generateRandomString(16) now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Unix() nbf := now.Unix() token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", "jti": jti, }) if err != nil { t.Fatalf("Failed to create test JWT: %v", err) } // Test the complete flow: // 1. Initial authentication (should add JTI to cache) // 2. Multiple subsequent requests (should not trigger false positives) // 3. Actual replay attack from different source (should be detected) // Step 1: Initial authentication err = ts.tOidc.VerifyToken(token) if err != nil { t.Fatalf("Initial authentication should succeed: %v", err) } // Verify JTI is in cache (use shardedReplayCache which is the actual cache used) exists := shardedReplayCache.Exists(jti) if !exists { t.Error("JTI should be added to replay cache during initial authentication") } // Step 2: Subsequent requests (simulate normal request processing) // These should use the token cache and skip replay detection for i := range 3 { err = ts.tOidc.VerifyToken(token) if err != nil { t.Errorf("Subsequent request %d should succeed: %v", i+1, err) } } // Step 3: Simulate actual replay attack by directly calling JWT.Verify with replay check jwt, err := parseJWT(token) if err != nil { t.Fatalf("Failed to parse JWT for replay attack test: %v", err) } err = jwt.Verify("https://test-issuer.com", "test-client-id", false) // Force replay check if err == nil { t.Error("Actual replay attack should be detected") } else if !strings.Contains(err.Error(), "token replay detected") { t.Errorf("Expected 'token replay detected' error, got: %v", err) } } // TestActualReplayAttackDetection ensures real replay attacks are still properly detected func TestActualReplayAttackDetection(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Clear the global replay cache cleanupReplayCache() initReplayCache() // Create a test JWT with unique JTI jti := generateRandomString(16) now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Unix() nbf := now.Unix() token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", "jti": jti, }) if err != nil { t.Fatalf("Failed to create test JWT: %v", err) } jwt, err := parseJWT(token) if err != nil { t.Fatalf("Failed to parse JWT: %v", err) } // First verification should succeed err = jwt.Verify("https://test-issuer.com", "test-client-id", false) if err != nil { t.Fatalf("First verification should succeed: %v", err) } // Simulate different types of replay attacks replayTests := []struct { name string description string }{ { name: "Direct replay attack", description: "Same token used again with replay checking enabled", }, { name: "Replay from different source", description: "Token intercepted and replayed by attacker", }, } for _, rt := range replayTests { t.Run(rt.name, func(t *testing.T) { // Parse token again (simulating replay) replayJWT, err := parseJWT(token) if err != nil { t.Fatalf("Failed to parse JWT for replay test: %v", err) } // Attempt replay with normal replay checking err = replayJWT.Verify("https://test-issuer.com", "test-client-id", false) if err == nil { t.Errorf("Replay attack should be detected for: %s", rt.description) } else if !strings.Contains(err.Error(), "token replay detected") { t.Errorf("Expected 'token replay detected' error for %s, got: %v", rt.description, err) } }) } } // TestConcurrentTokenValidation tests thread safety of replay detection func TestConcurrentTokenValidation(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Configure rate limiter to allow more requests for concurrent testing ts.tOidc.limiter = rate.NewLimiter(rate.Limit(1000), 1000) // Allow 1000 requests per second with burst of 1000 // Clear the global replay cache cleanupReplayCache() initReplayCache() // Create multiple tokens with unique JTIs var tokens []string var jtis []string now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Unix() nbf := now.Unix() for i := range 10 { jti := generateRandomString(16) jtis = append(jtis, jti) token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", "jti": jti, }) if err != nil { t.Fatalf("Failed to create test JWT %d: %v", i, err) } tokens = append(tokens, token) } // Test concurrent validation const numGoroutines = 20 const numIterations = 5 results := make(chan error, numGoroutines*numIterations) for g := range numGoroutines { go func(goroutineID int) { for i := range numIterations { tokenIndex := (goroutineID + i) % len(tokens) token := tokens[tokenIndex] // First validation should succeed err := ts.tOidc.VerifyToken(token) results <- err // Subsequent validation with same token should also succeed (uses cache) err = ts.tOidc.VerifyToken(token) results <- err } }(g) } // Collect results var errors []error for range numGoroutines * numIterations * 2 { if err := <-results; err != nil { errors = append(errors, err) } } // All validations should succeed (no race conditions) if len(errors) > 0 { t.Errorf("Expected no errors in concurrent validation, got %d errors: %v", len(errors), errors) } // Verify all JTIs are in cache (use shardedReplayCache which is the actual cache used) for i, jti := range jtis { if !shardedReplayCache.Exists(jti) { t.Errorf("JTI %d (%s) should be in replay cache", i, jti) } } } // TestJTIBlacklistBehavior tests the JTI blacklist cache management func TestJTIBlacklistBehavior(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Properly reinitialize the global replay cache cleanupReplayCache() // Clean up any existing cache and reset sync.Once initReplayCache() // Initialize new cache through proper channel // Create a test JWT with unique JTI jti := generateRandomString(16) t.Logf("TestJTIBlacklistBehavior - JTI: %s", jti) now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Unix() nbf := now.Unix() token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", "jti": jti, }) if err != nil { t.Fatalf("Failed to create test JWT: %v", err) } // Test JTI blacklist behavior tests := []struct { action func() error name string description string expectError bool }{ { name: "Initial verification adds JTI to blacklist", action: func() error { return ts.tOidc.VerifyToken(token) }, expectError: false, description: "First verification should succeed and add JTI to blacklist", }, { name: "JTI exists in blacklist after verification", action: func() error { // Use shardedReplayCache which is the actual cache used if !shardedReplayCache.Exists(jti) { return fmt.Errorf("JTI not found in blacklist cache") } return nil }, expectError: false, description: "JTI should be present in blacklist cache", }, { name: "Subsequent verification uses cache (no replay check)", action: func() error { return ts.tOidc.VerifyToken(token) }, expectError: false, description: "Subsequent verification should succeed using token cache", }, { name: "Direct JWT verification detects replay", action: func() error { jwt, err := parseJWT(token) if err != nil { return err } return jwt.Verify("https://test-issuer.com", "test-client-id", false) }, expectError: true, description: "Direct JWT verification should detect replay", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { err := tc.action() if tc.expectError { if err == nil { t.Errorf("Expected error for %s, but got nil", tc.description) } } else { if err != nil { t.Errorf("Expected no error for %s, but got: %v", tc.description, err) } } }) } } // TestSessionBasedTokenRevalidation tests token revalidation in session-based scenarios func TestSessionBasedTokenRevalidation(t *testing.T) { if testing.Short() { t.Skip("Skipping session-based token revalidation test in short mode") } ts := NewTestSuite(t) ts.Setup() // Clear the global replay cache cleanupReplayCache() initReplayCache() // Create a test JWT with unique JTI jti := generateRandomString(16) now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Unix() nbf := now.Unix() token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", "jti": jti, }) if err != nil { t.Fatalf("Failed to create test JWT: %v", err) } // Simulate session-based token revalidation scenario // This tests the specific case that was causing false positives // Step 1: Initial authentication (callback processing) err = ts.tOidc.VerifyToken(token) if err != nil { t.Fatalf("Initial authentication should succeed: %v", err) } // Step 2: Multiple session-based requests (normal request processing) // These should not trigger replay detection false positives for i := range 5 { err = ts.tOidc.VerifyToken(token) if err != nil { t.Errorf("Session request %d should succeed: %v", i+1, err) } } // Step 3: Verify token is in both caches appropriately // Check token cache if _, exists := ts.tOidc.tokenCache.Get(token); !exists { t.Error("Token should be in token cache") } // Check replay cache // Use shardedReplayCache which is the actual cache used inReplayCache := shardedReplayCache.Exists(jti) if !inReplayCache { t.Error("JTI should be in replay cache") } // Step 4: Verify that clearing token cache still allows validation ts.tOidc.tokenCache = NewTokenCache() // Clear token cache err = ts.tOidc.VerifyToken(token) if err != nil { t.Errorf("Token validation should succeed even after cache clear: %v", err) } } // TestEdgeCasesWithDifferentTokenTypes tests replay detection with different token types func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Properly reinitialize the global replay cache cleanupReplayCache() // Clean up any existing cache and reset sync.Once initReplayCache() // Initialize new cache through proper channel now := time.Now() exp := now.Add(1 * time.Hour).Unix() iat := now.Unix() nbf := now.Unix() tests := []struct { claims map[string]interface{} name string tokenType string expectError bool }{ { name: "ID Token with JTI", tokenType: "id_token", claims: map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", "jti": generateRandomString(16), "token_type": "id_token", }, expectError: false, }, { name: "Access Token with JTI", tokenType: "access_token", claims: map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "scope": "openid profile email", "jti": generateRandomString(16), "token_type": "access_token", }, expectError: false, }, { name: "Token without JTI", tokenType: "no_jti", claims: map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf, "sub": "test-subject", "email": "user@example.com", "nonce": "test-nonce", // No JTI claim }, expectError: false, // Should still work, just no replay protection }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Create token with specific claims token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims) if err != nil { t.Fatalf("Failed to create test JWT: %v", err) } // First verification should succeed err = ts.tOidc.VerifyToken(token) if tc.expectError { if err == nil { t.Errorf("Expected error for token type %s, but got nil", tc.tokenType) } } else { if err != nil { t.Errorf("Expected no error for token type %s, but got: %v", tc.tokenType, err) } } // Second verification should also succeed (uses cache) if !tc.expectError { err = ts.tOidc.VerifyToken(token) if err != nil { t.Errorf("Second verification should succeed for token type %s: %v", tc.tokenType, err) } } // Test direct JWT verification for replay detection if !tc.expectError && tc.claims["jti"] != nil { jwt, err := parseJWT(token) if err != nil { t.Fatalf("Failed to parse JWT: %v", err) } // This should detect replay for tokens with JTI err = jwt.Verify("https://test-issuer.com", "test-client-id", false) if err == nil { t.Errorf("Expected replay detection for token type %s with JTI", tc.tokenType) } else if !strings.Contains(err.Error(), "token replay detected") { t.Errorf("Expected 'token replay detected' error for token type %s, got: %v", tc.tokenType, err) } } }) } } // TestScopeMerging tests the scope append functionality func TestScopeMerging(t *testing.T) { // Helper function to compare string slices equalSlices := func(a, b []string) bool { if len(a) != len(b) { return false } for i, v := range a { if v != b[i] { return false } } return true } tests := []struct { name string defaultScopes []string userScopes []string expectedScopes []string }{ { name: "Empty user scopes", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{}, expectedScopes: []string{"openid", "profile", "email"}, }, { name: "Nil user scopes", defaultScopes: []string{"openid", "profile", "email"}, userScopes: nil, expectedScopes: []string{"openid", "profile", "email"}, }, { name: "New scopes are appended", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"custom_scope", "another_scope"}, expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"}, }, { name: "Deduplication - user scope already in defaults", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"openid", "custom_scope"}, expectedScopes: []string{"openid", "profile", "email", "custom_scope"}, }, { name: "Duplicate user scopes are removed", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"custom_scope", "custom_scope", "another_scope"}, expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"}, }, { name: "Multiple overlapping scopes", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"profile", "custom_scope", "email", "another_scope", "profile"}, expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"}, }, { name: "Only custom scopes", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"read:users", "write:users", "admin"}, expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin"}, }, { name: "Empty defaults", defaultScopes: []string{}, userScopes: []string{"custom1", "custom2"}, expectedScopes: []string{"custom1", "custom2"}, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Test the mergeScopes function directly result := mergeScopes(tc.defaultScopes, tc.userScopes) if !equalSlices(result, tc.expectedScopes) { t.Errorf("Expected %v, got %v", tc.expectedScopes, result) } }) } } // TestScopeMergingEdgeCases tests additional edge cases for scope deduplication func TestScopeMergingEdgeCases(t *testing.T) { // Helper function to compare string slices equalSlices := func(a, b []string) bool { if len(a) != len(b) { return false } for i, v := range a { if v != b[i] { return false } } return true } tests := []struct { name string description string defaultScopes []string userScopes []string expectedScopes []string }{ { name: "Case sensitivity preserved", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"OpenID", "PROFILE", "custom"}, expectedScopes: []string{"openid", "profile", "email", "OpenID", "PROFILE", "custom"}, description: "OAuth scopes are case-sensitive, so different cases should be preserved", }, { name: "Empty strings in user scopes", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"", "custom", "", "another"}, expectedScopes: []string{"openid", "profile", "email", "", "custom", "another"}, description: "Empty strings should be preserved (though invalid in OAuth)", }, { name: "Whitespace scopes", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{" ", "custom", " ", "another"}, expectedScopes: []string{"openid", "profile", "email", " ", "custom", " ", "another"}, description: "Whitespace-only scopes should be preserved as distinct", }, { name: "Large number of scopes", defaultScopes: []string{"openid", "profile", "email"}, userScopes: generateLargeUserScopes(), expectedScopes: func() []string { // Manually calculate expected result with proper deduplication defaults := []string{"openid", "profile", "email"} userScopes := generateLargeUserScopes() return mergeScopes(defaults, userScopes) }(), description: "Performance test with larger scope lists", }, { name: "Complex OAuth scopes with special characters", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"read:users", "write:users", "admin:*", "scope/with/slashes", "scope-with-dashes"}, expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin:*", "scope/with/slashes", "scope-with-dashes"}, description: "Real-world OAuth scopes with colons, slashes, and special characters", }, { name: "Duplicate defaults in user scopes multiple times", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"openid", "profile", "openid", "custom", "email", "profile", "custom"}, expectedScopes: []string{"openid", "profile", "email", "custom"}, description: "Multiple duplicates of default scopes should be completely deduplicated", }, { name: "All user scopes are duplicates of defaults", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"email", "openid", "profile", "openid"}, expectedScopes: []string{"openid", "profile", "email"}, description: "When all user scopes duplicate defaults, result should be just defaults", }, { name: "Single scope scenarios", defaultScopes: []string{"openid"}, userScopes: []string{"custom"}, expectedScopes: []string{"openid", "custom"}, description: "Minimal case with single scopes", }, { name: "Identical scopes in same order", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"openid", "profile", "email"}, expectedScopes: []string{"openid", "profile", "email"}, description: "When user scopes exactly match defaults, no duplication", }, { name: "Identical scopes in different order", defaultScopes: []string{"openid", "profile", "email"}, userScopes: []string{"email", "profile", "openid"}, expectedScopes: []string{"openid", "profile", "email"}, description: "Order of defaults is preserved when user scopes are reordered duplicates", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Test the mergeScopes function directly result := mergeScopes(tc.defaultScopes, tc.userScopes) if !equalSlices(result, tc.expectedScopes) { t.Errorf("Expected %v, got %v\nDescription: %s", tc.expectedScopes, result, tc.description) } }) } } // generateLargeUserScopes creates a large list of user scopes for performance testing func generateLargeUserScopes() []string { scopes := make([]string, 100) for i := range 100 { scopes[i] = fmt.Sprintf("scope_%d", i) } // Add some duplicates to test deduplication performance scopes = append(scopes, "scope_1", "scope_5", "scope_10", "openid") // Include a default duplicate return scopes } // TestScopeMergingPerformance tests performance with large scope lists func TestScopeMergingPerformance(t *testing.T) { // Create large scope lists defaultScopes := []string{"openid", "profile", "email"} // Create 1000 user scopes with some duplicates userScopes := make([]string, 1000) for i := range 1000 { if i%10 == 0 { // Add some duplicates of defaults userScopes[i] = defaultScopes[i%len(defaultScopes)] } else if i%7 == 0 { // Add some internal duplicates userScopes[i] = fmt.Sprintf("scope_%d", i%50) } else { userScopes[i] = fmt.Sprintf("scope_%d", i) } } // Measure performance start := time.Now() result := mergeScopes(defaultScopes, userScopes) duration := time.Since(start) // Verify result correctness if len(result) < len(defaultScopes) { t.Errorf("Result should contain at least the default scopes") } // Verify no duplicates exist seen := make(map[string]bool) for _, scope := range result { if seen[scope] { t.Errorf("Duplicate scope found in result: %s", scope) } seen[scope] = true } // Performance assertion (should be very fast) if duration > time.Millisecond*10 { t.Logf("Performance note: mergeScopes took %v for 1000+ scopes (still acceptable)", duration) } t.Logf("Performance: processed %d user scopes in %v, result has %d unique scopes", len(userScopes), duration, len(result)) } // TestScopeMergingMemoryEfficiency tests memory efficiency of the mergeScopes function func TestScopeMergingMemoryEfficiency(t *testing.T) { defaultScopes := []string{"openid", "profile", "email"} userScopes := []string{"custom1", "custom2"} // Test that the function doesn't modify input slices originalDefaults := make([]string, len(defaultScopes)) copy(originalDefaults, defaultScopes) originalUser := make([]string, len(userScopes)) copy(originalUser, userScopes) result := mergeScopes(defaultScopes, userScopes) // Verify input slices are unchanged for i, scope := range defaultScopes { if scope != originalDefaults[i] { t.Errorf("Default scopes were modified: expected %s, got %s", originalDefaults[i], scope) } } for i, scope := range userScopes { if scope != originalUser[i] { t.Errorf("User scopes were modified: expected %s, got %s", originalUser[i], scope) } } // Verify result is independent result[0] = "modified" if defaultScopes[0] == "modified" { t.Error("Modifying result affected input defaults") } expectedLength := len(defaultScopes) + len(userScopes) if len(result) != expectedLength { t.Errorf("Expected result length %d, got %d", expectedLength, len(result)) } } // TestNewWithScopeAppending tests that the New function properly merges scopes func TestNewWithScopeAppending(t *testing.T) { if testing.Short() { t.Skip("Skipping test in short mode") } // Create mock provider metadata server mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/.well-known/openid-configuration" { w.WriteHeader(http.StatusNotFound) return } metadata := ProviderMetadata{ Issuer: "https://test-issuer.com", AuthURL: "https://test-issuer.com/auth", TokenURL: "https://test-issuer.com/token", JWKSURL: "https://test-issuer.com/jwks", RevokeURL: "https://test-issuer.com/revoke", EndSessionURL: "https://test-issuer.com/end-session", } json.NewEncoder(w).Encode(metadata) })) defer mockServer.Close() tests := []struct { name string configScopes []string expectedScopes []string }{ { name: "Default scopes only", configScopes: []string{}, expectedScopes: []string{"openid", "profile", "email"}, }, { name: "Custom scopes appended", configScopes: []string{"custom_scope", "another_scope"}, expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"}, }, { name: "Overlapping scopes deduplicated", configScopes: []string{"openid", "custom_scope"}, expectedScopes: []string{"openid", "profile", "email", "custom_scope"}, }, { name: "OAuth scopes", configScopes: []string{"read:users", "write:users", "admin"}, expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin"}, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Create config with test scopes config := &Config{ ProviderURL: mockServer.URL, ClientID: "test-client", ClientSecret: "test-secret", CallbackURL: "/callback", SessionEncryptionKey: "test-encryption-key-thats-long-enough", Scopes: tc.configScopes, } // Create middleware instance middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }), config, "test") if err != nil { t.Fatalf("Failed to create middleware: %v", err) } // Wait for initialization if m, ok := middleware.(*TraefikOidc); ok { // Ensure middleware is properly closed to prevent goroutine leaks defer func() { if err := m.Close(); err != nil { t.Errorf("Failed to close middleware: %v", err) } }() select { case <-m.initComplete: case <-time.After(5 * time.Second): t.Fatalf("Middleware failed to initialize") } // Check that scopes were properly merged if !equalSlices(m.scopes, tc.expectedScopes) { t.Errorf("Expected scopes %v, got %v", tc.expectedScopes, m.scopes) } } else { t.Fatalf("Middleware is not of type *TraefikOidc") } }) } } // TestBuildAuthURLWithMergedScopes tests that the auth URL includes the properly merged scopes func TestBuildAuthURLWithMergedScopes(t *testing.T) { ts := NewTestSuite(t) ts.Setup() tests := []struct { name string expectedScopes string scopes []string }{ { name: "Default scopes only", scopes: []string{"openid", "profile", "email"}, expectedScopes: "openid profile email offline_access", }, { name: "Custom scopes appended", scopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"}, expectedScopes: "openid profile email custom_scope another_scope offline_access", }, { name: "OAuth scopes", scopes: []string{"openid", "profile", "email", "read:users", "write:users"}, expectedScopes: "openid profile email read:users write:users offline_access", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Configure the test instance with specific scopes tOidc := ts.tOidc tOidc.scopes = tc.scopes // These scopes are already deduplicated by New() tOidc.authURL = "https://auth.example.com/oauth/authorize" tOidc.issuerURL = "https://auth.example.com" // Reset overrideScopes for each test case, as it's part of tOidc state // Default to false, specific tests will set it. tOidc.overrideScopes = false // Build auth URL result := tOidc.buildAuthURL("https://app.example.com/callback", "test-state", "test-nonce", "") // Parse the resulting URL to verify scopes parsedURL, err := url.Parse(result) if err != nil { t.Fatalf("Failed to parse resulting URL: %v", err) } query := parsedURL.Query() actualScopes := query.Get("scope") if actualScopes != tc.expectedScopes { t.Errorf("Expected scopes %q, got %q", tc.expectedScopes, actualScopes) } }) } } // TestBuildAuthURL_OverrideScopes_And_OfflineAccess tests the offline_access logic in buildAuthURL // considering the overrideScopes flag. func TestBuildAuthURL_OverrideScopes_And_OfflineAccess(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Sets up ts.tOidc tests := []struct { expectedParams map[string]string name string expectedScope string initialScopes []string overrideScopes bool isGoogle bool isAzure bool }{ { name: "Override false, no user scopes, non-Google/Azure", initialScopes: []string{"openid", "profile", "email"}, // Defaults from New() when config.Scopes is empty overrideScopes: false, expectedScope: "openid profile email offline_access", }, { name: "Override false, user scopes without offline_access, non-Google/Azure", initialScopes: []string{"openid", "profile", "email", "custom1"}, // Merged and deduplicated by New() overrideScopes: false, expectedScope: "openid profile email custom1 offline_access", }, { name: "Override false, user scopes with offline_access, non-Google/Azure", initialScopes: []string{"openid", "profile", "email", "offline_access", "custom1"}, overrideScopes: false, expectedScope: "openid profile email offline_access custom1", // Order might vary based on merge, but offline_access present }, { name: "Override true, user scopes without offline_access, non-Google/Azure", initialScopes: []string{"custom1", "custom2"}, // Directly from config.Scopes, deduplicated overrideScopes: true, expectedScope: "custom1 custom2", // offline_access NOT added }, { name: "Override true, user scopes with offline_access, non-Google/Azure", initialScopes: []string{"custom1", "offline_access", "custom2"}, overrideScopes: true, expectedScope: "custom1 offline_access custom2", // User explicitly included it }, { name: "Override true, no user scopes (edge case), non-Google/Azure", initialScopes: []string{}, // config.Scopes was empty overrideScopes: true, // In this edge case, buildAuthURL's logic `(t.overrideScopes && len(t.scopes) == 0)` // will lead to offline_access being added, as it behaves like defaults. expectedScope: "offline_access", }, // Google Provider Tests (access_type=offline, prompt=consent) { name: "Google, Override false, no user scopes", initialScopes: []string{"openid", "profile", "email"}, overrideScopes: false, isGoogle: true, expectedParams: map[string]string{"access_type": "offline", "prompt": "consent"}, expectedScope: "openid profile email", // No offline_access scope for Google }, { name: "Google, Override true, user scopes", initialScopes: []string{"custom1", "custom2"}, overrideScopes: true, isGoogle: true, expectedParams: map[string]string{"access_type": "offline", "prompt": "consent"}, expectedScope: "custom1 custom2", // No offline_access scope for Google }, // Azure Provider Tests (response_mode=query, offline_access scope added if not present by user) { name: "Azure, Override false, no user scopes", initialScopes: []string{"openid", "profile", "email"}, overrideScopes: false, isAzure: true, expectedParams: map[string]string{"response_mode": "query"}, expectedScope: "openid profile email offline_access", }, { name: "Azure, Override true, user scopes without offline_access", initialScopes: []string{"custom1", "custom2"}, overrideScopes: true, isAzure: true, expectedParams: map[string]string{"response_mode": "query"}, expectedScope: "custom1 custom2", // offline_access NOT added by default when override is true }, { name: "Azure, Override true, user scopes with offline_access", initialScopes: []string{"custom1", "offline_access"}, overrideScopes: true, isAzure: true, expectedParams: map[string]string{"response_mode": "query"}, expectedScope: "custom1 offline_access", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { tOidc := ts.tOidc tOidc.scopes = tc.initialScopes // Set the scopes as if they came from New() tOidc.overrideScopes = tc.overrideScopes // Adjust issuerURL for provider-specific tests originalIssuerURL := tOidc.issuerURL if tc.isGoogle { tOidc.issuerURL = "https://accounts.google.com" } else if tc.isAzure { tOidc.issuerURL = "https://login.microsoftonline.com/common" } else { tOidc.issuerURL = "https://generic-provider.com" // Non-Google/Azure } authURLString := tOidc.buildAuthURL("http://localhost/callback", "state123", "nonce123", "challenge123") parsedAuthURL, err := url.Parse(authURLString) if err != nil { t.Fatalf("Failed to parse auth URL: %v", err) } query := parsedAuthURL.Query() actualScope := query.Get("scope") if actualScope != tc.expectedScope { t.Errorf("Expected scope string %q, got %q", tc.expectedScope, actualScope) } if tc.expectedParams != nil { for k, v := range tc.expectedParams { if query.Get(k) != v { t.Errorf("Expected param %s=%s, got %s", k, v, query.Get(k)) } } } // Restore original issuerURL for next test tOidc.issuerURL = originalIssuerURL }) } } // TestBuildAuthURL_SpecificUserCase tests the buildAuthURL function with the specific user-reported scenario. func TestBuildAuthURL_SpecificUserCase(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Basic setup for tOidc // Configure the TraefikOidc instance for the specific scenario tOidc := ts.tOidc tOidc.scopes = []string{"email", "test3"} // This is what t.scopes should be after New() tOidc.overrideScopes = true tOidc.issuerURL = "https://generic-provider.com" // Non-Google/Azure tOidc.authURL = "https://generic-provider.com/auth" // Dummy auth URL tOidc.clientID = "test-client-id" // Expected scope string in the URL expectedScopeString := "email test3" // Call buildAuthURL authURLString := tOidc.buildAuthURL("http://localhost/callback", "test-state", "test-nonce", "") // Parse the resulting URL parsedAuthURL, err := url.Parse(authURLString) if err != nil { t.Fatalf("Failed to parse generated auth URL %q: %v", authURLString, err) } // Get the 'scope' query parameter actualScopeString := parsedAuthURL.Query().Get("scope") // Assert that the scope string is as expected if actualScopeString != expectedScopeString { t.Errorf("Expected scope parameter to be %q, but got %q. Full URL: %s", expectedScopeString, actualScopeString, authURLString) } // Additionally, ensure 'offline_access' was not added if strings.Contains(actualScopeString, "offline_access") { t.Errorf("Scope parameter %q should not contain 'offline_access' when overrideScopes is true and it's not in tOidc.scopes", actualScopeString) } }