mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
bde1db1c3b
* Automatic discovery of the scopes. Issue #61 raised very valid concerns about users configuring scopes that are not supported by the provider. This change introduces automatic discovery of supported scopes by fetching the provider's discovery document and filtering out unsupported scopes. Before: User configures: scopes: ["openid", "profile", "email", "offline_access"] Self-hosted GitLab: "The requested scope is invalid, unknown, or malformed" Authentication: ❌ FAILS After: User configures: scopes: ["openid", "profile", "email", "offline_access"] Middleware checks discovery doc → offline_access not supported Automatically filters to: ["openid", "profile", "email"] Authentication: ✅ SUCCEEDS * Resolves issue #74 by enabling user to specify expected audience in the configuration. * Fix flaky tests.
928 lines
27 KiB
Go
928 lines
27 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// TestConfigAudienceValidation tests the Config.Validate() method for the audience field
|
|
func TestConfigAudienceValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
audience string
|
|
wantErr bool
|
|
errContains string
|
|
}{
|
|
{
|
|
name: "Empty audience is valid for backward compatibility",
|
|
audience: "",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid HTTPS URL audience Auth0 format",
|
|
audience: "https://api.example.com",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid identifier audience",
|
|
audience: "my-api",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid Azure AD Application ID URI format",
|
|
audience: "api://12345-guid-67890",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid Auth0 API identifier",
|
|
audience: "https://my-company.auth0.com/api/v2/",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "HTTP URL audience should fail",
|
|
audience: "http://api.example.com",
|
|
wantErr: true,
|
|
errContains: "must use HTTPS",
|
|
},
|
|
{
|
|
name: "Audience with wildcard should fail",
|
|
audience: "https://api.*.example.com",
|
|
wantErr: true,
|
|
errContains: "must not contain wildcards",
|
|
},
|
|
{
|
|
name: "Audience with single asterisk should fail",
|
|
audience: "*",
|
|
wantErr: true,
|
|
errContains: "must not contain wildcards",
|
|
},
|
|
{
|
|
name: "Audience over 256 characters should fail",
|
|
audience: strings.Repeat("a", 257),
|
|
wantErr: true,
|
|
errContains: "must not exceed 256 characters",
|
|
},
|
|
{
|
|
name: "Audience with newline should fail",
|
|
audience: "my-api\ninjection",
|
|
wantErr: true,
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Audience with carriage return should fail",
|
|
audience: "my-api\rinjection",
|
|
wantErr: true,
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Audience with tab should fail",
|
|
audience: "my-api\tinjection",
|
|
wantErr: true,
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Valid audience exactly 256 characters",
|
|
audience: strings.Repeat("a", 256),
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid simple identifier",
|
|
audience: "my-service-api",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid URN format",
|
|
audience: "urn:myservice:api:v1",
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
config := CreateConfig()
|
|
config.ProviderURL = "https://provider.example.com"
|
|
config.ClientID = "test-client-id"
|
|
config.ClientSecret = "test-client-secret"
|
|
config.CallbackURL = "/callback"
|
|
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
|
config.Audience = tt.audience
|
|
|
|
err := config.Validate()
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
|
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestJWTAudienceVerification tests JWT verification with custom audience values
|
|
func TestJWTAudienceVerification(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Generate RSA key for signing JWTs
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
// Create JWK
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
tests := []struct {
|
|
name string
|
|
configAudience string
|
|
tokenAudience interface{}
|
|
wantErr bool
|
|
errContains string
|
|
skipReplayCheck bool
|
|
}{
|
|
{
|
|
name: "JWT with string aud matching configured audience",
|
|
configAudience: "https://api.example.com",
|
|
tokenAudience: "https://api.example.com",
|
|
wantErr: false,
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "JWT with array aud containing configured audience",
|
|
configAudience: "https://api.example.com",
|
|
tokenAudience: []interface{}{"https://other.com", "https://api.example.com", "https://another.com"},
|
|
wantErr: false,
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "JWT with string aud NOT matching configured audience",
|
|
configAudience: "https://api.example.com",
|
|
tokenAudience: "https://wrong-api.example.com",
|
|
wantErr: true,
|
|
errContains: "invalid audience",
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "JWT with array aud NOT containing configured audience",
|
|
configAudience: "https://api.example.com",
|
|
tokenAudience: []interface{}{"https://other.com", "https://another.com"},
|
|
wantErr: true,
|
|
errContains: "invalid audience",
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "JWT with clientID as aud when no custom audience configured",
|
|
configAudience: "",
|
|
tokenAudience: "test-client-id",
|
|
wantErr: false,
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "JWT with empty string aud",
|
|
configAudience: "https://api.example.com",
|
|
tokenAudience: "",
|
|
wantErr: true,
|
|
errContains: "invalid audience",
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "Azure AD Application ID URI format",
|
|
configAudience: "api://12345-app-id",
|
|
tokenAudience: "api://12345-app-id",
|
|
wantErr: false,
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "Auth0 custom API audience",
|
|
configAudience: "https://mycompany.com/api",
|
|
tokenAudience: "https://mycompany.com/api",
|
|
wantErr: false,
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "Token confusion attack - audience for different service",
|
|
configAudience: "https://service-a.example.com",
|
|
tokenAudience: "https://service-b.example.com",
|
|
wantErr: true,
|
|
errContains: "invalid audience",
|
|
skipReplayCheck: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create TraefikOidc instance
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
httpClient: &http.Client{},
|
|
}
|
|
|
|
// Set up the token verifier and JWT verifier
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
// Determine the expected audience for validation
|
|
expectedAudience := tt.configAudience
|
|
if expectedAudience == "" {
|
|
expectedAudience = tOidc.clientID
|
|
}
|
|
|
|
// Set the audience field on the tOidc instance
|
|
tOidc.audience = expectedAudience
|
|
|
|
// Create JWT with specified audience
|
|
jti := generateRandomString(16)
|
|
if tt.skipReplayCheck {
|
|
// Use a unique JTI for each test to avoid replay detection
|
|
jti = fmt.Sprintf("test-%s-%s", tt.name, jti)
|
|
}
|
|
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": tt.tokenAudience,
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": jti,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
// Verify the token
|
|
err = tOidc.VerifyToken(jwt)
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("VerifyToken() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
|
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestJWTAudienceBackwardCompatibility tests that existing behavior is preserved
|
|
// when the Audience field is not set
|
|
func TestJWTAudienceBackwardCompatibility(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Test with no custom audience configured - should use clientID
|
|
jwt, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id", // Should match clientID
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
err = ts.tOidc.VerifyToken(jwt)
|
|
if err != nil {
|
|
t.Errorf("Backward compatibility broken: VerifyToken() error = %v, expected nil", err)
|
|
}
|
|
}
|
|
|
|
// TestAudienceIntegrationAuth0Scenario tests Auth0-specific use case
|
|
func TestAudienceIntegrationAuth0Scenario(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Simulate Auth0 scenario: custom audience for API access
|
|
config := CreateConfig()
|
|
config.ProviderURL = "https://mycompany.auth0.com"
|
|
config.ClientID = "auth0-client-id"
|
|
config.ClientSecret = "auth0-client-secret"
|
|
config.CallbackURL = "/callback"
|
|
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
|
config.Audience = "https://api.mycompany.com" // Custom API audience
|
|
|
|
// Validate config
|
|
if err := config.Validate(); err != nil {
|
|
t.Fatalf("Auth0 config validation failed: %v", err)
|
|
}
|
|
|
|
// Generate test keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "auth0-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: config.ProviderURL,
|
|
clientID: config.ClientID,
|
|
clientSecret: config.ClientSecret,
|
|
audience: config.Audience, // Set audience from config
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://mycompany.auth0.com/.well-known/jwks.json",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
httpClient: &http.Client{},
|
|
}
|
|
|
|
// Default audience to clientID if not specified
|
|
if tOidc.audience == "" {
|
|
tOidc.audience = tOidc.clientID
|
|
}
|
|
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
t.Run("Valid Auth0 API access token with custom audience", func(t *testing.T) {
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
|
"iss": config.ProviderURL,
|
|
"aud": config.Audience, // Matches configured audience
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "auth0|123456",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(jwt)
|
|
if err != nil {
|
|
t.Errorf("Auth0 token verification failed: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Auth0 token with clientID instead of API audience should fail", func(t *testing.T) {
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
|
"iss": config.ProviderURL,
|
|
"aud": config.ClientID, // Using clientID instead of API audience
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "auth0|123456",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(jwt)
|
|
if err == nil {
|
|
t.Error("Auth0 token with wrong audience should have been rejected")
|
|
} else if !strings.Contains(err.Error(), "invalid audience") {
|
|
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestAudienceIntegrationAzureADScenario tests Azure AD-specific use case
|
|
func TestAudienceIntegrationAzureADScenario(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Simulate Azure AD scenario: Application ID URI format
|
|
config := CreateConfig()
|
|
config.ProviderURL = "https://login.microsoftonline.com/tenant-id/v2.0"
|
|
config.ClientID = "azure-client-id"
|
|
config.ClientSecret = "azure-client-secret"
|
|
config.CallbackURL = "/callback"
|
|
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
|
config.Audience = "api://12345-abcd-6789-efgh" // Azure AD Application ID URI
|
|
|
|
// Validate config
|
|
if err := config.Validate(); err != nil {
|
|
t.Fatalf("Azure AD config validation failed: %v", err)
|
|
}
|
|
|
|
// Generate test keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "azure-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: config.ProviderURL,
|
|
clientID: config.ClientID,
|
|
clientSecret: config.ClientSecret,
|
|
audience: config.Audience, // Set audience from config
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: config.ProviderURL + "/.well-known/jwks.json",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
httpClient: &http.Client{},
|
|
}
|
|
|
|
// Default audience to clientID if not specified
|
|
if tOidc.audience == "" {
|
|
tOidc.audience = tOidc.clientID
|
|
}
|
|
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
t.Run("Valid Azure AD token with Application ID URI audience", func(t *testing.T) {
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
|
"iss": config.ProviderURL,
|
|
"aud": config.Audience, // Matches Application ID URI
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "azure-user-id",
|
|
"email": "user@example.com",
|
|
"oid": "object-id-12345",
|
|
"tid": "tenant-id",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(jwt)
|
|
if err != nil {
|
|
t.Errorf("Azure AD token verification failed: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Azure AD token with multiple audiences including correct one", func(t *testing.T) {
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
|
"iss": config.ProviderURL,
|
|
"aud": []interface{}{config.ClientID, config.Audience, "https://graph.microsoft.com"},
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "azure-user-id",
|
|
"email": "user@example.com",
|
|
"oid": "object-id-12345",
|
|
"tid": "tenant-id",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(jwt)
|
|
if err != nil {
|
|
t.Errorf("Azure AD token with multiple audiences verification failed: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestAudienceSecurityTokenConfusionAttack tests security against token confusion attacks
|
|
func TestAudienceSecurityTokenConfusionAttack(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Generate test keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
// Service A configuration
|
|
serviceA := &TraefikOidc{
|
|
issuerURL: "https://auth.example.com",
|
|
clientID: "service-a-client-id",
|
|
clientSecret: "service-a-secret",
|
|
audience: "service-a-client-id", // Service A uses its clientID as audience
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
httpClient: &http.Client{},
|
|
}
|
|
serviceA.jwtVerifier = serviceA
|
|
serviceA.tokenVerifier = serviceA
|
|
|
|
t.Run("Token confusion - Try to use service B token on service A", func(t *testing.T) {
|
|
// Create a token intended for service B
|
|
serviceBToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://auth.example.com",
|
|
"aud": "https://service-b.example.com", // For service B
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "attacker@example.com",
|
|
"email": "attacker@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create service B token: %v", err)
|
|
}
|
|
|
|
// Try to verify the service B token on service A
|
|
err = serviceA.VerifyToken(serviceBToken)
|
|
if err == nil {
|
|
t.Error("SECURITY VULNERABILITY: Token confusion attack succeeded - service B token was accepted by service A")
|
|
} else if !strings.Contains(err.Error(), "invalid audience") {
|
|
t.Errorf("Expected 'invalid audience' error for token confusion, got: %v", err)
|
|
} else {
|
|
t.Logf("Token confusion attack correctly prevented: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestAudienceSecurityWildcardInjection tests that wildcards are rejected
|
|
func TestAudienceSecurityWildcardInjection(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
audience string
|
|
}{
|
|
{
|
|
name: "Single asterisk",
|
|
audience: "*",
|
|
},
|
|
{
|
|
name: "Wildcard in URL",
|
|
audience: "https://*.example.com",
|
|
},
|
|
{
|
|
name: "Wildcard in path",
|
|
audience: "https://api.example.com/*",
|
|
},
|
|
{
|
|
name: "Multiple wildcards",
|
|
audience: "https://*.*.example.com",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
config := CreateConfig()
|
|
config.ProviderURL = "https://provider.example.com"
|
|
config.ClientID = "test-client-id"
|
|
config.ClientSecret = "test-client-secret"
|
|
config.CallbackURL = "/callback"
|
|
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
|
config.Audience = tt.audience
|
|
|
|
err := config.Validate()
|
|
if err == nil {
|
|
t.Errorf("SECURITY VULNERABILITY: Wildcard audience %q was not rejected", tt.audience)
|
|
} else if !strings.Contains(err.Error(), "must not contain wildcards") {
|
|
t.Errorf("Expected 'must not contain wildcards' error, got: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestAudienceSecurityInjectionAttempts tests various injection attempts
|
|
func TestAudienceSecurityInjectionAttempts(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
audience string
|
|
errContains string
|
|
}{
|
|
{
|
|
name: "Newline injection",
|
|
audience: "api.example.com\nmalicious.com",
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Carriage return injection",
|
|
audience: "api.example.com\rmalicious.com",
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Tab injection",
|
|
audience: "api.example.com\tmalicious.com",
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Null byte injection",
|
|
audience: "api.example.com\x00malicious.com",
|
|
errContains: "contains invalid characters",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
config := CreateConfig()
|
|
config.ProviderURL = "https://provider.example.com"
|
|
config.ClientID = "test-client-id"
|
|
config.ClientSecret = "test-client-secret"
|
|
config.CallbackURL = "/callback"
|
|
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
|
config.Audience = tt.audience
|
|
|
|
err := config.Validate()
|
|
if err == nil {
|
|
t.Errorf("SECURITY VULNERABILITY: Injection attempt with %q was not rejected", tt.name)
|
|
} else if !strings.Contains(err.Error(), tt.errContains) {
|
|
t.Errorf("Expected error containing %q, got: %v", tt.errContains, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestAudienceWithReplayProtection tests that replay protection works correctly with custom audiences
|
|
func TestAudienceWithReplayProtection(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Generate test keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: "https://auth.example.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
httpClient: &http.Client{},
|
|
}
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
// Create a token with custom audience and fixed JTI
|
|
fixedJTI := "replay-test-jti-" + generateRandomString(8)
|
|
customAudience := "https://api.example.com"
|
|
|
|
// Set the audience field to match what we expect
|
|
tOidc.audience = customAudience
|
|
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://auth.example.com",
|
|
"aud": customAudience,
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-user",
|
|
"email": "user@example.com",
|
|
"jti": fixedJTI,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT: %v", err)
|
|
}
|
|
|
|
// First verification should succeed
|
|
err = tOidc.VerifyToken(jwt)
|
|
if err != nil {
|
|
t.Fatalf("First verification failed: %v", err)
|
|
}
|
|
|
|
// Verify that the JTI was blacklisted
|
|
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
|
|
t.Logf("Note: JTI was not added to blacklist (may be due to test token prefix)")
|
|
} else {
|
|
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", fixedJTI)
|
|
}
|
|
}
|
|
|
|
// TestAudienceEndToEndScenario tests a complete end-to-end scenario with middleware
|
|
func TestAudienceEndToEndScenario(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Create a test next handler
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("Authenticated with custom audience"))
|
|
})
|
|
|
|
// Generate test keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
customAudience := "https://api.company.com"
|
|
|
|
tOidc := &TraefikOidc{
|
|
next: nextHandler,
|
|
name: "test",
|
|
redirURLPath: "/callback",
|
|
logoutURLPath: "/callback/logout",
|
|
issuerURL: "https://auth.company.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
audience: customAudience, // Set custom audience
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
|
excludedURLs: map[string]struct{}{},
|
|
httpClient: &http.Client{},
|
|
initComplete: make(chan struct{}),
|
|
sessionManager: sm,
|
|
extractClaimsFunc: extractClaims,
|
|
}
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
close(tOidc.initComplete)
|
|
|
|
t.Run("End-to-end with correct custom audience", func(t *testing.T) {
|
|
// Create a valid token with the custom audience
|
|
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://auth.company.com",
|
|
"aud": customAudience,
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "user-123",
|
|
"email": "user@company.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid JWT: %v", err)
|
|
}
|
|
|
|
// Create a request with authenticated session
|
|
req := httptest.NewRequest("GET", "/protected", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("X-Forwarded-Host", "company.com")
|
|
|
|
// Create session with token
|
|
resp := httptest.NewRecorder()
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
session.SetAuthenticated(true)
|
|
session.SetEmail("user@company.com")
|
|
session.SetIDToken(validJWT)
|
|
session.SetAccessToken(validJWT)
|
|
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Get cookies and add them to a new request
|
|
cookies := resp.Result().Cookies()
|
|
req = httptest.NewRequest("GET", "/protected", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("X-Forwarded-Host", "company.com")
|
|
for _, cookie := range cookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
resp = httptest.NewRecorder()
|
|
tOidc.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d. Body: %s", resp.Code, resp.Body.String())
|
|
}
|
|
})
|
|
}
|