diff --git a/jwt.go b/jwt.go index dde5ac3..c40301e 100644 --- a/jwt.go +++ b/jwt.go @@ -123,11 +123,12 @@ func parseJWT(tokenString string) (*JWT, error) { // Parameters: // - issuerURL: The expected issuer URL (e.g., "https://accounts.google.com"). // - clientID: The expected audience value (the client ID of this application). +// - skipReplayCheck: If true, skips JTI replay detection (used for revalidation of cached tokens). // // Returns: // - nil if all standard claims are valid. // - An error describing the first validation failure encountered. -func (j *JWT) Verify(issuerURL, clientID string) error { +func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error { // Validate algorithm to prevent algorithm switching attacks alg, ok := j.Header["alg"].(string) if !ok { @@ -183,7 +184,10 @@ func (j *JWT) Verify(issuerURL, clientID string) error { } // Implement replay protection by checking the jti (JWT ID) - if jti, ok := claims["jti"].(string); ok { + // Skip replay check if explicitly requested (for revalidation scenarios) + shouldSkipReplay := len(skipReplayCheck) > 0 && skipReplayCheck[0] + + if jti, ok := claims["jti"].(string); ok && !shouldSkipReplay { // Skip replay detection for tokens that are being verified from the cache if j.Token == "" { // This is a parsed JWT without the original token string, diff --git a/main.go b/main.go index 576253e..45302f1 100644 --- a/main.go +++ b/main.go @@ -363,8 +363,8 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error return fmt.Errorf("signature verification failed: %w", err) } - // Verify standard claims - if err := jwt.Verify(t.issuerURL, t.clientID); err != nil { + // Verify standard claims - skip replay check since it's already handled in VerifyToken + if err := jwt.Verify(t.issuerURL, t.clientID, true); err != nil { return fmt.Errorf("standard claim verification failed: %w", err) } diff --git a/main_test.go b/main_test.go index a65e74a..32423c4 100644 --- a/main_test.go +++ b/main_test.go @@ -2806,3 +2806,761 @@ func TestVerifyTimeConstraint(t *testing.T) { }) } } // Add missing closing brace for TestVerifyTimeConstraint + +// ===== JWT REPLAY DETECTION TESTS ===== +// These tests ensure the replay detection fix works correctly and prevents regressions + +// TestJWTVerifyWithSkipReplayCheck tests the new skipReplayCheck parameter functionality +func TestJWTVerifyWithSkipReplayCheck(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + // Clear the global replay cache before test + replayCacheMu.Lock() + replayCache = NewCache() + replayCache.SetMaxSize(10000) + replayCacheMu.Unlock() + + // Create a test JWT with unique JTI + jti := generateRandomString(16) + now := time.Now() + exp := now.Add(1 * time.Hour).Unix() + iat := now.Unix() + nbf := now.Unix() + + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce", + "jti": jti, + }) + if err != nil { + t.Fatalf("Failed to create test JWT: %v", err) + } + + jwt, err := parseJWT(token) + if err != nil { + t.Fatalf("Failed to parse JWT: %v", err) + } + + tests := []struct { + name string + skipReplayCheck bool + firstCall bool + expectError bool + errorContains string + }{ + { + name: "First verification with skipReplayCheck=false should succeed", + skipReplayCheck: false, + firstCall: true, + expectError: false, + }, + { + name: "Second verification with skipReplayCheck=false should fail (replay detected)", + skipReplayCheck: false, + firstCall: false, + expectError: true, + errorContains: "token replay detected", + }, + { + name: "Verification with skipReplayCheck=true should always succeed", + skipReplayCheck: true, + firstCall: false, // Even on subsequent calls + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.firstCall { + // Clear replay cache for first call tests + replayCacheMu.Lock() + replayCache = NewCache() + replayCache.SetMaxSize(10000) + replayCacheMu.Unlock() + } + + err := jwt.Verify("https://test-issuer.com", "test-client-id", tc.skipReplayCheck) + + if tc.expectError { + if err == nil { + t.Errorf("Expected error containing '%s', but got nil", tc.errorContains) + } else if !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err) + } + } else { + if err != nil { + t.Errorf("Expected no error, but got: %v", err) + } + } + }) + } +} + +// TestJWTVerifyBackwardCompatibility tests that calls without the skipReplayCheck parameter default to replay checking +func TestJWTVerifyBackwardCompatibility(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + // Clear the global replay cache + replayCacheMu.Lock() + replayCache = NewCache() + replayCache.SetMaxSize(10000) + replayCacheMu.Unlock() + + // Create a test JWT with unique JTI + jti := generateRandomString(16) + now := time.Now() + exp := now.Add(1 * time.Hour).Unix() + iat := now.Unix() + nbf := now.Unix() + + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce", + "jti": jti, + }) + if err != nil { + t.Fatalf("Failed to create test JWT: %v", err) + } + + jwt, err := parseJWT(token) + if err != nil { + t.Fatalf("Failed to parse JWT: %v", err) + } + + // First call with old signature (no skipReplayCheck parameter) should succeed + err = jwt.Verify("https://test-issuer.com", "test-client-id") + if err != nil { + t.Errorf("First verification should succeed, got: %v", err) + } + + // Second call with old signature should fail due to replay detection + err = jwt.Verify("https://test-issuer.com", "test-client-id") + if err == nil { + t.Error("Second verification should fail due to replay detection") + } else if !strings.Contains(err.Error(), "token replay detected") { + t.Errorf("Expected 'token replay detected' error, got: %v", err) + } +} + +// TestTokenReplayDetectionFalsePositiveFix tests the specific scenario that was causing false positives +func TestTokenReplayDetectionFalsePositiveFix(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + // Clear the global replay cache + replayCacheMu.Lock() + replayCache = NewCache() + replayCache.SetMaxSize(10000) + replayCacheMu.Unlock() + + // Create a test JWT with unique JTI + jti := generateRandomString(16) + now := time.Now() + exp := now.Add(1 * time.Hour).Unix() + iat := now.Unix() + nbf := now.Unix() + + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce", + "jti": jti, + }) + if err != nil { + t.Fatalf("Failed to create test JWT: %v", err) + } + + // Simulate the authentication flow that was causing false positives: + // 1. Initial authentication adds JTI to cache + // 2. Subsequent request validation should not trigger false positive + + // Step 1: Initial authentication (this would add JTI to cache) + jwt1, err := parseJWT(token) + if err != nil { + t.Fatalf("Failed to parse JWT for initial auth: %v", err) + } + + err = jwt1.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check + if err != nil { + t.Fatalf("Initial authentication should succeed: %v", err) + } + + // Step 2: Subsequent request validation (this should skip replay check to avoid false positive) + jwt2, err := parseJWT(token) + if err != nil { + t.Fatalf("Failed to parse JWT for subsequent request: %v", err) + } + + err = jwt2.Verify("https://test-issuer.com", "test-client-id", true) // Skip replay check + if err != nil { + t.Errorf("Subsequent request validation should succeed with skipReplayCheck=true: %v", err) + } + + // Step 3: Verify that actual replay attacks are still detected + jwt3, err := parseJWT(token) + if err != nil { + t.Fatalf("Failed to parse JWT for replay attack test: %v", err) + } + + err = jwt3.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check + if err == nil { + t.Error("Actual replay attack should be detected when skipReplayCheck=false") + } else if !strings.Contains(err.Error(), "token replay detected") { + t.Errorf("Expected 'token replay detected' error, got: %v", err) + } +} + +// TestAuthenticationFlowReplayDetection tests the complete authentication flow +func TestAuthenticationFlowReplayDetection(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + // Clear the global replay cache + replayCacheMu.Lock() + replayCache = NewCache() + replayCache.SetMaxSize(10000) + replayCacheMu.Unlock() + + // Create a test JWT with unique JTI + jti := generateRandomString(16) + now := time.Now() + exp := now.Add(1 * time.Hour).Unix() + iat := now.Unix() + nbf := now.Unix() + + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce", + "jti": jti, + }) + if err != nil { + t.Fatalf("Failed to create test JWT: %v", err) + } + + // Test the complete flow: + // 1. Initial authentication (should add JTI to cache) + // 2. Multiple subsequent requests (should not trigger false positives) + // 3. Actual replay attack from different source (should be detected) + + // Step 1: Initial authentication + err = ts.tOidc.VerifyToken(token) + if err != nil { + t.Fatalf("Initial authentication should succeed: %v", err) + } + + // Verify JTI is in cache + replayCacheMu.Lock() + _, exists := replayCache.Get(jti) + replayCacheMu.Unlock() + if !exists { + t.Error("JTI should be added to replay cache during initial authentication") + } + + // Step 2: Subsequent requests (simulate normal request processing) + // These should use the token cache and skip replay detection + for i := 0; i < 3; i++ { + err = ts.tOidc.VerifyToken(token) + if err != nil { + t.Errorf("Subsequent request %d should succeed: %v", i+1, err) + } + } + + // Step 3: Simulate actual replay attack by directly calling JWT.Verify with replay check + jwt, err := parseJWT(token) + if err != nil { + t.Fatalf("Failed to parse JWT for replay attack test: %v", err) + } + + err = jwt.Verify("https://test-issuer.com", "test-client-id", false) // Force replay check + if err == nil { + t.Error("Actual replay attack should be detected") + } else if !strings.Contains(err.Error(), "token replay detected") { + t.Errorf("Expected 'token replay detected' error, got: %v", err) + } +} + +// TestActualReplayAttackDetection ensures real replay attacks are still properly detected +func TestActualReplayAttackDetection(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + // Clear the global replay cache + replayCacheMu.Lock() + replayCache = NewCache() + replayCache.SetMaxSize(10000) + replayCacheMu.Unlock() + + // Create a test JWT with unique JTI + jti := generateRandomString(16) + now := time.Now() + exp := now.Add(1 * time.Hour).Unix() + iat := now.Unix() + nbf := now.Unix() + + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce", + "jti": jti, + }) + if err != nil { + t.Fatalf("Failed to create test JWT: %v", err) + } + + jwt, err := parseJWT(token) + if err != nil { + t.Fatalf("Failed to parse JWT: %v", err) + } + + // First verification should succeed + err = jwt.Verify("https://test-issuer.com", "test-client-id", false) + if err != nil { + t.Fatalf("First verification should succeed: %v", err) + } + + // Simulate different types of replay attacks + replayTests := []struct { + name string + description string + }{ + { + name: "Direct replay attack", + description: "Same token used again with replay checking enabled", + }, + { + name: "Replay from different source", + description: "Token intercepted and replayed by attacker", + }, + } + + for _, rt := range replayTests { + t.Run(rt.name, func(t *testing.T) { + // Parse token again (simulating replay) + replayJWT, err := parseJWT(token) + if err != nil { + t.Fatalf("Failed to parse JWT for replay test: %v", err) + } + + // Attempt replay with normal replay checking + err = replayJWT.Verify("https://test-issuer.com", "test-client-id", false) + if err == nil { + t.Errorf("Replay attack should be detected for: %s", rt.description) + } else if !strings.Contains(err.Error(), "token replay detected") { + t.Errorf("Expected 'token replay detected' error for %s, got: %v", rt.description, err) + } + }) + } +} + +// TestConcurrentTokenValidation tests thread safety of replay detection +func TestConcurrentTokenValidation(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + // Configure rate limiter to allow more requests for concurrent testing + ts.tOidc.limiter = rate.NewLimiter(rate.Limit(1000), 1000) // Allow 1000 requests per second with burst of 1000 + + // Clear the global replay cache + replayCacheMu.Lock() + replayCache = NewCache() + replayCache.SetMaxSize(10000) + replayCacheMu.Unlock() + + // Create multiple tokens with unique JTIs + var tokens []string + var jtis []string + now := time.Now() + exp := now.Add(1 * time.Hour).Unix() + iat := now.Unix() + nbf := now.Unix() + + for i := 0; i < 10; i++ { + jti := generateRandomString(16) + jtis = append(jtis, jti) + + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce", + "jti": jti, + }) + if err != nil { + t.Fatalf("Failed to create test JWT %d: %v", i, err) + } + tokens = append(tokens, token) + } + + // Test concurrent validation + const numGoroutines = 20 + const numIterations = 5 + + results := make(chan error, numGoroutines*numIterations) + + for g := 0; g < numGoroutines; g++ { + go func(goroutineID int) { + for i := 0; i < numIterations; i++ { + tokenIndex := (goroutineID + i) % len(tokens) + token := tokens[tokenIndex] + + // First validation should succeed + err := ts.tOidc.VerifyToken(token) + results <- err + + // Subsequent validation with same token should also succeed (uses cache) + err = ts.tOidc.VerifyToken(token) + results <- err + } + }(g) + } + + // Collect results + var errors []error + for i := 0; i < numGoroutines*numIterations*2; i++ { + if err := <-results; err != nil { + errors = append(errors, err) + } + } + + // All validations should succeed (no race conditions) + if len(errors) > 0 { + t.Errorf("Expected no errors in concurrent validation, got %d errors: %v", len(errors), errors) + } + + // Verify all JTIs are in cache + replayCacheMu.Lock() + for i, jti := range jtis { + if _, exists := replayCache.Get(jti); !exists { + t.Errorf("JTI %d (%s) should be in replay cache", i, jti) + } + } + replayCacheMu.Unlock() +} + +// TestJTIBlacklistBehavior tests the JTI blacklist cache management +func TestJTIBlacklistBehavior(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + // Clear the global replay cache + replayCacheMu.Lock() + replayCache = NewCache() + replayCache.SetMaxSize(10000) + replayCacheMu.Unlock() + + // Create a test JWT with unique JTI + jti := generateRandomString(16) + now := time.Now() + exp := now.Add(1 * time.Hour).Unix() + iat := now.Unix() + nbf := now.Unix() + + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce", + "jti": jti, + }) + if err != nil { + t.Fatalf("Failed to create test JWT: %v", err) + } + + // Test JTI blacklist behavior + tests := []struct { + name string + action func() error + expectError bool + description string + }{ + { + name: "Initial verification adds JTI to blacklist", + action: func() error { + return ts.tOidc.VerifyToken(token) + }, + expectError: false, + description: "First verification should succeed and add JTI to blacklist", + }, + { + name: "JTI exists in blacklist after verification", + action: func() error { + replayCacheMu.Lock() + defer replayCacheMu.Unlock() + if _, exists := replayCache.Get(jti); !exists { + return fmt.Errorf("JTI not found in blacklist cache") + } + return nil + }, + expectError: false, + description: "JTI should be present in blacklist cache", + }, + { + name: "Subsequent verification uses cache (no replay check)", + action: func() error { + return ts.tOidc.VerifyToken(token) + }, + expectError: false, + description: "Subsequent verification should succeed using token cache", + }, + { + name: "Direct JWT verification detects replay", + action: func() error { + jwt, err := parseJWT(token) + if err != nil { + return err + } + return jwt.Verify("https://test-issuer.com", "test-client-id", false) + }, + expectError: true, + description: "Direct JWT verification should detect replay", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.action() + + if tc.expectError { + if err == nil { + t.Errorf("Expected error for %s, but got nil", tc.description) + } + } else { + if err != nil { + t.Errorf("Expected no error for %s, but got: %v", tc.description, err) + } + } + }) + } +} + +// TestSessionBasedTokenRevalidation tests token revalidation in session-based scenarios +func TestSessionBasedTokenRevalidation(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + // Clear the global replay cache + replayCacheMu.Lock() + replayCache = NewCache() + replayCache.SetMaxSize(10000) + replayCacheMu.Unlock() + + // Create a test JWT with unique JTI + jti := generateRandomString(16) + now := time.Now() + exp := now.Add(1 * time.Hour).Unix() + iat := now.Unix() + nbf := now.Unix() + + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce", + "jti": jti, + }) + if err != nil { + t.Fatalf("Failed to create test JWT: %v", err) + } + + // Simulate session-based token revalidation scenario + // This tests the specific case that was causing false positives + + // Step 1: Initial authentication (callback processing) + err = ts.tOidc.VerifyToken(token) + if err != nil { + t.Fatalf("Initial authentication should succeed: %v", err) + } + + // Step 2: Multiple session-based requests (normal request processing) + // These should not trigger replay detection false positives + for i := 0; i < 5; i++ { + err = ts.tOidc.VerifyToken(token) + if err != nil { + t.Errorf("Session request %d should succeed: %v", i+1, err) + } + } + + // Step 3: Verify token is in both caches appropriately + // Check token cache + if _, exists := ts.tOidc.tokenCache.Get(token); !exists { + t.Error("Token should be in token cache") + } + + // Check replay cache + replayCacheMu.Lock() + _, inReplayCache := replayCache.Get(jti) + replayCacheMu.Unlock() + if !inReplayCache { + t.Error("JTI should be in replay cache") + } + + // Step 4: Verify that clearing token cache still allows validation + ts.tOidc.tokenCache = NewTokenCache() // Clear token cache + + err = ts.tOidc.VerifyToken(token) + if err != nil { + t.Errorf("Token validation should succeed even after cache clear: %v", err) + } +} + +// TestEdgeCasesWithDifferentTokenTypes tests replay detection with different token types +func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) { + ts := &TestSuite{t: t} + ts.Setup() + + // Clear the global replay cache + replayCacheMu.Lock() + replayCache = NewCache() + replayCache.SetMaxSize(10000) + replayCacheMu.Unlock() + + now := time.Now() + exp := now.Add(1 * time.Hour).Unix() + iat := now.Unix() + nbf := now.Unix() + + tests := []struct { + name string + tokenType string + claims map[string]interface{} + expectError bool + }{ + { + name: "ID Token with JTI", + tokenType: "id_token", + claims: map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce", + "jti": generateRandomString(16), + "token_type": "id_token", + }, + expectError: false, + }, + { + name: "Access Token with JTI", + tokenType: "access_token", + claims: map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "scope": "openid profile email", + "jti": generateRandomString(16), + "token_type": "access_token", + }, + expectError: false, + }, + { + name: "Token without JTI", + tokenType: "no_jti", + claims: map[string]interface{}{ + "iss": "https://test-issuer.com", + "aud": "test-client-id", + "exp": exp, + "iat": iat, + "nbf": nbf, + "sub": "test-subject", + "email": "user@example.com", + "nonce": "test-nonce", + // No JTI claim + }, + expectError: false, // Should still work, just no replay protection + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create token with specific claims + token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims) + if err != nil { + t.Fatalf("Failed to create test JWT: %v", err) + } + + // First verification should succeed + err = ts.tOidc.VerifyToken(token) + if tc.expectError { + if err == nil { + t.Errorf("Expected error for token type %s, but got nil", tc.tokenType) + } + } else { + if err != nil { + t.Errorf("Expected no error for token type %s, but got: %v", tc.tokenType, err) + } + } + + // Second verification should also succeed (uses cache) + if !tc.expectError { + err = ts.tOidc.VerifyToken(token) + if err != nil { + t.Errorf("Second verification should succeed for token type %s: %v", tc.tokenType, err) + } + } + + // Test direct JWT verification for replay detection + if !tc.expectError && tc.claims["jti"] != nil { + jwt, err := parseJWT(token) + if err != nil { + t.Fatalf("Failed to parse JWT: %v", err) + } + + // This should detect replay for tokens with JTI + err = jwt.Verify("https://test-issuer.com", "test-client-id", false) + if err == nil { + t.Errorf("Expected replay detection for token type %s with JTI", tc.tokenType) + } else if !strings.Contains(err.Error(), "token replay detected") { + t.Errorf("Expected 'token replay detected' error for token type %s, got: %v", tc.tokenType, err) + } + } + }) + } +}