mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
e64fc7f730
* Add redis support for distributed caching * Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * fixup! fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * ... and another all nighter. * fixup! ... and another all nighter. * fixup! fixup! ... and another all nighter. * fixup! fixup! fixup! ... and another all nighter. * Resolve issue #85 by adding ability to set custom claims in JWT tokens * Remove redundant validation in auth middleware ( issue #89 ) * Add ability to set cookie prefix for session cookies ( #87 ) * fixup! Add ability to set cookie prefix for session cookies ( #87 ) * Add ability to set cookie max age - issue #91 * Potential fix for code scanning alert no. 10: Size computation for allocation may overflow Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> * fixup! Merge main into 0.8.0-redis: resolve conflicts --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
932 lines
28 KiB
Go
932 lines
28 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// TestConfigAudienceValidation tests the Config.Validate() method for the audience field
|
|
func TestConfigAudienceValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
audience string
|
|
wantErr bool
|
|
errContains string
|
|
}{
|
|
{
|
|
name: "Empty audience is valid for backward compatibility",
|
|
audience: "",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid HTTPS URL audience Auth0 format",
|
|
audience: "https://api.example.com",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid identifier audience",
|
|
audience: "my-api",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid Azure AD Application ID URI format",
|
|
audience: "api://12345-guid-67890",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid Auth0 API identifier",
|
|
audience: "https://my-company.auth0.com/api/v2/",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "HTTP URL audience should fail",
|
|
audience: "http://api.example.com",
|
|
wantErr: true,
|
|
errContains: "must use HTTPS",
|
|
},
|
|
{
|
|
name: "Audience with wildcard should fail",
|
|
audience: "https://api.*.example.com",
|
|
wantErr: true,
|
|
errContains: "must not contain wildcards",
|
|
},
|
|
{
|
|
name: "Audience with single asterisk should fail",
|
|
audience: "*",
|
|
wantErr: true,
|
|
errContains: "must not contain wildcards",
|
|
},
|
|
{
|
|
name: "Audience over 256 characters should fail",
|
|
audience: strings.Repeat("a", 257),
|
|
wantErr: true,
|
|
errContains: "must not exceed 256 characters",
|
|
},
|
|
{
|
|
name: "Audience with newline should fail",
|
|
audience: "my-api\ninjection",
|
|
wantErr: true,
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Audience with carriage return should fail",
|
|
audience: "my-api\rinjection",
|
|
wantErr: true,
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Audience with tab should fail",
|
|
audience: "my-api\tinjection",
|
|
wantErr: true,
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Valid audience exactly 256 characters",
|
|
audience: strings.Repeat("a", 256),
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid simple identifier",
|
|
audience: "my-service-api",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "Valid URN format",
|
|
audience: "urn:myservice:api:v1",
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
config := CreateConfig()
|
|
config.ProviderURL = "https://provider.example.com"
|
|
config.ClientID = "test-client-id"
|
|
config.ClientSecret = "test-client-secret"
|
|
config.CallbackURL = "/callback"
|
|
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
|
config.Audience = tt.audience
|
|
|
|
err := config.Validate()
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
|
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestJWTAudienceVerification tests JWT verification with custom audience values
|
|
func TestJWTAudienceVerification(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Generate RSA key for signing JWTs
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
// Create JWK
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
tests := []struct {
|
|
name string
|
|
configAudience string
|
|
tokenAudience interface{}
|
|
wantErr bool
|
|
errContains string
|
|
skipReplayCheck bool
|
|
}{
|
|
{
|
|
name: "JWT with string aud matching configured audience",
|
|
configAudience: "https://api.example.com",
|
|
tokenAudience: "https://api.example.com",
|
|
wantErr: false,
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "JWT with array aud containing configured audience",
|
|
configAudience: "https://api.example.com",
|
|
tokenAudience: []interface{}{"https://other.com", "https://api.example.com", "https://another.com"},
|
|
wantErr: false,
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "JWT with string aud NOT matching configured audience",
|
|
configAudience: "https://api.example.com",
|
|
tokenAudience: "https://wrong-api.example.com",
|
|
wantErr: true,
|
|
errContains: "invalid audience",
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "JWT with array aud NOT containing configured audience",
|
|
configAudience: "https://api.example.com",
|
|
tokenAudience: []interface{}{"https://other.com", "https://another.com"},
|
|
wantErr: true,
|
|
errContains: "invalid audience",
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "JWT with clientID as aud when no custom audience configured",
|
|
configAudience: "",
|
|
tokenAudience: "test-client-id",
|
|
wantErr: false,
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "JWT with empty string aud",
|
|
configAudience: "https://api.example.com",
|
|
tokenAudience: "",
|
|
wantErr: true,
|
|
errContains: "invalid audience",
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "Azure AD Application ID URI format",
|
|
configAudience: "api://12345-app-id",
|
|
tokenAudience: "api://12345-app-id",
|
|
wantErr: false,
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "Auth0 custom API audience",
|
|
configAudience: "https://mycompany.com/api",
|
|
tokenAudience: "https://mycompany.com/api",
|
|
wantErr: false,
|
|
skipReplayCheck: true,
|
|
},
|
|
{
|
|
name: "Token confusion attack - audience for different service",
|
|
configAudience: "https://service-a.example.com",
|
|
tokenAudience: "https://service-b.example.com",
|
|
wantErr: true,
|
|
errContains: "invalid audience",
|
|
skipReplayCheck: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create TraefikOidc instance
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
httpClient: &http.Client{},
|
|
}
|
|
|
|
// Set up the token verifier and JWT verifier
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
// Determine the expected audience for validation
|
|
expectedAudience := tt.configAudience
|
|
if expectedAudience == "" {
|
|
expectedAudience = tOidc.clientID
|
|
}
|
|
|
|
// Set the audience field on the tOidc instance
|
|
tOidc.audience = expectedAudience
|
|
|
|
// Create JWT with specified audience
|
|
jti := generateRandomString(16)
|
|
if tt.skipReplayCheck {
|
|
// Use a unique JTI for each test to avoid replay detection
|
|
jti = fmt.Sprintf("test-%s-%s", tt.name, jti)
|
|
}
|
|
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": tt.tokenAudience,
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": jti,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
// Verify the token
|
|
err = tOidc.VerifyToken(jwt)
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("VerifyToken() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
|
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestJWTAudienceBackwardCompatibility tests that existing behavior is preserved
|
|
// when the Audience field is not set
|
|
func TestJWTAudienceBackwardCompatibility(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Test with no custom audience configured - should use clientID
|
|
jwt, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id", // Should match clientID
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
err = ts.tOidc.VerifyToken(jwt)
|
|
if err != nil {
|
|
t.Errorf("Backward compatibility broken: VerifyToken() error = %v, expected nil", err)
|
|
}
|
|
}
|
|
|
|
// TestAudienceIntegrationAuth0Scenario tests Auth0-specific use case
|
|
func TestAudienceIntegrationAuth0Scenario(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Simulate Auth0 scenario: custom audience for API access
|
|
config := CreateConfig()
|
|
config.ProviderURL = "https://mycompany.auth0.com"
|
|
config.ClientID = "auth0-client-id"
|
|
config.ClientSecret = "auth0-client-secret"
|
|
config.CallbackURL = "/callback"
|
|
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
|
config.Audience = "https://api.mycompany.com" // Custom API audience
|
|
|
|
// Validate config
|
|
if err := config.Validate(); err != nil {
|
|
t.Fatalf("Auth0 config validation failed: %v", err)
|
|
}
|
|
|
|
// Generate test keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "auth0-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: config.ProviderURL,
|
|
clientID: config.ClientID,
|
|
clientSecret: config.ClientSecret,
|
|
audience: config.Audience, // Set audience from config
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://mycompany.auth0.com/.well-known/jwks.json",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
httpClient: &http.Client{},
|
|
}
|
|
|
|
// Default audience to clientID if not specified
|
|
if tOidc.audience == "" {
|
|
tOidc.audience = tOidc.clientID
|
|
}
|
|
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
t.Run("Valid Auth0 API access token with custom audience", func(t *testing.T) {
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
|
"iss": config.ProviderURL,
|
|
"aud": config.Audience, // Matches configured audience
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "auth0|123456",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(jwt)
|
|
if err != nil {
|
|
t.Errorf("Auth0 token verification failed: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Auth0 ACCESS token with clientID instead of API audience should fail", func(t *testing.T) {
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
|
"iss": config.ProviderURL,
|
|
"aud": config.ClientID, // Using clientID instead of API audience
|
|
"scope": "openid profile email", // Mark as access token
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "auth0|123456",
|
|
"email": "user@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(jwt)
|
|
if err == nil {
|
|
t.Error("Auth0 access token with wrong audience should have been rejected")
|
|
} else if !strings.Contains(err.Error(), "invalid audience") {
|
|
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestAudienceIntegrationAzureADScenario tests Azure AD-specific use case
|
|
func TestAudienceIntegrationAzureADScenario(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Simulate Azure AD scenario: Application ID URI format
|
|
config := CreateConfig()
|
|
config.ProviderURL = "https://login.microsoftonline.com/tenant-id/v2.0"
|
|
config.ClientID = "azure-client-id"
|
|
config.ClientSecret = "azure-client-secret"
|
|
config.CallbackURL = "/callback"
|
|
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
|
config.Audience = "api://12345-abcd-6789-efgh" // Azure AD Application ID URI
|
|
|
|
// Validate config
|
|
if err := config.Validate(); err != nil {
|
|
t.Fatalf("Azure AD config validation failed: %v", err)
|
|
}
|
|
|
|
// Generate test keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "azure-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: config.ProviderURL,
|
|
clientID: config.ClientID,
|
|
clientSecret: config.ClientSecret,
|
|
audience: config.Audience, // Set audience from config
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: config.ProviderURL + "/.well-known/jwks.json",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
httpClient: &http.Client{},
|
|
}
|
|
|
|
// Default audience to clientID if not specified
|
|
if tOidc.audience == "" {
|
|
tOidc.audience = tOidc.clientID
|
|
}
|
|
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
t.Run("Valid Azure AD token with Application ID URI audience", func(t *testing.T) {
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
|
"iss": config.ProviderURL,
|
|
"aud": config.Audience, // Matches Application ID URI
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "azure-user-id",
|
|
"email": "user@example.com",
|
|
"oid": "object-id-12345",
|
|
"tid": "tenant-id",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(jwt)
|
|
if err != nil {
|
|
t.Errorf("Azure AD token verification failed: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Azure AD token with multiple audiences including correct one", func(t *testing.T) {
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
|
"iss": config.ProviderURL,
|
|
"aud": []interface{}{config.ClientID, config.Audience, "https://graph.microsoft.com"},
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "azure-user-id",
|
|
"email": "user@example.com",
|
|
"oid": "object-id-12345",
|
|
"tid": "tenant-id",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
|
}
|
|
|
|
err = tOidc.VerifyToken(jwt)
|
|
if err != nil {
|
|
t.Errorf("Azure AD token with multiple audiences verification failed: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestAudienceSecurityTokenConfusionAttack tests security against token confusion attacks
|
|
func TestAudienceSecurityTokenConfusionAttack(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Generate test keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
// Service A configuration
|
|
serviceA := &TraefikOidc{
|
|
issuerURL: "https://auth.example.com",
|
|
clientID: "service-a-client-id",
|
|
clientSecret: "service-a-secret",
|
|
audience: "service-a-client-id", // Service A uses its clientID as audience
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
httpClient: &http.Client{},
|
|
}
|
|
serviceA.jwtVerifier = serviceA
|
|
serviceA.tokenVerifier = serviceA
|
|
|
|
t.Run("Token confusion - Try to use service B token on service A", func(t *testing.T) {
|
|
// Create a token intended for service B
|
|
serviceBToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://auth.example.com",
|
|
"aud": "https://service-b.example.com", // For service B
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "attacker@example.com",
|
|
"email": "attacker@example.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create service B token: %v", err)
|
|
}
|
|
|
|
// Try to verify the service B token on service A
|
|
err = serviceA.VerifyToken(serviceBToken)
|
|
switch {
|
|
case err == nil:
|
|
t.Error("SECURITY VULNERABILITY: Token confusion attack succeeded - service B token was accepted by service A")
|
|
case !strings.Contains(err.Error(), "invalid audience"):
|
|
t.Errorf("Expected 'invalid audience' error for token confusion, got: %v", err)
|
|
default:
|
|
t.Logf("Token confusion attack correctly prevented: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestAudienceSecurityWildcardInjection tests that wildcards are rejected
|
|
func TestAudienceSecurityWildcardInjection(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
audience string
|
|
}{
|
|
{
|
|
name: "Single asterisk",
|
|
audience: "*",
|
|
},
|
|
{
|
|
name: "Wildcard in URL",
|
|
audience: "https://*.example.com",
|
|
},
|
|
{
|
|
name: "Wildcard in path",
|
|
audience: "https://api.example.com/*",
|
|
},
|
|
{
|
|
name: "Multiple wildcards",
|
|
audience: "https://*.*.example.com",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
config := CreateConfig()
|
|
config.ProviderURL = "https://provider.example.com"
|
|
config.ClientID = "test-client-id"
|
|
config.ClientSecret = "test-client-secret"
|
|
config.CallbackURL = "/callback"
|
|
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
|
config.Audience = tt.audience
|
|
|
|
err := config.Validate()
|
|
if err == nil {
|
|
t.Errorf("SECURITY VULNERABILITY: Wildcard audience %q was not rejected", tt.audience)
|
|
} else if !strings.Contains(err.Error(), "must not contain wildcards") {
|
|
t.Errorf("Expected 'must not contain wildcards' error, got: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestAudienceSecurityInjectionAttempts tests various injection attempts
|
|
func TestAudienceSecurityInjectionAttempts(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
audience string
|
|
errContains string
|
|
}{
|
|
{
|
|
name: "Newline injection",
|
|
audience: "api.example.com\nmalicious.com",
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Carriage return injection",
|
|
audience: "api.example.com\rmalicious.com",
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Tab injection",
|
|
audience: "api.example.com\tmalicious.com",
|
|
errContains: "contains invalid characters",
|
|
},
|
|
{
|
|
name: "Null byte injection",
|
|
audience: "api.example.com\x00malicious.com",
|
|
errContains: "contains invalid characters",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
config := CreateConfig()
|
|
config.ProviderURL = "https://provider.example.com"
|
|
config.ClientID = "test-client-id"
|
|
config.ClientSecret = "test-client-secret"
|
|
config.CallbackURL = "/callback"
|
|
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
|
config.Audience = tt.audience
|
|
|
|
err := config.Validate()
|
|
if err == nil {
|
|
t.Errorf("SECURITY VULNERABILITY: Injection attempt with %q was not rejected", tt.name)
|
|
} else if !strings.Contains(err.Error(), tt.errContains) {
|
|
t.Errorf("Expected error containing %q, got: %v", tt.errContains, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestAudienceWithReplayProtection tests that replay protection works correctly with custom audiences
|
|
func TestAudienceWithReplayProtection(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Generate test keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
tOidc := &TraefikOidc{
|
|
issuerURL: "https://auth.example.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
httpClient: &http.Client{},
|
|
}
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
|
|
// Create a token with custom audience and fixed JTI
|
|
fixedJTI := "replay-test-jti-" + generateRandomString(8)
|
|
customAudience := "https://api.example.com"
|
|
|
|
// Set the audience field to match what we expect
|
|
tOidc.audience = customAudience
|
|
|
|
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://auth.example.com",
|
|
"aud": customAudience,
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "test-user",
|
|
"email": "user@example.com",
|
|
"jti": fixedJTI,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create JWT: %v", err)
|
|
}
|
|
|
|
// First verification should succeed
|
|
err = tOidc.VerifyToken(jwt)
|
|
if err != nil {
|
|
t.Fatalf("First verification failed: %v", err)
|
|
}
|
|
|
|
// Verify that the JTI was blacklisted
|
|
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
|
|
t.Logf("Note: JTI was not added to blacklist (may be due to test token prefix)")
|
|
} else {
|
|
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", fixedJTI)
|
|
}
|
|
}
|
|
|
|
// TestAudienceEndToEndScenario tests a complete end-to-end scenario with middleware
|
|
func TestAudienceEndToEndScenario(t *testing.T) {
|
|
// Create cleanup helper
|
|
tc := newTestCleanup(t)
|
|
|
|
// Create a test next handler
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("Authenticated with custom audience"))
|
|
})
|
|
|
|
// Generate test keys
|
|
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
rsaPublicKey := &rsaPrivateKey.PublicKey
|
|
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
mockJWKCache := &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
logger := NewLogger("debug")
|
|
sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", "", 0, logger)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session manager: %v", err)
|
|
}
|
|
|
|
tokenBlacklist := tc.addCache(NewCache())
|
|
tokenCache := tc.addTokenCache(NewTokenCache())
|
|
|
|
customAudience := "https://api.company.com"
|
|
|
|
tOidc := &TraefikOidc{
|
|
next: nextHandler,
|
|
name: "test",
|
|
redirURLPath: "/callback",
|
|
logoutURLPath: "/callback/logout",
|
|
issuerURL: "https://auth.company.com",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
audience: customAudience, // Set custom audience
|
|
jwkCache: mockJWKCache,
|
|
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
|
excludedURLs: map[string]struct{}{},
|
|
httpClient: &http.Client{},
|
|
initComplete: make(chan struct{}),
|
|
sessionManager: sm,
|
|
extractClaimsFunc: extractClaims,
|
|
}
|
|
tOidc.jwtVerifier = tOidc
|
|
tOidc.tokenVerifier = tOidc
|
|
close(tOidc.initComplete)
|
|
|
|
t.Run("End-to-end with correct custom audience", func(t *testing.T) {
|
|
// Create a valid token with the custom audience
|
|
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://auth.company.com",
|
|
"aud": customAudience,
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
|
"sub": "user-123",
|
|
"email": "user@company.com",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create valid JWT: %v", err)
|
|
}
|
|
|
|
// Create a request with authenticated session
|
|
req := httptest.NewRequest("GET", "/protected", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("X-Forwarded-Host", "company.com")
|
|
|
|
// Create session with token
|
|
resp := httptest.NewRecorder()
|
|
session, err := sm.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
if err := session.SetAuthenticated(true); err != nil {
|
|
t.Fatalf("Failed to set authenticated: %v", err)
|
|
}
|
|
session.SetEmail("user@company.com")
|
|
session.SetIDToken(validJWT)
|
|
session.SetAccessToken(validJWT)
|
|
|
|
if err := session.Save(req, resp); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Get cookies and add them to a new request
|
|
cookies := resp.Result().Cookies()
|
|
req = httptest.NewRequest("GET", "/protected", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("X-Forwarded-Host", "company.com")
|
|
for _, cookie := range cookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
resp = httptest.NewRecorder()
|
|
tOidc.ServeHTTP(resp, req)
|
|
|
|
if resp.Code != http.StatusOK {
|
|
t.Errorf("Expected status 200, got %d. Body: %s", resp.Code, resp.Body.String())
|
|
}
|
|
})
|
|
}
|