package traefikoidc import ( "crypto/rand" "crypto/rsa" "encoding/base64" "fmt" "net/http" "net/http/httptest" "strings" "testing" "time" "golang.org/x/time/rate" ) // TestConfigAudienceValidation tests the Config.Validate() method for the audience field func TestConfigAudienceValidation(t *testing.T) { tests := []struct { name string audience string wantErr bool errContains string }{ { name: "Empty audience is valid for backward compatibility", audience: "", wantErr: false, }, { name: "Valid HTTPS URL audience Auth0 format", audience: "https://api.example.com", wantErr: false, }, { name: "Valid identifier audience", audience: "my-api", wantErr: false, }, { name: "Valid Azure AD Application ID URI format", audience: "api://12345-guid-67890", wantErr: false, }, { name: "Valid Auth0 API identifier", audience: "https://my-company.auth0.com/api/v2/", wantErr: false, }, { name: "HTTP URL audience should fail", audience: "http://api.example.com", wantErr: true, errContains: "must use HTTPS", }, { name: "Audience with wildcard should fail", audience: "https://api.*.example.com", wantErr: true, errContains: "must not contain wildcards", }, { name: "Audience with single asterisk should fail", audience: "*", wantErr: true, errContains: "must not contain wildcards", }, { name: "Audience over 256 characters should fail", audience: strings.Repeat("a", 257), wantErr: true, errContains: "must not exceed 256 characters", }, { name: "Audience with newline should fail", audience: "my-api\ninjection", wantErr: true, errContains: "contains invalid characters", }, { name: "Audience with carriage return should fail", audience: "my-api\rinjection", wantErr: true, errContains: "contains invalid characters", }, { name: "Audience with tab should fail", audience: "my-api\tinjection", wantErr: true, errContains: "contains invalid characters", }, { name: "Valid audience exactly 256 characters", audience: strings.Repeat("a", 256), wantErr: false, }, { name: "Valid simple identifier", audience: "my-service-api", wantErr: false, }, { name: "Valid URN format", audience: "urn:myservice:api:v1", wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { config := CreateConfig() config.ProviderURL = "https://provider.example.com" config.ClientID = "test-client-id" config.ClientSecret = "test-client-secret" config.CallbackURL = "/callback" config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength) config.Audience = tt.audience err := config.Validate() if (err != nil) != tt.wantErr { t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { t.Errorf("Error message should contain %q, got: %v", tt.errContains, err) } }) } } // TestJWTAudienceVerification tests JWT verification with custom audience values func TestJWTAudienceVerification(t *testing.T) { // Create cleanup helper tc := newTestCleanup(t) // Generate RSA key for signing JWTs rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("Failed to generate RSA key: %v", err) } rsaPublicKey := &rsaPrivateKey.PublicKey // Create JWK jwk := JWK{ Kty: "RSA", Kid: "test-key-id", Alg: "RS256", N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()), E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), } jwks := &JWKSet{ Keys: []JWK{jwk}, } mockJWKCache := &MockJWKCache{ JWKS: jwks, Err: nil, } logger := NewLogger("debug") tokenBlacklist := tc.addCache(NewCache()) tokenCache := tc.addTokenCache(NewTokenCache()) tests := []struct { name string configAudience string tokenAudience interface{} wantErr bool errContains string skipReplayCheck bool }{ { name: "JWT with string aud matching configured audience", configAudience: "https://api.example.com", tokenAudience: "https://api.example.com", wantErr: false, skipReplayCheck: true, }, { name: "JWT with array aud containing configured audience", configAudience: "https://api.example.com", tokenAudience: []interface{}{"https://other.com", "https://api.example.com", "https://another.com"}, wantErr: false, skipReplayCheck: true, }, { name: "JWT with string aud NOT matching configured audience", configAudience: "https://api.example.com", tokenAudience: "https://wrong-api.example.com", wantErr: true, errContains: "invalid audience", skipReplayCheck: true, }, { name: "JWT with array aud NOT containing configured audience", configAudience: "https://api.example.com", tokenAudience: []interface{}{"https://other.com", "https://another.com"}, wantErr: true, errContains: "invalid audience", skipReplayCheck: true, }, { name: "JWT with clientID as aud when no custom audience configured", configAudience: "", tokenAudience: "test-client-id", wantErr: false, skipReplayCheck: true, }, { name: "JWT with empty string aud", configAudience: "https://api.example.com", tokenAudience: "", wantErr: true, errContains: "invalid audience", skipReplayCheck: true, }, { name: "Azure AD Application ID URI format", configAudience: "api://12345-app-id", tokenAudience: "api://12345-app-id", wantErr: false, skipReplayCheck: true, }, { name: "Auth0 custom API audience", configAudience: "https://mycompany.com/api", tokenAudience: "https://mycompany.com/api", wantErr: false, skipReplayCheck: true, }, { name: "Token confusion attack - audience for different service", configAudience: "https://service-a.example.com", tokenAudience: "https://service-b.example.com", wantErr: true, errContains: "invalid audience", skipReplayCheck: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create TraefikOidc instance tOidc := &TraefikOidc{ issuerURL: "https://test-issuer.com", clientID: "test-client-id", clientSecret: "test-client-secret", jwkCache: mockJWKCache, jwksURL: "https://test-jwks-url.com", tokenBlacklist: tokenBlacklist, tokenCache: tokenCache, limiter: rate.NewLimiter(rate.Every(time.Second), 10), logger: logger, httpClient: &http.Client{}, } // Set up the token verifier and JWT verifier tOidc.jwtVerifier = tOidc tOidc.tokenVerifier = tOidc // Determine the expected audience for validation expectedAudience := tt.configAudience if expectedAudience == "" { expectedAudience = tOidc.clientID } // Set the audience field on the tOidc instance tOidc.audience = expectedAudience // Create JWT with specified audience jti := generateRandomString(16) if tt.skipReplayCheck { // Use a unique JTI for each test to avoid replay detection jti = fmt.Sprintf("test-%s-%s", tt.name, jti) } jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": tt.tokenAudience, "exp": float64(time.Now().Add(1 * time.Hour).Unix()), "iat": float64(time.Now().Add(-2 * time.Minute).Unix()), "sub": "test-subject", "email": "user@example.com", "jti": jti, }) if err != nil { t.Fatalf("Failed to create test JWT: %v", err) } // Verify the token err = tOidc.VerifyToken(jwt) if (err != nil) != tt.wantErr { t.Errorf("VerifyToken() error = %v, wantErr %v", err, tt.wantErr) return } if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { t.Errorf("Error message should contain %q, got: %v", tt.errContains, err) } }) } } // TestJWTAudienceBackwardCompatibility tests that existing behavior is preserved // when the Audience field is not set func TestJWTAudienceBackwardCompatibility(t *testing.T) { ts := NewTestSuite(t) ts.Setup() // Test with no custom audience configured - should use clientID jwt, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", // Should match clientID "exp": float64(time.Now().Add(1 * time.Hour).Unix()), "iat": float64(time.Now().Add(-2 * time.Minute).Unix()), "sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16), }) if err != nil { t.Fatalf("Failed to create test JWT: %v", err) } err = ts.tOidc.VerifyToken(jwt) if err != nil { t.Errorf("Backward compatibility broken: VerifyToken() error = %v, expected nil", err) } } // TestAudienceIntegrationAuth0Scenario tests Auth0-specific use case func TestAudienceIntegrationAuth0Scenario(t *testing.T) { // Create cleanup helper tc := newTestCleanup(t) // Simulate Auth0 scenario: custom audience for API access config := CreateConfig() config.ProviderURL = "https://mycompany.auth0.com" config.ClientID = "auth0-client-id" config.ClientSecret = "auth0-client-secret" config.CallbackURL = "/callback" config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength) config.Audience = "https://api.mycompany.com" // Custom API audience // Validate config if err := config.Validate(); err != nil { t.Fatalf("Auth0 config validation failed: %v", err) } // Generate test keys rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("Failed to generate RSA key: %v", err) } rsaPublicKey := &rsaPrivateKey.PublicKey jwk := JWK{ Kty: "RSA", Kid: "auth0-key-id", Alg: "RS256", N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()), E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), } jwks := &JWKSet{ Keys: []JWK{jwk}, } mockJWKCache := &MockJWKCache{ JWKS: jwks, Err: nil, } logger := NewLogger("debug") tokenBlacklist := tc.addCache(NewCache()) tokenCache := tc.addTokenCache(NewTokenCache()) tOidc := &TraefikOidc{ issuerURL: config.ProviderURL, clientID: config.ClientID, clientSecret: config.ClientSecret, audience: config.Audience, // Set audience from config jwkCache: mockJWKCache, jwksURL: "https://mycompany.auth0.com/.well-known/jwks.json", tokenBlacklist: tokenBlacklist, tokenCache: tokenCache, limiter: rate.NewLimiter(rate.Every(time.Second), 10), logger: logger, httpClient: &http.Client{}, } // Default audience to clientID if not specified if tOidc.audience == "" { tOidc.audience = tOidc.clientID } tOidc.jwtVerifier = tOidc tOidc.tokenVerifier = tOidc t.Run("Valid Auth0 API access token with custom audience", func(t *testing.T) { jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{ "iss": config.ProviderURL, "aud": config.Audience, // Matches configured audience "exp": float64(time.Now().Add(1 * time.Hour).Unix()), "iat": float64(time.Now().Add(-2 * time.Minute).Unix()), "sub": "auth0|123456", "email": "user@example.com", "jti": generateRandomString(16), }) if err != nil { t.Fatalf("Failed to create Auth0 JWT: %v", err) } err = tOidc.VerifyToken(jwt) if err != nil { t.Errorf("Auth0 token verification failed: %v", err) } }) t.Run("Auth0 token with clientID instead of API audience should fail", func(t *testing.T) { jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{ "iss": config.ProviderURL, "aud": config.ClientID, // Using clientID instead of API audience "exp": float64(time.Now().Add(1 * time.Hour).Unix()), "iat": float64(time.Now().Add(-2 * time.Minute).Unix()), "sub": "auth0|123456", "email": "user@example.com", "jti": generateRandomString(16), }) if err != nil { t.Fatalf("Failed to create Auth0 JWT: %v", err) } err = tOidc.VerifyToken(jwt) if err == nil { t.Error("Auth0 token with wrong audience should have been rejected") } else if !strings.Contains(err.Error(), "invalid audience") { t.Errorf("Expected 'invalid audience' error, got: %v", err) } }) } // TestAudienceIntegrationAzureADScenario tests Azure AD-specific use case func TestAudienceIntegrationAzureADScenario(t *testing.T) { // Create cleanup helper tc := newTestCleanup(t) // Simulate Azure AD scenario: Application ID URI format config := CreateConfig() config.ProviderURL = "https://login.microsoftonline.com/tenant-id/v2.0" config.ClientID = "azure-client-id" config.ClientSecret = "azure-client-secret" config.CallbackURL = "/callback" config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength) config.Audience = "api://12345-abcd-6789-efgh" // Azure AD Application ID URI // Validate config if err := config.Validate(); err != nil { t.Fatalf("Azure AD config validation failed: %v", err) } // Generate test keys rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("Failed to generate RSA key: %v", err) } rsaPublicKey := &rsaPrivateKey.PublicKey jwk := JWK{ Kty: "RSA", Kid: "azure-key-id", Alg: "RS256", N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()), E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), } jwks := &JWKSet{ Keys: []JWK{jwk}, } mockJWKCache := &MockJWKCache{ JWKS: jwks, Err: nil, } logger := NewLogger("debug") tokenBlacklist := tc.addCache(NewCache()) tokenCache := tc.addTokenCache(NewTokenCache()) tOidc := &TraefikOidc{ issuerURL: config.ProviderURL, clientID: config.ClientID, clientSecret: config.ClientSecret, audience: config.Audience, // Set audience from config jwkCache: mockJWKCache, jwksURL: config.ProviderURL + "/.well-known/jwks.json", tokenBlacklist: tokenBlacklist, tokenCache: tokenCache, limiter: rate.NewLimiter(rate.Every(time.Second), 10), logger: logger, httpClient: &http.Client{}, } // Default audience to clientID if not specified if tOidc.audience == "" { tOidc.audience = tOidc.clientID } tOidc.jwtVerifier = tOidc tOidc.tokenVerifier = tOidc t.Run("Valid Azure AD token with Application ID URI audience", func(t *testing.T) { jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{ "iss": config.ProviderURL, "aud": config.Audience, // Matches Application ID URI "exp": float64(time.Now().Add(1 * time.Hour).Unix()), "iat": float64(time.Now().Add(-2 * time.Minute).Unix()), "sub": "azure-user-id", "email": "user@example.com", "oid": "object-id-12345", "tid": "tenant-id", "jti": generateRandomString(16), }) if err != nil { t.Fatalf("Failed to create Azure AD JWT: %v", err) } err = tOidc.VerifyToken(jwt) if err != nil { t.Errorf("Azure AD token verification failed: %v", err) } }) t.Run("Azure AD token with multiple audiences including correct one", func(t *testing.T) { jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{ "iss": config.ProviderURL, "aud": []interface{}{config.ClientID, config.Audience, "https://graph.microsoft.com"}, "exp": float64(time.Now().Add(1 * time.Hour).Unix()), "iat": float64(time.Now().Add(-2 * time.Minute).Unix()), "sub": "azure-user-id", "email": "user@example.com", "oid": "object-id-12345", "tid": "tenant-id", "jti": generateRandomString(16), }) if err != nil { t.Fatalf("Failed to create Azure AD JWT: %v", err) } err = tOidc.VerifyToken(jwt) if err != nil { t.Errorf("Azure AD token with multiple audiences verification failed: %v", err) } }) } // TestAudienceSecurityTokenConfusionAttack tests security against token confusion attacks func TestAudienceSecurityTokenConfusionAttack(t *testing.T) { // Create cleanup helper tc := newTestCleanup(t) // Generate test keys rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("Failed to generate RSA key: %v", err) } rsaPublicKey := &rsaPrivateKey.PublicKey jwk := JWK{ Kty: "RSA", Kid: "test-key-id", Alg: "RS256", N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()), E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), } jwks := &JWKSet{ Keys: []JWK{jwk}, } mockJWKCache := &MockJWKCache{ JWKS: jwks, Err: nil, } logger := NewLogger("debug") tokenBlacklist := tc.addCache(NewCache()) tokenCache := tc.addTokenCache(NewTokenCache()) // Service A configuration serviceA := &TraefikOidc{ issuerURL: "https://auth.example.com", clientID: "service-a-client-id", clientSecret: "service-a-secret", audience: "service-a-client-id", // Service A uses its clientID as audience jwkCache: mockJWKCache, jwksURL: "https://auth.example.com/.well-known/jwks.json", tokenBlacklist: tokenBlacklist, tokenCache: tokenCache, limiter: rate.NewLimiter(rate.Every(time.Second), 10), logger: logger, httpClient: &http.Client{}, } serviceA.jwtVerifier = serviceA serviceA.tokenVerifier = serviceA t.Run("Token confusion - Try to use service B token on service A", func(t *testing.T) { // Create a token intended for service B serviceBToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://auth.example.com", "aud": "https://service-b.example.com", // For service B "exp": float64(time.Now().Add(1 * time.Hour).Unix()), "iat": float64(time.Now().Add(-2 * time.Minute).Unix()), "sub": "attacker@example.com", "email": "attacker@example.com", "jti": generateRandomString(16), }) if err != nil { t.Fatalf("Failed to create service B token: %v", err) } // Try to verify the service B token on service A err = serviceA.VerifyToken(serviceBToken) if err == nil { t.Error("SECURITY VULNERABILITY: Token confusion attack succeeded - service B token was accepted by service A") } else if !strings.Contains(err.Error(), "invalid audience") { t.Errorf("Expected 'invalid audience' error for token confusion, got: %v", err) } else { t.Logf("Token confusion attack correctly prevented: %v", err) } }) } // TestAudienceSecurityWildcardInjection tests that wildcards are rejected func TestAudienceSecurityWildcardInjection(t *testing.T) { tests := []struct { name string audience string }{ { name: "Single asterisk", audience: "*", }, { name: "Wildcard in URL", audience: "https://*.example.com", }, { name: "Wildcard in path", audience: "https://api.example.com/*", }, { name: "Multiple wildcards", audience: "https://*.*.example.com", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { config := CreateConfig() config.ProviderURL = "https://provider.example.com" config.ClientID = "test-client-id" config.ClientSecret = "test-client-secret" config.CallbackURL = "/callback" config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength) config.Audience = tt.audience err := config.Validate() if err == nil { t.Errorf("SECURITY VULNERABILITY: Wildcard audience %q was not rejected", tt.audience) } else if !strings.Contains(err.Error(), "must not contain wildcards") { t.Errorf("Expected 'must not contain wildcards' error, got: %v", err) } }) } } // TestAudienceSecurityInjectionAttempts tests various injection attempts func TestAudienceSecurityInjectionAttempts(t *testing.T) { tests := []struct { name string audience string errContains string }{ { name: "Newline injection", audience: "api.example.com\nmalicious.com", errContains: "contains invalid characters", }, { name: "Carriage return injection", audience: "api.example.com\rmalicious.com", errContains: "contains invalid characters", }, { name: "Tab injection", audience: "api.example.com\tmalicious.com", errContains: "contains invalid characters", }, { name: "Null byte injection", audience: "api.example.com\x00malicious.com", errContains: "contains invalid characters", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { config := CreateConfig() config.ProviderURL = "https://provider.example.com" config.ClientID = "test-client-id" config.ClientSecret = "test-client-secret" config.CallbackURL = "/callback" config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength) config.Audience = tt.audience err := config.Validate() if err == nil { t.Errorf("SECURITY VULNERABILITY: Injection attempt with %q was not rejected", tt.name) } else if !strings.Contains(err.Error(), tt.errContains) { t.Errorf("Expected error containing %q, got: %v", tt.errContains, err) } }) } } // TestAudienceWithReplayProtection tests that replay protection works correctly with custom audiences func TestAudienceWithReplayProtection(t *testing.T) { // Create cleanup helper tc := newTestCleanup(t) // Generate test keys rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("Failed to generate RSA key: %v", err) } rsaPublicKey := &rsaPrivateKey.PublicKey jwk := JWK{ Kty: "RSA", Kid: "test-key-id", Alg: "RS256", N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()), E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), } jwks := &JWKSet{ Keys: []JWK{jwk}, } mockJWKCache := &MockJWKCache{ JWKS: jwks, Err: nil, } logger := NewLogger("debug") tokenBlacklist := tc.addCache(NewCache()) tokenCache := tc.addTokenCache(NewTokenCache()) tOidc := &TraefikOidc{ issuerURL: "https://auth.example.com", clientID: "test-client-id", clientSecret: "test-client-secret", jwkCache: mockJWKCache, jwksURL: "https://auth.example.com/.well-known/jwks.json", tokenBlacklist: tokenBlacklist, tokenCache: tokenCache, limiter: rate.NewLimiter(rate.Every(time.Second), 10), logger: logger, httpClient: &http.Client{}, } tOidc.jwtVerifier = tOidc tOidc.tokenVerifier = tOidc // Create a token with custom audience and fixed JTI fixedJTI := "replay-test-jti-" + generateRandomString(8) customAudience := "https://api.example.com" // Set the audience field to match what we expect tOidc.audience = customAudience jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://auth.example.com", "aud": customAudience, "exp": float64(time.Now().Add(1 * time.Hour).Unix()), "iat": float64(time.Now().Add(-2 * time.Minute).Unix()), "sub": "test-user", "email": "user@example.com", "jti": fixedJTI, }) if err != nil { t.Fatalf("Failed to create JWT: %v", err) } // First verification should succeed err = tOidc.VerifyToken(jwt) if err != nil { t.Fatalf("First verification failed: %v", err) } // Verify that the JTI was blacklisted if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil { t.Logf("Note: JTI was not added to blacklist (may be due to test token prefix)") } else { t.Logf("Replay protection verified: JTI %s is correctly blacklisted", fixedJTI) } } // TestAudienceEndToEndScenario tests a complete end-to-end scenario with middleware func TestAudienceEndToEndScenario(t *testing.T) { // Create cleanup helper tc := newTestCleanup(t) // Create a test next handler nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("Authenticated with custom audience")) }) // Generate test keys rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("Failed to generate RSA key: %v", err) } rsaPublicKey := &rsaPrivateKey.PublicKey jwk := JWK{ Kty: "RSA", Kid: "test-key-id", Alg: "RS256", N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()), E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), } jwks := &JWKSet{ Keys: []JWK{jwk}, } mockJWKCache := &MockJWKCache{ JWKS: jwks, Err: nil, } logger := NewLogger("debug") sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", logger) if err != nil { t.Fatalf("Failed to create session manager: %v", err) } tokenBlacklist := tc.addCache(NewCache()) tokenCache := tc.addTokenCache(NewTokenCache()) customAudience := "https://api.company.com" tOidc := &TraefikOidc{ next: nextHandler, name: "test", redirURLPath: "/callback", logoutURLPath: "/callback/logout", issuerURL: "https://auth.company.com", clientID: "test-client-id", clientSecret: "test-client-secret", audience: customAudience, // Set custom audience jwkCache: mockJWKCache, jwksURL: "https://auth.company.com/.well-known/jwks.json", tokenBlacklist: tokenBlacklist, tokenCache: tokenCache, limiter: rate.NewLimiter(rate.Every(time.Second), 10), logger: logger, allowedUserDomains: map[string]struct{}{"company.com": {}}, excludedURLs: map[string]struct{}{}, httpClient: &http.Client{}, initComplete: make(chan struct{}), sessionManager: sm, extractClaimsFunc: extractClaims, } tOidc.jwtVerifier = tOidc tOidc.tokenVerifier = tOidc close(tOidc.initComplete) t.Run("End-to-end with correct custom audience", func(t *testing.T) { // Create a valid token with the custom audience validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://auth.company.com", "aud": customAudience, "exp": float64(time.Now().Add(1 * time.Hour).Unix()), "iat": float64(time.Now().Add(-2 * time.Minute).Unix()), "sub": "user-123", "email": "user@company.com", "jti": generateRandomString(16), }) if err != nil { t.Fatalf("Failed to create valid JWT: %v", err) } // Create a request with authenticated session req := httptest.NewRequest("GET", "/protected", nil) req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("X-Forwarded-Host", "company.com") // Create session with token resp := httptest.NewRecorder() session, err := sm.GetSession(req) if err != nil { t.Fatalf("Failed to get session: %v", err) } session.SetAuthenticated(true) session.SetEmail("user@company.com") session.SetIDToken(validJWT) session.SetAccessToken(validJWT) if err := session.Save(req, resp); err != nil { t.Fatalf("Failed to save session: %v", err) } // Get cookies and add them to a new request cookies := resp.Result().Cookies() req = httptest.NewRequest("GET", "/protected", nil) req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("X-Forwarded-Host", "company.com") for _, cookie := range cookies { req.AddCookie(cookie) } resp = httptest.NewRecorder() tOidc.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Errorf("Expected status 200, got %d. Body: %s", resp.Code, resp.Body.String()) } }) }