Files
traefikoidc/audience_validation_test.go
T
lukaszraczylo bde1db1c3b traefik plugin 0.7.7 (#73)
* Automatic discovery of the scopes.

Issue #61 raised very valid concerns about users configuring scopes that are not supported by the provider.
This change introduces automatic discovery of supported scopes by fetching the provider's discovery document and filtering out unsupported scopes.

Before:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Self-hosted GitLab: "The requested scope is invalid, unknown, or malformed"
Authentication:  FAILS

After:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Middleware checks discovery doc → offline_access not supported
Automatically filters to: ["openid", "profile", "email"]
Authentication:  SUCCEEDS

* Resolves issue #74 by enabling user to specify expected audience in the configuration.

* Fix flaky tests.
2025-10-08 11:44:00 +01:00

928 lines
27 KiB
Go

package traefikoidc
import (
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"golang.org/x/time/rate"
)
// TestConfigAudienceValidation tests the Config.Validate() method for the audience field
func TestConfigAudienceValidation(t *testing.T) {
tests := []struct {
name string
audience string
wantErr bool
errContains string
}{
{
name: "Empty audience is valid for backward compatibility",
audience: "",
wantErr: false,
},
{
name: "Valid HTTPS URL audience Auth0 format",
audience: "https://api.example.com",
wantErr: false,
},
{
name: "Valid identifier audience",
audience: "my-api",
wantErr: false,
},
{
name: "Valid Azure AD Application ID URI format",
audience: "api://12345-guid-67890",
wantErr: false,
},
{
name: "Valid Auth0 API identifier",
audience: "https://my-company.auth0.com/api/v2/",
wantErr: false,
},
{
name: "HTTP URL audience should fail",
audience: "http://api.example.com",
wantErr: true,
errContains: "must use HTTPS",
},
{
name: "Audience with wildcard should fail",
audience: "https://api.*.example.com",
wantErr: true,
errContains: "must not contain wildcards",
},
{
name: "Audience with single asterisk should fail",
audience: "*",
wantErr: true,
errContains: "must not contain wildcards",
},
{
name: "Audience over 256 characters should fail",
audience: strings.Repeat("a", 257),
wantErr: true,
errContains: "must not exceed 256 characters",
},
{
name: "Audience with newline should fail",
audience: "my-api\ninjection",
wantErr: true,
errContains: "contains invalid characters",
},
{
name: "Audience with carriage return should fail",
audience: "my-api\rinjection",
wantErr: true,
errContains: "contains invalid characters",
},
{
name: "Audience with tab should fail",
audience: "my-api\tinjection",
wantErr: true,
errContains: "contains invalid characters",
},
{
name: "Valid audience exactly 256 characters",
audience: strings.Repeat("a", 256),
wantErr: false,
},
{
name: "Valid simple identifier",
audience: "my-service-api",
wantErr: false,
},
{
name: "Valid URN format",
audience: "urn:myservice:api:v1",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := CreateConfig()
config.ProviderURL = "https://provider.example.com"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = tt.audience
err := config.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
}
})
}
}
// TestJWTAudienceVerification tests JWT verification with custom audience values
func TestJWTAudienceVerification(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Generate RSA key for signing JWTs
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
// Create JWK
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tests := []struct {
name string
configAudience string
tokenAudience interface{}
wantErr bool
errContains string
skipReplayCheck bool
}{
{
name: "JWT with string aud matching configured audience",
configAudience: "https://api.example.com",
tokenAudience: "https://api.example.com",
wantErr: false,
skipReplayCheck: true,
},
{
name: "JWT with array aud containing configured audience",
configAudience: "https://api.example.com",
tokenAudience: []interface{}{"https://other.com", "https://api.example.com", "https://another.com"},
wantErr: false,
skipReplayCheck: true,
},
{
name: "JWT with string aud NOT matching configured audience",
configAudience: "https://api.example.com",
tokenAudience: "https://wrong-api.example.com",
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
{
name: "JWT with array aud NOT containing configured audience",
configAudience: "https://api.example.com",
tokenAudience: []interface{}{"https://other.com", "https://another.com"},
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
{
name: "JWT with clientID as aud when no custom audience configured",
configAudience: "",
tokenAudience: "test-client-id",
wantErr: false,
skipReplayCheck: true,
},
{
name: "JWT with empty string aud",
configAudience: "https://api.example.com",
tokenAudience: "",
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
{
name: "Azure AD Application ID URI format",
configAudience: "api://12345-app-id",
tokenAudience: "api://12345-app-id",
wantErr: false,
skipReplayCheck: true,
},
{
name: "Auth0 custom API audience",
configAudience: "https://mycompany.com/api",
tokenAudience: "https://mycompany.com/api",
wantErr: false,
skipReplayCheck: true,
},
{
name: "Token confusion attack - audience for different service",
configAudience: "https://service-a.example.com",
tokenAudience: "https://service-b.example.com",
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create TraefikOidc instance
tOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://test-jwks-url.com",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
// Set up the token verifier and JWT verifier
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
// Determine the expected audience for validation
expectedAudience := tt.configAudience
if expectedAudience == "" {
expectedAudience = tOidc.clientID
}
// Set the audience field on the tOidc instance
tOidc.audience = expectedAudience
// Create JWT with specified audience
jti := generateRandomString(16)
if tt.skipReplayCheck {
// Use a unique JTI for each test to avoid replay detection
jti = fmt.Sprintf("test-%s-%s", tt.name, jti)
}
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": tt.tokenAudience,
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// Verify the token
err = tOidc.VerifyToken(jwt)
if (err != nil) != tt.wantErr {
t.Errorf("VerifyToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
}
})
}
}
// TestJWTAudienceBackwardCompatibility tests that existing behavior is preserved
// when the Audience field is not set
func TestJWTAudienceBackwardCompatibility(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Test with no custom audience configured - should use clientID
jwt, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id", // Should match clientID
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
err = ts.tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Backward compatibility broken: VerifyToken() error = %v, expected nil", err)
}
}
// TestAudienceIntegrationAuth0Scenario tests Auth0-specific use case
func TestAudienceIntegrationAuth0Scenario(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Simulate Auth0 scenario: custom audience for API access
config := CreateConfig()
config.ProviderURL = "https://mycompany.auth0.com"
config.ClientID = "auth0-client-id"
config.ClientSecret = "auth0-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = "https://api.mycompany.com" // Custom API audience
// Validate config
if err := config.Validate(); err != nil {
t.Fatalf("Auth0 config validation failed: %v", err)
}
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "auth0-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
issuerURL: config.ProviderURL,
clientID: config.ClientID,
clientSecret: config.ClientSecret,
audience: config.Audience, // Set audience from config
jwkCache: mockJWKCache,
jwksURL: "https://mycompany.auth0.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
// Default audience to clientID if not specified
if tOidc.audience == "" {
tOidc.audience = tOidc.clientID
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
t.Run("Valid Auth0 API access token with custom audience", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": config.Audience, // Matches configured audience
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "auth0|123456",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Auth0 JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Auth0 token verification failed: %v", err)
}
})
t.Run("Auth0 token with clientID instead of API audience should fail", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": config.ClientID, // Using clientID instead of API audience
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "auth0|123456",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Auth0 JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err == nil {
t.Error("Auth0 token with wrong audience should have been rejected")
} else if !strings.Contains(err.Error(), "invalid audience") {
t.Errorf("Expected 'invalid audience' error, got: %v", err)
}
})
}
// TestAudienceIntegrationAzureADScenario tests Azure AD-specific use case
func TestAudienceIntegrationAzureADScenario(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Simulate Azure AD scenario: Application ID URI format
config := CreateConfig()
config.ProviderURL = "https://login.microsoftonline.com/tenant-id/v2.0"
config.ClientID = "azure-client-id"
config.ClientSecret = "azure-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = "api://12345-abcd-6789-efgh" // Azure AD Application ID URI
// Validate config
if err := config.Validate(); err != nil {
t.Fatalf("Azure AD config validation failed: %v", err)
}
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "azure-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
issuerURL: config.ProviderURL,
clientID: config.ClientID,
clientSecret: config.ClientSecret,
audience: config.Audience, // Set audience from config
jwkCache: mockJWKCache,
jwksURL: config.ProviderURL + "/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
// Default audience to clientID if not specified
if tOidc.audience == "" {
tOidc.audience = tOidc.clientID
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
t.Run("Valid Azure AD token with Application ID URI audience", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": config.Audience, // Matches Application ID URI
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "azure-user-id",
"email": "user@example.com",
"oid": "object-id-12345",
"tid": "tenant-id",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Azure AD JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Azure AD token verification failed: %v", err)
}
})
t.Run("Azure AD token with multiple audiences including correct one", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": []interface{}{config.ClientID, config.Audience, "https://graph.microsoft.com"},
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "azure-user-id",
"email": "user@example.com",
"oid": "object-id-12345",
"tid": "tenant-id",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Azure AD JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Azure AD token with multiple audiences verification failed: %v", err)
}
})
}
// TestAudienceSecurityTokenConfusionAttack tests security against token confusion attacks
func TestAudienceSecurityTokenConfusionAttack(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
// Service A configuration
serviceA := &TraefikOidc{
issuerURL: "https://auth.example.com",
clientID: "service-a-client-id",
clientSecret: "service-a-secret",
audience: "service-a-client-id", // Service A uses its clientID as audience
jwkCache: mockJWKCache,
jwksURL: "https://auth.example.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
serviceA.jwtVerifier = serviceA
serviceA.tokenVerifier = serviceA
t.Run("Token confusion - Try to use service B token on service A", func(t *testing.T) {
// Create a token intended for service B
serviceBToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://auth.example.com",
"aud": "https://service-b.example.com", // For service B
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "attacker@example.com",
"email": "attacker@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create service B token: %v", err)
}
// Try to verify the service B token on service A
err = serviceA.VerifyToken(serviceBToken)
if err == nil {
t.Error("SECURITY VULNERABILITY: Token confusion attack succeeded - service B token was accepted by service A")
} else if !strings.Contains(err.Error(), "invalid audience") {
t.Errorf("Expected 'invalid audience' error for token confusion, got: %v", err)
} else {
t.Logf("Token confusion attack correctly prevented: %v", err)
}
})
}
// TestAudienceSecurityWildcardInjection tests that wildcards are rejected
func TestAudienceSecurityWildcardInjection(t *testing.T) {
tests := []struct {
name string
audience string
}{
{
name: "Single asterisk",
audience: "*",
},
{
name: "Wildcard in URL",
audience: "https://*.example.com",
},
{
name: "Wildcard in path",
audience: "https://api.example.com/*",
},
{
name: "Multiple wildcards",
audience: "https://*.*.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := CreateConfig()
config.ProviderURL = "https://provider.example.com"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = tt.audience
err := config.Validate()
if err == nil {
t.Errorf("SECURITY VULNERABILITY: Wildcard audience %q was not rejected", tt.audience)
} else if !strings.Contains(err.Error(), "must not contain wildcards") {
t.Errorf("Expected 'must not contain wildcards' error, got: %v", err)
}
})
}
}
// TestAudienceSecurityInjectionAttempts tests various injection attempts
func TestAudienceSecurityInjectionAttempts(t *testing.T) {
tests := []struct {
name string
audience string
errContains string
}{
{
name: "Newline injection",
audience: "api.example.com\nmalicious.com",
errContains: "contains invalid characters",
},
{
name: "Carriage return injection",
audience: "api.example.com\rmalicious.com",
errContains: "contains invalid characters",
},
{
name: "Tab injection",
audience: "api.example.com\tmalicious.com",
errContains: "contains invalid characters",
},
{
name: "Null byte injection",
audience: "api.example.com\x00malicious.com",
errContains: "contains invalid characters",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := CreateConfig()
config.ProviderURL = "https://provider.example.com"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = tt.audience
err := config.Validate()
if err == nil {
t.Errorf("SECURITY VULNERABILITY: Injection attempt with %q was not rejected", tt.name)
} else if !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Expected error containing %q, got: %v", tt.errContains, err)
}
})
}
}
// TestAudienceWithReplayProtection tests that replay protection works correctly with custom audiences
func TestAudienceWithReplayProtection(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
issuerURL: "https://auth.example.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://auth.example.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
// Create a token with custom audience and fixed JTI
fixedJTI := "replay-test-jti-" + generateRandomString(8)
customAudience := "https://api.example.com"
// Set the audience field to match what we expect
tOidc.audience = customAudience
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://auth.example.com",
"aud": customAudience,
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-user",
"email": "user@example.com",
"jti": fixedJTI,
})
if err != nil {
t.Fatalf("Failed to create JWT: %v", err)
}
// First verification should succeed
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Fatalf("First verification failed: %v", err)
}
// Verify that the JTI was blacklisted
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
t.Logf("Note: JTI was not added to blacklist (may be due to test token prefix)")
} else {
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", fixedJTI)
}
}
// TestAudienceEndToEndScenario tests a complete end-to-end scenario with middleware
func TestAudienceEndToEndScenario(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Create a test next handler
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Authenticated with custom audience"))
})
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
customAudience := "https://api.company.com"
tOidc := &TraefikOidc{
next: nextHandler,
name: "test",
redirURLPath: "/callback",
logoutURLPath: "/callback/logout",
issuerURL: "https://auth.company.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
audience: customAudience, // Set custom audience
jwkCache: mockJWKCache,
jwksURL: "https://auth.company.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
allowedUserDomains: map[string]struct{}{"company.com": {}},
excludedURLs: map[string]struct{}{},
httpClient: &http.Client{},
initComplete: make(chan struct{}),
sessionManager: sm,
extractClaimsFunc: extractClaims,
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
close(tOidc.initComplete)
t.Run("End-to-end with correct custom audience", func(t *testing.T) {
// Create a valid token with the custom audience
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://auth.company.com",
"aud": customAudience,
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "user-123",
"email": "user@company.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create valid JWT: %v", err)
}
// Create a request with authenticated session
req := httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "company.com")
// Create session with token
resp := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetAuthenticated(true)
session.SetEmail("user@company.com")
session.SetIDToken(validJWT)
session.SetAccessToken(validJWT)
if err := session.Save(req, resp); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get cookies and add them to a new request
cookies := resp.Result().Cookies()
req = httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "company.com")
for _, cookie := range cookies {
req.AddCookie(cookie)
}
resp = httptest.NewRecorder()
tOidc.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", resp.Code, resp.Body.String())
}
})
}