traefik plugin 0.7.7 (#73)

* Automatic discovery of the scopes.

Issue #61 raised very valid concerns about users configuring scopes that are not supported by the provider.
This change introduces automatic discovery of supported scopes by fetching the provider's discovery document and filtering out unsupported scopes.

Before:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Self-hosted GitLab: "The requested scope is invalid, unknown, or malformed"
Authentication:  FAILS

After:
User configures: scopes: ["openid", "profile", "email", "offline_access"]
Middleware checks discovery doc → offline_access not supported
Automatically filters to: ["openid", "profile", "email"]
Authentication:  SUCCEEDS

* Resolves issue #74 by enabling user to specify expected audience in the configuration.

* Fix flaky tests.
This commit is contained in:
2025-10-08 11:44:00 +01:00
committed by GitHub
parent 79d34ea4c9
commit bde1db1c3b
29 changed files with 3214 additions and 85 deletions
+1
View File
@@ -8,6 +8,7 @@ The Traefik OIDC middleware provides a complete OIDC authentication solution wit
- **Universal provider support**: Works with 9+ OIDC providers including Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, and more
- **Automatic provider detection**: Automatically detects and configures provider-specific settings
- **Automatic scope filtering**: Intelligently filters OAuth scopes based on provider capabilities declared in OIDC discovery documents, preventing authentication failures with unsupported scopes
- **Security headers**: Comprehensive security headers with CORS, CSP, HSTS, and custom profiles
- **Domain restrictions**: Limit access to specific email domains or individual users
- **Role-based access control**: Restrict access based on roles and groups from OIDC claims
+143
View File
@@ -0,0 +1,143 @@
package traefikoidc
import (
"context"
"net/http"
"strings"
"testing"
)
// TestAudienceConfiguration tests the custom audience configuration feature
func TestAudienceConfiguration(t *testing.T) {
tests := []struct {
name string
configAudience string
clientID string
expectedAudience string
}{
{
name: "no custom audience - uses clientID",
configAudience: "",
clientID: "test-client-id",
expectedAudience: "test-client-id",
},
{
name: "custom audience specified",
configAudience: "api://custom-audience",
clientID: "test-client-id",
expectedAudience: "api://custom-audience",
},
{
name: "auth0 style custom audience",
configAudience: "https://api.example.com",
clientID: "test-client-id",
expectedAudience: "https://api.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create config with custom audience
config := CreateConfig()
config.ProviderURL = "https://provider.example.com"
config.ClientID = tt.clientID
config.ClientSecret = "test-secret"
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
config.CallbackURL = "/callback"
config.Audience = tt.configAudience
// Create middleware instance
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
traefikOidc, err := NewWithContext(context.Background(), config, next, "test")
if err != nil {
t.Fatalf("Failed to create middleware: %v", err)
}
// Verify audience is set correctly
if traefikOidc.audience != tt.expectedAudience {
t.Errorf("Expected audience %s, got %s", tt.expectedAudience, traefikOidc.audience)
}
// Cleanup
traefikOidc.Close()
})
}
}
// TestAudienceValidation tests the audience validation in Config.Validate()
func TestAudienceValidation(t *testing.T) {
tests := []struct {
name string
audience string
expectError bool
errorContains string
}{
{
name: "valid custom audience URL",
audience: "https://api.example.com",
expectError: false,
},
{
name: "valid azure style audience",
audience: "api://12345678-1234-1234-1234-123456789012",
expectError: false,
},
{
name: "empty audience is valid (uses clientID)",
audience: "",
expectError: false,
},
{
name: "http URL not allowed",
audience: "http://api.example.com",
expectError: true,
errorContains: "audience URL must use HTTPS",
},
{
name: "wildcard not allowed",
audience: "https://*.example.com",
expectError: true,
errorContains: "audience must not contain wildcards",
},
{
name: "too long audience",
audience: "https://" + string(make([]byte, 250)) + ".com",
expectError: true,
errorContains: "audience must not exceed 256 characters",
},
{
name: "invalid characters",
audience: "api://test\ninjection",
expectError: true,
errorContains: "audience contains invalid characters",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := CreateConfig()
config.ProviderURL = "https://provider.example.com"
config.ClientID = "test-client"
config.ClientSecret = "test-secret"
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
config.CallbackURL = "/callback"
config.Audience = tt.audience
err := config.Validate()
if tt.expectError {
if err == nil {
t.Errorf("Expected error but got none")
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("Expected error containing '%s', got: %v", tt.errorContains, err)
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
})
}
}
+927
View File
@@ -0,0 +1,927 @@
package traefikoidc
import (
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"golang.org/x/time/rate"
)
// TestConfigAudienceValidation tests the Config.Validate() method for the audience field
func TestConfigAudienceValidation(t *testing.T) {
tests := []struct {
name string
audience string
wantErr bool
errContains string
}{
{
name: "Empty audience is valid for backward compatibility",
audience: "",
wantErr: false,
},
{
name: "Valid HTTPS URL audience Auth0 format",
audience: "https://api.example.com",
wantErr: false,
},
{
name: "Valid identifier audience",
audience: "my-api",
wantErr: false,
},
{
name: "Valid Azure AD Application ID URI format",
audience: "api://12345-guid-67890",
wantErr: false,
},
{
name: "Valid Auth0 API identifier",
audience: "https://my-company.auth0.com/api/v2/",
wantErr: false,
},
{
name: "HTTP URL audience should fail",
audience: "http://api.example.com",
wantErr: true,
errContains: "must use HTTPS",
},
{
name: "Audience with wildcard should fail",
audience: "https://api.*.example.com",
wantErr: true,
errContains: "must not contain wildcards",
},
{
name: "Audience with single asterisk should fail",
audience: "*",
wantErr: true,
errContains: "must not contain wildcards",
},
{
name: "Audience over 256 characters should fail",
audience: strings.Repeat("a", 257),
wantErr: true,
errContains: "must not exceed 256 characters",
},
{
name: "Audience with newline should fail",
audience: "my-api\ninjection",
wantErr: true,
errContains: "contains invalid characters",
},
{
name: "Audience with carriage return should fail",
audience: "my-api\rinjection",
wantErr: true,
errContains: "contains invalid characters",
},
{
name: "Audience with tab should fail",
audience: "my-api\tinjection",
wantErr: true,
errContains: "contains invalid characters",
},
{
name: "Valid audience exactly 256 characters",
audience: strings.Repeat("a", 256),
wantErr: false,
},
{
name: "Valid simple identifier",
audience: "my-service-api",
wantErr: false,
},
{
name: "Valid URN format",
audience: "urn:myservice:api:v1",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := CreateConfig()
config.ProviderURL = "https://provider.example.com"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = tt.audience
err := config.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
}
})
}
}
// TestJWTAudienceVerification tests JWT verification with custom audience values
func TestJWTAudienceVerification(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Generate RSA key for signing JWTs
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
// Create JWK
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tests := []struct {
name string
configAudience string
tokenAudience interface{}
wantErr bool
errContains string
skipReplayCheck bool
}{
{
name: "JWT with string aud matching configured audience",
configAudience: "https://api.example.com",
tokenAudience: "https://api.example.com",
wantErr: false,
skipReplayCheck: true,
},
{
name: "JWT with array aud containing configured audience",
configAudience: "https://api.example.com",
tokenAudience: []interface{}{"https://other.com", "https://api.example.com", "https://another.com"},
wantErr: false,
skipReplayCheck: true,
},
{
name: "JWT with string aud NOT matching configured audience",
configAudience: "https://api.example.com",
tokenAudience: "https://wrong-api.example.com",
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
{
name: "JWT with array aud NOT containing configured audience",
configAudience: "https://api.example.com",
tokenAudience: []interface{}{"https://other.com", "https://another.com"},
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
{
name: "JWT with clientID as aud when no custom audience configured",
configAudience: "",
tokenAudience: "test-client-id",
wantErr: false,
skipReplayCheck: true,
},
{
name: "JWT with empty string aud",
configAudience: "https://api.example.com",
tokenAudience: "",
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
{
name: "Azure AD Application ID URI format",
configAudience: "api://12345-app-id",
tokenAudience: "api://12345-app-id",
wantErr: false,
skipReplayCheck: true,
},
{
name: "Auth0 custom API audience",
configAudience: "https://mycompany.com/api",
tokenAudience: "https://mycompany.com/api",
wantErr: false,
skipReplayCheck: true,
},
{
name: "Token confusion attack - audience for different service",
configAudience: "https://service-a.example.com",
tokenAudience: "https://service-b.example.com",
wantErr: true,
errContains: "invalid audience",
skipReplayCheck: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create TraefikOidc instance
tOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://test-jwks-url.com",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
// Set up the token verifier and JWT verifier
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
// Determine the expected audience for validation
expectedAudience := tt.configAudience
if expectedAudience == "" {
expectedAudience = tOidc.clientID
}
// Set the audience field on the tOidc instance
tOidc.audience = expectedAudience
// Create JWT with specified audience
jti := generateRandomString(16)
if tt.skipReplayCheck {
// Use a unique JTI for each test to avoid replay detection
jti = fmt.Sprintf("test-%s-%s", tt.name, jti)
}
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": tt.tokenAudience,
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": jti,
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
// Verify the token
err = tOidc.VerifyToken(jwt)
if (err != nil) != tt.wantErr {
t.Errorf("VerifyToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
}
})
}
}
// TestJWTAudienceBackwardCompatibility tests that existing behavior is preserved
// when the Audience field is not set
func TestJWTAudienceBackwardCompatibility(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
// Test with no custom audience configured - should use clientID
jwt, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id", // Should match clientID
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-subject",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create test JWT: %v", err)
}
err = ts.tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Backward compatibility broken: VerifyToken() error = %v, expected nil", err)
}
}
// TestAudienceIntegrationAuth0Scenario tests Auth0-specific use case
func TestAudienceIntegrationAuth0Scenario(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Simulate Auth0 scenario: custom audience for API access
config := CreateConfig()
config.ProviderURL = "https://mycompany.auth0.com"
config.ClientID = "auth0-client-id"
config.ClientSecret = "auth0-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = "https://api.mycompany.com" // Custom API audience
// Validate config
if err := config.Validate(); err != nil {
t.Fatalf("Auth0 config validation failed: %v", err)
}
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "auth0-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
issuerURL: config.ProviderURL,
clientID: config.ClientID,
clientSecret: config.ClientSecret,
audience: config.Audience, // Set audience from config
jwkCache: mockJWKCache,
jwksURL: "https://mycompany.auth0.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
// Default audience to clientID if not specified
if tOidc.audience == "" {
tOidc.audience = tOidc.clientID
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
t.Run("Valid Auth0 API access token with custom audience", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": config.Audience, // Matches configured audience
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "auth0|123456",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Auth0 JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Auth0 token verification failed: %v", err)
}
})
t.Run("Auth0 token with clientID instead of API audience should fail", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": config.ClientID, // Using clientID instead of API audience
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "auth0|123456",
"email": "user@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Auth0 JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err == nil {
t.Error("Auth0 token with wrong audience should have been rejected")
} else if !strings.Contains(err.Error(), "invalid audience") {
t.Errorf("Expected 'invalid audience' error, got: %v", err)
}
})
}
// TestAudienceIntegrationAzureADScenario tests Azure AD-specific use case
func TestAudienceIntegrationAzureADScenario(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Simulate Azure AD scenario: Application ID URI format
config := CreateConfig()
config.ProviderURL = "https://login.microsoftonline.com/tenant-id/v2.0"
config.ClientID = "azure-client-id"
config.ClientSecret = "azure-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = "api://12345-abcd-6789-efgh" // Azure AD Application ID URI
// Validate config
if err := config.Validate(); err != nil {
t.Fatalf("Azure AD config validation failed: %v", err)
}
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "azure-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
issuerURL: config.ProviderURL,
clientID: config.ClientID,
clientSecret: config.ClientSecret,
audience: config.Audience, // Set audience from config
jwkCache: mockJWKCache,
jwksURL: config.ProviderURL + "/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
// Default audience to clientID if not specified
if tOidc.audience == "" {
tOidc.audience = tOidc.clientID
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
t.Run("Valid Azure AD token with Application ID URI audience", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": config.Audience, // Matches Application ID URI
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "azure-user-id",
"email": "user@example.com",
"oid": "object-id-12345",
"tid": "tenant-id",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Azure AD JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Azure AD token verification failed: %v", err)
}
})
t.Run("Azure AD token with multiple audiences including correct one", func(t *testing.T) {
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
"iss": config.ProviderURL,
"aud": []interface{}{config.ClientID, config.Audience, "https://graph.microsoft.com"},
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "azure-user-id",
"email": "user@example.com",
"oid": "object-id-12345",
"tid": "tenant-id",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create Azure AD JWT: %v", err)
}
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Errorf("Azure AD token with multiple audiences verification failed: %v", err)
}
})
}
// TestAudienceSecurityTokenConfusionAttack tests security against token confusion attacks
func TestAudienceSecurityTokenConfusionAttack(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
// Service A configuration
serviceA := &TraefikOidc{
issuerURL: "https://auth.example.com",
clientID: "service-a-client-id",
clientSecret: "service-a-secret",
audience: "service-a-client-id", // Service A uses its clientID as audience
jwkCache: mockJWKCache,
jwksURL: "https://auth.example.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
serviceA.jwtVerifier = serviceA
serviceA.tokenVerifier = serviceA
t.Run("Token confusion - Try to use service B token on service A", func(t *testing.T) {
// Create a token intended for service B
serviceBToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://auth.example.com",
"aud": "https://service-b.example.com", // For service B
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "attacker@example.com",
"email": "attacker@example.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create service B token: %v", err)
}
// Try to verify the service B token on service A
err = serviceA.VerifyToken(serviceBToken)
if err == nil {
t.Error("SECURITY VULNERABILITY: Token confusion attack succeeded - service B token was accepted by service A")
} else if !strings.Contains(err.Error(), "invalid audience") {
t.Errorf("Expected 'invalid audience' error for token confusion, got: %v", err)
} else {
t.Logf("Token confusion attack correctly prevented: %v", err)
}
})
}
// TestAudienceSecurityWildcardInjection tests that wildcards are rejected
func TestAudienceSecurityWildcardInjection(t *testing.T) {
tests := []struct {
name string
audience string
}{
{
name: "Single asterisk",
audience: "*",
},
{
name: "Wildcard in URL",
audience: "https://*.example.com",
},
{
name: "Wildcard in path",
audience: "https://api.example.com/*",
},
{
name: "Multiple wildcards",
audience: "https://*.*.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := CreateConfig()
config.ProviderURL = "https://provider.example.com"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = tt.audience
err := config.Validate()
if err == nil {
t.Errorf("SECURITY VULNERABILITY: Wildcard audience %q was not rejected", tt.audience)
} else if !strings.Contains(err.Error(), "must not contain wildcards") {
t.Errorf("Expected 'must not contain wildcards' error, got: %v", err)
}
})
}
}
// TestAudienceSecurityInjectionAttempts tests various injection attempts
func TestAudienceSecurityInjectionAttempts(t *testing.T) {
tests := []struct {
name string
audience string
errContains string
}{
{
name: "Newline injection",
audience: "api.example.com\nmalicious.com",
errContains: "contains invalid characters",
},
{
name: "Carriage return injection",
audience: "api.example.com\rmalicious.com",
errContains: "contains invalid characters",
},
{
name: "Tab injection",
audience: "api.example.com\tmalicious.com",
errContains: "contains invalid characters",
},
{
name: "Null byte injection",
audience: "api.example.com\x00malicious.com",
errContains: "contains invalid characters",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := CreateConfig()
config.ProviderURL = "https://provider.example.com"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
config.Audience = tt.audience
err := config.Validate()
if err == nil {
t.Errorf("SECURITY VULNERABILITY: Injection attempt with %q was not rejected", tt.name)
} else if !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("Expected error containing %q, got: %v", tt.errContains, err)
}
})
}
}
// TestAudienceWithReplayProtection tests that replay protection works correctly with custom audiences
func TestAudienceWithReplayProtection(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
tOidc := &TraefikOidc{
issuerURL: "https://auth.example.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://auth.example.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
httpClient: &http.Client{},
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
// Create a token with custom audience and fixed JTI
fixedJTI := "replay-test-jti-" + generateRandomString(8)
customAudience := "https://api.example.com"
// Set the audience field to match what we expect
tOidc.audience = customAudience
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://auth.example.com",
"aud": customAudience,
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "test-user",
"email": "user@example.com",
"jti": fixedJTI,
})
if err != nil {
t.Fatalf("Failed to create JWT: %v", err)
}
// First verification should succeed
err = tOidc.VerifyToken(jwt)
if err != nil {
t.Fatalf("First verification failed: %v", err)
}
// Verify that the JTI was blacklisted
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
t.Logf("Note: JTI was not added to blacklist (may be due to test token prefix)")
} else {
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", fixedJTI)
}
}
// TestAudienceEndToEndScenario tests a complete end-to-end scenario with middleware
func TestAudienceEndToEndScenario(t *testing.T) {
// Create cleanup helper
tc := newTestCleanup(t)
// Create a test next handler
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Authenticated with custom audience"))
})
// Generate test keys
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
rsaPublicKey := &rsaPrivateKey.PublicKey
jwk := JWK{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
}
jwks := &JWKSet{
Keys: []JWK{jwk},
}
mockJWKCache := &MockJWKCache{
JWKS: jwks,
Err: nil,
}
logger := NewLogger("debug")
sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
tokenBlacklist := tc.addCache(NewCache())
tokenCache := tc.addTokenCache(NewTokenCache())
customAudience := "https://api.company.com"
tOidc := &TraefikOidc{
next: nextHandler,
name: "test",
redirURLPath: "/callback",
logoutURLPath: "/callback/logout",
issuerURL: "https://auth.company.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
audience: customAudience, // Set custom audience
jwkCache: mockJWKCache,
jwksURL: "https://auth.company.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
allowedUserDomains: map[string]struct{}{"company.com": {}},
excludedURLs: map[string]struct{}{},
httpClient: &http.Client{},
initComplete: make(chan struct{}),
sessionManager: sm,
extractClaimsFunc: extractClaims,
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
close(tOidc.initComplete)
t.Run("End-to-end with correct custom audience", func(t *testing.T) {
// Create a valid token with the custom audience
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://auth.company.com",
"aud": customAudience,
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
"sub": "user-123",
"email": "user@company.com",
"jti": generateRandomString(16),
})
if err != nil {
t.Fatalf("Failed to create valid JWT: %v", err)
}
// Create a request with authenticated session
req := httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "company.com")
// Create session with token
resp := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetAuthenticated(true)
session.SetEmail("user@company.com")
session.SetIDToken(validJWT)
session.SetAccessToken(validJWT)
if err := session.Save(req, resp); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Get cookies and add them to a new request
cookies := resp.Result().Cookies()
req = httptest.NewRequest("GET", "/protected", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "company.com")
for _, cookie := range cookies {
req.AddCookie(cookie)
}
resp = httptest.NewRecorder()
tOidc.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", resp.Code, resp.Body.String())
}
})
}
+54 -23
View File
@@ -11,17 +11,24 @@ import (
"github.com/google/uuid"
)
// ScopeFilter interface for filtering OAuth scopes based on provider capabilities
type ScopeFilter interface {
FilterSupportedScopes(requestedScopes, supportedScopes []string, providerURL string) []string
}
// AuthHandler provides core authentication functionality for OIDC flows
type AuthHandler struct {
logger Logger
enablePKCE bool
isGoogleProv func() bool
isAzureProv func() bool
clientID string
authURL string
issuerURL string
scopes []string
overrideScopes bool
logger Logger
enablePKCE bool
isGoogleProv func() bool
isAzureProv func() bool
clientID string
authURL string
issuerURL string
scopes []string
overrideScopes bool
scopeFilter ScopeFilter // NEW
scopesSupported []string // NEW - from provider metadata
}
// Logger interface for dependency injection
@@ -32,17 +39,20 @@ type Logger interface {
// NewAuthHandler creates a new AuthHandler instance
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool) *AuthHandler {
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool,
scopeFilter ScopeFilter, scopesSupported []string) *AuthHandler {
return &AuthHandler{
logger: logger,
enablePKCE: enablePKCE,
isGoogleProv: isGoogleProv,
isAzureProv: isAzureProv,
clientID: clientID,
authURL: authURL,
issuerURL: issuerURL,
scopes: scopes,
overrideScopes: overrideScopes,
logger: logger,
enablePKCE: enablePKCE,
isGoogleProv: isGoogleProv,
isAzureProv: isAzureProv,
clientID: clientID,
authURL: authURL,
issuerURL: issuerURL,
scopes: scopes,
overrideScopes: overrideScopes,
scopeFilter: scopeFilter, // NEW
scopesSupported: scopesSupported, // NEW
}
}
@@ -144,10 +154,25 @@ func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge stri
scopes := make([]string, len(h.scopes))
copy(scopes, h.scopes)
if h.isGoogleProv() {
params.Set("access_type", "offline")
h.logger.Debugf("Google OIDC provider detected, added access_type=offline for refresh tokens")
// Apply discovery-based scope filtering if available
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
h.logger.Debugf("AuthHandler.BuildAuthURL: After discovery filtering: %v", scopes)
}
// Then apply provider-specific modifications
if h.isGoogleProv() {
// Google: Remove offline_access if present, add access_type=offline
filteredScopes := make([]string, 0, len(scopes))
for _, scope := range scopes {
if scope != "offline_access" {
filteredScopes = append(filteredScopes, scope)
}
}
scopes = filteredScopes
params.Set("access_type", "offline")
h.logger.Debugf("Google OIDC provider detected, added access_type=offline")
params.Set("prompt", "consent")
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
} else if h.isAzureProv() {
@@ -155,7 +180,6 @@ func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge stri
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
hasOfflineAccess = true
@@ -172,6 +196,7 @@ func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge stri
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
}
} else {
// Standard providers: Add offline_access if not overriding and not present
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
hasOfflineAccess := false
for _, scope := range scopes {
@@ -189,6 +214,12 @@ func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge stri
}
}
// Final filtering pass to remove anything the provider doesn't support
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
h.logger.Debugf("AuthHandler.BuildAuthURL: After final filtering: %v", scopes)
}
if len(scopes) > 0 {
finalScopeString := strings.Join(scopes, " ")
params.Set("scope", finalScopeString)
+581 -12
View File
@@ -22,6 +22,28 @@ func (l *mockLogger) Errorf(format string, args ...interface{}) {
l.errorMessages = append(l.errorMessages, format)
}
// mockScopeFilter is a mock implementation of the ScopeFilter interface for testing
type mockScopeFilter struct{}
func (m *mockScopeFilter) FilterSupportedScopes(requestedScopes, supportedScopes []string, providerURL string) []string {
// For testing, just return requested scopes if no supported scopes provided
if len(supportedScopes) == 0 {
return requestedScopes
}
// Simple filter logic for tests
filtered := make([]string, 0, len(requestedScopes))
supportedMap := make(map[string]bool)
for _, s := range supportedScopes {
supportedMap[s] = true
}
for _, s := range requestedScopes {
if supportedMap[s] {
filtered = append(filtered, s)
}
}
return filtered
}
type mockSessionData struct {
authenticated bool
email string
@@ -64,7 +86,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
handler := NewAuthHandler(logger, true, isGoogleProv, isAzureProv,
"test-client-id", "https://example.com/auth", "https://example.com",
scopes, false)
scopes, false, nil, nil)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
@@ -103,7 +125,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
session := &mockSessionData{redirectCount: 5} // At the limit
req := httptest.NewRequest("GET", "/test", nil)
@@ -138,7 +160,7 @@ func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/test", nil)
@@ -169,7 +191,7 @@ func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/test", nil)
@@ -200,7 +222,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T)
func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/test", nil)
@@ -231,7 +253,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T)
func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
session := &mockSessionData{saveError: &testError{"save failed"}}
req := httptest.NewRequest("GET", "/test?param=value", nil)
@@ -275,7 +297,7 @@ func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
func TestAuthHandler_InitiateAuthentication_Success(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false)
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/protected/resource", nil)
@@ -378,7 +400,7 @@ func TestAuthHandler_BuildAuthURL_GoogleProvider(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
"google-client", "https://accounts.google.com/oauth2/auth", "https://accounts.google.com",
[]string{"openid", "profile", "email"}, false)
[]string{"openid", "profile", "email"}, false, nil, nil)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -418,7 +440,7 @@ func TestAuthHandler_BuildAuthURL_AzureProvider(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
"https://login.microsoftonline.com/tenant/v2.0",
[]string{"openid", "profile", "email"}, false)
[]string{"openid", "profile", "email"}, false, nil, nil)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -446,7 +468,7 @@ func TestAuthHandler_BuildAuthURL_PKCEEnabled(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"pkce-client", "https://example.com/auth", "https://example.com",
[]string{"openid"}, false)
[]string{"openid"}, false, nil, nil)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
@@ -471,7 +493,7 @@ func TestAuthHandler_BuildAuthURL_PKCEDisabled(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"no-pkce-client", "https://example.com/auth", "https://example.com",
[]string{"openid"}, false)
[]string{"openid"}, false, nil, nil)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
@@ -543,7 +565,7 @@ func TestAuthHandler_BuildAuthURL_ScopeHandling(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return tt.isAzure },
"test-client", "https://example.com/auth", "https://example.com",
tt.scopes, tt.overrideScopes)
tt.scopes, tt.overrideScopes, nil, nil)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -597,3 +619,550 @@ type testError struct {
func (e *testError) Error() string {
return e.message
}
// SCOPE FILTERING INTEGRATION TESTS
// TestAuthHandler_BuildAuthURL_WithScopeFiltering tests scope filtering when enabled
func TestAuthHandler_BuildAuthURL_WithScopeFiltering(t *testing.T) {
logger := &mockLogger{}
scopeFilter := &mockScopeFilter{}
// Requested scopes include offline_access
scopes := []string{"openid", "profile", "email", "offline_access"}
// Provider only supports these
scopesSupported := []string{"openid", "profile", "email"}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, scopeFilter, scopesSupported)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
actualScope := parsedURL.Query().Get("scope")
actualScopes := strings.Split(actualScope, " ")
// offline_access should have been filtered out in the first pass
// The standard provider logic then tries to add it back
// But the final filtering pass removes it again
for _, scope := range actualScopes {
if scope == "offline_access" {
t.Error("offline_access should have been filtered out when not in scopesSupported")
}
}
// Should contain the supported scopes
if !strings.Contains(actualScope, "openid") {
t.Error("Expected openid in final scope string")
}
if !strings.Contains(actualScope, "profile") {
t.Error("Expected profile in final scope string")
}
if !strings.Contains(actualScope, "email") {
t.Error("Expected email in final scope string")
}
}
// TestAuthHandler_BuildAuthURL_WithoutScopeFiltering tests backward compatibility
func TestAuthHandler_BuildAuthURL_WithoutScopeFiltering(t *testing.T) {
logger := &mockLogger{}
scopes := []string{"openid", "profile", "email"}
// No scopeFilter or scopesSupported (backward compatibility)
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, nil, nil)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
actualScope := parsedURL.Query().Get("scope")
// All scopes should be present, plus offline_access added by standard provider logic
if !strings.Contains(actualScope, "openid") {
t.Error("Expected openid in scope string")
}
if !strings.Contains(actualScope, "profile") {
t.Error("Expected profile in scope string")
}
if !strings.Contains(actualScope, "email") {
t.Error("Expected email in scope string")
}
if !strings.Contains(actualScope, "offline_access") {
t.Error("Expected offline_access added by standard provider logic")
}
}
// TestAuthHandler_BuildAuthURL_GitLabFiltersOfflineAccess tests GitLab scenario
func TestAuthHandler_BuildAuthURL_GitLabFiltersOfflineAccess(t *testing.T) {
logger := &mockLogger{}
scopeFilter := &mockScopeFilter{}
scopes := []string{"openid", "profile", "email", "offline_access"}
// GitLab discovery doc doesn't include offline_access
scopesSupported := []string{"openid", "profile", "email", "read_user", "read_api"}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"gitlab-client", "https://gitlab.example.com/oauth/authorize",
"https://gitlab.example.com",
scopes, false, scopeFilter, scopesSupported)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
actualScope := parsedURL.Query().Get("scope")
actualScopes := strings.Split(actualScope, " ")
// offline_access should be filtered out
for _, scope := range actualScopes {
if scope == "offline_access" {
t.Error("GitLab scenario: offline_access should have been filtered out")
}
}
// Should contain standard scopes
if !strings.Contains(actualScope, "openid") {
t.Error("Expected openid in final scope string")
}
if !strings.Contains(actualScope, "profile") {
t.Error("Expected profile in final scope string")
}
if !strings.Contains(actualScope, "email") {
t.Error("Expected email in final scope string")
}
}
// TestAuthHandler_BuildAuthURL_GoogleRemovesOfflineAccess tests Google provider
func TestAuthHandler_BuildAuthURL_GoogleRemovesOfflineAccess(t *testing.T) {
logger := &mockLogger{}
scopeFilter := &mockScopeFilter{}
scopes := []string{"openid", "profile", "email", "offline_access"}
scopesSupported := []string{"openid", "profile", "email"}
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
"google-client", "https://accounts.google.com/o/oauth2/v2/auth",
"https://accounts.google.com",
scopes, false, scopeFilter, scopesSupported)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
query := parsedURL.Query()
actualScope := query.Get("scope")
actualScopes := strings.Split(actualScope, " ")
// Google removes offline_access and uses access_type=offline instead
for _, scope := range actualScopes {
if scope == "offline_access" {
t.Error("Google scenario: offline_access should have been removed by Google-specific logic")
}
}
// Google-specific parameters should be present
if query.Get("access_type") != "offline" {
t.Error("Expected access_type=offline for Google")
}
if query.Get("prompt") != "consent" {
t.Error("Expected prompt=consent for Google")
}
}
// TestAuthHandler_BuildAuthURL_AzureAddsOfflineAccess tests Azure provider
func TestAuthHandler_BuildAuthURL_AzureAddsOfflineAccess(t *testing.T) {
logger := &mockLogger{}
scopeFilter := &mockScopeFilter{}
scopes := []string{"openid", "profile", "email"}
// Azure supports offline_access
scopesSupported := []string{"openid", "profile", "email", "offline_access"}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
"https://login.microsoftonline.com/tenant/v2.0",
scopes, false, scopeFilter, scopesSupported)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
query := parsedURL.Query()
actualScope := query.Get("scope")
// Azure should add offline_access automatically and it should pass filtering
if !strings.Contains(actualScope, "offline_access") {
t.Error("Azure scenario: offline_access should be present")
}
// Azure-specific parameter
if query.Get("response_mode") != "query" {
t.Error("Expected response_mode=query for Azure")
}
}
// TestAuthHandler_BuildAuthURL_GenericWithFiltering tests generic provider with discovery filtering
func TestAuthHandler_BuildAuthURL_GenericWithFiltering(t *testing.T) {
logger := &mockLogger{}
scopeFilter := &mockScopeFilter{}
scopes := []string{"openid", "profile", "email", "custom_scope", "offline_access"}
scopesSupported := []string{"openid", "profile", "email", "custom_scope"}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"generic-client", "https://auth.provider.com/authorize",
"https://auth.provider.com",
scopes, false, scopeFilter, scopesSupported)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
actualScope := parsedURL.Query().Get("scope")
// Should contain supported scopes including custom_scope
if !strings.Contains(actualScope, "openid") {
t.Error("Expected openid in scope string")
}
if !strings.Contains(actualScope, "custom_scope") {
t.Error("Expected custom_scope in scope string")
}
// offline_access should be filtered out (not in scopesSupported)
actualScopes := strings.Split(actualScope, " ")
for _, scope := range actualScopes {
if scope == "offline_access" {
t.Error("offline_access should have been filtered out when not supported")
}
}
}
// TestAuthHandler_BuildAuthURL_OverrideScopesWithFiltering tests override scopes + filtering
func TestAuthHandler_BuildAuthURL_OverrideScopesWithFiltering(t *testing.T) {
logger := &mockLogger{}
scopeFilter := &mockScopeFilter{}
// User explicitly overrides scopes
scopes := []string{"openid", "custom:read", "custom:write"}
scopesSupported := []string{"openid", "custom:read"}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, true, scopeFilter, scopesSupported)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
actualScope := parsedURL.Query().Get("scope")
actualScopes := strings.Split(actualScope, " ")
// Should contain only supported scopes from override
if !strings.Contains(actualScope, "openid") {
t.Error("Expected openid in scope string")
}
if !strings.Contains(actualScope, "custom:read") {
t.Error("Expected custom:read in scope string")
}
// custom:write should be filtered out
for _, scope := range actualScopes {
if scope == "custom:write" {
t.Error("custom:write should have been filtered out (not supported)")
}
}
// offline_access should NOT be auto-added when overrideScopes=true
for _, scope := range actualScopes {
if scope == "offline_access" {
t.Error("offline_access should not be auto-added when user overrides scopes")
}
}
}
// TestAuthHandler_BuildAuthURL_DoubleFiltering tests initial + final filtering passes
func TestAuthHandler_BuildAuthURL_DoubleFiltering(t *testing.T) {
logger := &mockLogger{}
scopeFilter := &mockScopeFilter{}
scopes := []string{"openid", "profile", "email"}
// Provider supports offline_access
scopesSupported := []string{"openid", "profile", "email", "offline_access"}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, scopeFilter, scopesSupported)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
actualScope := parsedURL.Query().Get("scope")
// Initial filtering: All requested scopes pass (all in scopesSupported)
// Provider-specific logic: Adds offline_access (standard provider)
// Final filtering: offline_access should still be present (it's in scopesSupported)
if !strings.Contains(actualScope, "offline_access") {
t.Error("offline_access should be present (supported by provider and added by logic)")
}
// Original scopes should be present
if !strings.Contains(actualScope, "openid") {
t.Error("Expected openid in scope string")
}
if !strings.Contains(actualScope, "profile") {
t.Error("Expected profile in scope string")
}
if !strings.Contains(actualScope, "email") {
t.Error("Expected email in scope string")
}
}
// TestAuthHandler_BuildAuthURL_NoScopeFilterProvided tests when scopeFilter is nil
func TestAuthHandler_BuildAuthURL_NoScopeFilterProvided(t *testing.T) {
logger := &mockLogger{}
scopes := []string{"openid", "profile", "email"}
scopesSupported := []string{"openid", "profile"} // Even with scopesSupported, no filter
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, nil, scopesSupported) // scopeFilter is nil
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
actualScope := parsedURL.Query().Get("scope")
// Without scopeFilter, all scopes should be present (no filtering)
if !strings.Contains(actualScope, "openid") {
t.Error("Expected openid in scope string")
}
if !strings.Contains(actualScope, "profile") {
t.Error("Expected profile in scope string")
}
if !strings.Contains(actualScope, "email") {
t.Error("Expected email in scope string (no filtering without scopeFilter)")
}
}
// TestAuthHandler_BuildAuthURL_EmptyScopesSupported tests empty scopesSupported list
func TestAuthHandler_BuildAuthURL_EmptyScopesSupported(t *testing.T) {
logger := &mockLogger{}
scopeFilter := &mockScopeFilter{}
scopes := []string{"openid", "profile", "email"}
scopesSupported := []string{} // Empty - backward compatibility mode
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, scopeFilter, scopesSupported)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
actualScope := parsedURL.Query().Get("scope")
// With empty scopesSupported, mockScopeFilter returns requested scopes unchanged
if !strings.Contains(actualScope, "openid") {
t.Error("Expected openid in scope string")
}
if !strings.Contains(actualScope, "profile") {
t.Error("Expected profile in scope string")
}
if !strings.Contains(actualScope, "email") {
t.Error("Expected email in scope string")
}
}
// TestAuthHandler_BuildAuthURL_FilteringWithPKCE tests scope filtering with PKCE enabled
func TestAuthHandler_BuildAuthURL_FilteringWithPKCE(t *testing.T) {
logger := &mockLogger{}
scopeFilter := &mockScopeFilter{}
scopes := []string{"openid", "profile", "offline_access"}
scopesSupported := []string{"openid", "profile"}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, scopeFilter, scopesSupported)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
query := parsedURL.Query()
// PKCE parameters should be present
if query.Get("code_challenge") != "test-challenge" {
t.Error("Expected code_challenge parameter with PKCE enabled")
}
if query.Get("code_challenge_method") != "S256" {
t.Error("Expected code_challenge_method=S256 with PKCE enabled")
}
// Scope filtering should still work
actualScope := query.Get("scope")
actualScopes := strings.Split(actualScope, " ")
for _, scope := range actualScopes {
if scope == "offline_access" {
t.Error("offline_access should have been filtered out even with PKCE")
}
}
}
// TestAuthHandler_BuildAuthURL_ComplexScenario tests realistic complex scenario
func TestAuthHandler_BuildAuthURL_ComplexScenario(t *testing.T) {
logger := &mockLogger{}
scopeFilter := &mockScopeFilter{}
// User configures: openid, profile, email, custom:read, offline_access
scopes := []string{"openid", "profile", "email", "custom:read", "offline_access"}
// Provider discovery returns: openid, profile, email, custom:read, custom:write, admin:all
scopesSupported := []string{"openid", "profile", "email", "custom:read", "custom:write", "admin:all"}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"complex-client", "https://auth.complex.com/authorize", "https://auth.complex.com",
scopes, false, scopeFilter, scopesSupported)
authURL := handler.BuildAuthURL("https://example.com/callback", "state-123", "nonce-456", "challenge-789")
parsedURL, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Failed to parse auth URL: %v", err)
}
query := parsedURL.Query()
// Verify basic OAuth parameters
if query.Get("client_id") != "complex-client" {
t.Error("Expected correct client_id")
}
if query.Get("response_type") != "code" {
t.Error("Expected response_type=code")
}
if query.Get("state") != "state-123" {
t.Error("Expected correct state")
}
if query.Get("nonce") != "nonce-456" {
t.Error("Expected correct nonce")
}
// Verify PKCE parameters
if query.Get("code_challenge") != "challenge-789" {
t.Error("Expected correct code_challenge")
}
// Verify scope filtering
actualScope := query.Get("scope")
// Should contain: openid, profile, email, custom:read
if !strings.Contains(actualScope, "openid") {
t.Error("Expected openid in scope")
}
if !strings.Contains(actualScope, "profile") {
t.Error("Expected profile in scope")
}
if !strings.Contains(actualScope, "email") {
t.Error("Expected email in scope")
}
if !strings.Contains(actualScope, "custom:read") {
t.Error("Expected custom:read in scope")
}
// offline_access should be filtered (not in scopesSupported)
actualScopes := strings.Split(actualScope, " ")
for _, scope := range actualScopes {
if scope == "offline_access" {
t.Error("offline_access should have been filtered (not in scopesSupported)")
}
}
}
// TestAuthHandler_BuildAuthURL_LoggingVerification tests that logging occurs correctly
func TestAuthHandler_BuildAuthURL_LoggingVerification(t *testing.T) {
logger := &mockLogger{}
scopeFilter := &mockScopeFilter{}
scopes := []string{"openid", "profile", "offline_access"}
scopesSupported := []string{"openid", "profile"}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, scopeFilter, scopesSupported)
handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
// Should have logged debug messages about filtering
if len(logger.debugMessages) == 0 {
t.Error("Expected debug messages to be logged during scope filtering")
}
// Verify specific log messages were generated
hasDiscoveryFilterLog := false
hasFinalFilterLog := false
hasFinalScopeLog := false
for _, msg := range logger.debugMessages {
if strings.Contains(msg, "After discovery filtering") {
hasDiscoveryFilterLog = true
}
if strings.Contains(msg, "After final filtering") {
hasFinalFilterLog = true
}
if strings.Contains(msg, "Final scope string being sent") {
hasFinalScopeLog = true
}
}
if !hasDiscoveryFilterLog {
t.Error("Expected log message about discovery filtering")
}
if !hasFinalFilterLog {
t.Error("Expected log message about final filtering")
}
if !hasFinalScopeLog {
t.Error("Expected log message about final scope string")
}
}
+5 -5
View File
@@ -10,7 +10,7 @@ import (
func TestAuthHandler_validateURL(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
tests := []struct {
name string
@@ -185,7 +185,7 @@ func TestAuthHandler_validateURL(t *testing.T) {
func TestAuthHandler_validateHost(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
tests := []struct {
name string
@@ -334,7 +334,7 @@ func TestAuthHandler_validateHost(t *testing.T) {
func TestAuthHandler_buildURLWithParams(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
tests := []struct {
name string
@@ -438,7 +438,7 @@ func TestAuthHandler_buildURLWithParams(t *testing.T) {
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
// Test special characters that need encoding
params := url.Values{
@@ -477,7 +477,7 @@ func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
func TestAuthHandler_validateParsedURL(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
tests := []struct {
name string
+1
View File
@@ -58,6 +58,7 @@ func TestAzureOIDCRegression(t *testing.T) {
tokenURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token",
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
scopes: []string{"openid", "profile", "email"},
refreshGracePeriod: 60 * time.Second,
+189 -4
View File
@@ -89,8 +89,9 @@ scopes: ["openid", "profile", "email", "offline_access"]
- **Offline access**: Requires `offline_access` scope for refresh tokens
- **Access token validation**: Supports both JWT and opaque access tokens
- **Tenant isolation**: Can restrict to specific Azure AD tenants
- **Application ID URI**: Supports custom audience for protected APIs
### Example Configuration
### Example Configuration (Basic)
```yaml
http:
middlewares:
@@ -108,6 +109,33 @@ http:
forceHttps: true
```
### Azure AD API Configuration (Application ID URI)
When exposing your application as an API with a custom Application ID URI, you need to specify the `audience` parameter. Azure AD includes the Application ID URI in the JWT `aud` claim.
```yaml
http:
middlewares:
azure-api-oidc:
plugin:
traefik-oidc:
providerUrl: "https://login.microsoftonline.com/common/v2.0"
clientId: "12345678-1234-1234-1234-123456789abc"
clientSecret: "your-azure-client-secret"
# Specify the Application ID URI as audience
audience: "api://12345678-1234-1234-1234-123456789abc"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email", "offline_access"]
forceHttps: true
```
**Important**:
- The `audience` parameter should match your Application ID URI (typically `api://{app-id}`)
- Find your Application ID URI in Azure Portal → App Registration → Expose an API → Application ID URI
- Without the `audience` parameter, access tokens with custom audiences will be rejected
- For ID token validation only (no API access), you can omit the `audience` parameter
### Azure App Registration Setup
1. Go to [Azure Portal](https://portal.azure.com/)
2. Navigate to "Azure Active Directory" > "App registrations"
@@ -116,6 +144,12 @@ http:
5. Create client secret in "Certificates & secrets"
6. Configure API permissions for required scopes
### Azure AD API Exposure Setup (for custom audiences)
1. In your App Registration, go to "Expose an API"
2. Set the Application ID URI (e.g., `api://12345678-1234-1234-1234-123456789abc`)
3. Add any custom scopes your API exposes
4. Update the middleware configuration to include the `audience` parameter with this URI
---
## Auth0
@@ -138,8 +172,9 @@ scopes: ["openid", "profile", "email", "offline_access"]
- **Rules and hooks**: Leverages Auth0's extensibility
- **Social connections**: Works with Auth0's social identity providers
- **Offline access**: Requires `offline_access` scope
- **API audiences**: Supports custom audience for API access tokens
### Example Configuration
### Example Configuration (Basic)
```yaml
http:
middlewares:
@@ -158,6 +193,34 @@ http:
enablePkce: true
```
### Auth0 API Configuration (Custom Audience)
When using Auth0 APIs with custom audience parameters, you need to specify the `audience` field. Auth0 includes the API identifier in the JWT `aud` claim instead of the `clientId`.
```yaml
http:
middlewares:
auth0-api-oidc:
plugin:
traefik-oidc:
providerUrl: "https://company.auth0.com"
clientId: "abcdef123456789"
clientSecret: "your-auth0-client-secret"
# Specify the Auth0 API identifier as audience
audience: "https://api.company.com"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email", "offline_access"]
forceHttps: true
enablePkce: true
```
**Important**:
- The `audience` parameter should match your Auth0 API identifier (not the client ID)
- Find your API identifier in Auth0 Dashboard → APIs → Your API → Settings → Identifier
- Without the `audience` parameter, access tokens with custom audiences will be rejected with "invalid audience" error
- For ID token validation only (no APIs), you can omit the `audience` parameter
### Auth0 Application Setup
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
2. Create new application (Regular Web Application)
@@ -165,6 +228,14 @@ http:
4. Configure allowed logout URLs: `https://your-domain.com/auth/logout`
5. Enable OIDC Conformant in Advanced Settings
### Auth0 API Setup (for custom audiences)
1. Go to Auth0 Dashboard → APIs
2. Create a new API or select existing API
3. Note the "Identifier" field (e.g., `https://api.company.com`) - this is your `audience` value
4. In API Settings → Machine to Machine Applications, authorize your application
5. Configure API permissions/scopes as needed
6. Use the API identifier as the `audience` parameter in your configuration
---
## GitHub
@@ -236,7 +307,7 @@ scopes: ["openid", "profile", "email"]
- **Self-hosted support**: Works with self-hosted GitLab instances
- **Group membership**: Can restrict by GitLab groups
- **Project access**: Can validate project permissions
- **Offline access**: Supports refresh tokens with `offline_access`
- **Offline access**: Supports refresh tokens without requiring `offline_access` scope
### Example Configuration
```yaml
@@ -250,7 +321,9 @@ http:
clientSecret: "your-gitlab-application-secret"
callbackUrl: "https://app.example.com/auth/callback"
logoutUrl: "https://app.example.com/auth/logout"
scopes: ["openid", "profile", "email", "offline_access"]
scopes: ["openid", "profile", "email"]
# Note: GitLab doesn't support the offline_access scope.
# Refresh tokens are issued automatically for the openid scope.
allowedRolesAndGroups: ["developers", "maintainers"]
forceHttps: true
enablePkce: true
@@ -459,8 +532,120 @@ http:
---
## Automatic Scope Filtering
### Overview
The middleware automatically filters OAuth scopes based on the provider's capabilities declared in their OIDC discovery document (`.well-known/openid-configuration`). This prevents authentication failures when providers reject unsupported scopes.
### How It Works
1. **Discovery Document Parsing**: The middleware fetches the provider's discovery document and extracts the `scopes_supported` field
2. **Intelligent Filtering**: Requested scopes are filtered to only include those the provider supports
3. **Fallback Behavior**: If the provider doesn't declare `scopes_supported`, all requested scopes are used (backward compatible)
4. **Provider-Specific Handling**: Special logic for Google and Azure is preserved and applied after filtering
### Example Scenarios
#### Self-Hosted GitLab
**Problem**: Self-hosted GitLab instances reject the `offline_access` scope with error:
```
The requested scope is invalid, unknown, or malformed.
```
**Solution**: The middleware automatically detects this by:
1. Reading GitLab's discovery document at `https://gitlab.example.com/.well-known/openid-configuration`
2. Observing that `offline_access` is NOT in the `scopes_supported` list
3. Filtering out `offline_access` from the request
4. Authentication succeeds
**Configuration**:
```yaml
http:
middlewares:
gitlab-oidc:
plugin:
traefik-oidc:
providerUrl: "https://gitlab.example.com"
clientId: "your-gitlab-application-id"
clientSecret: "your-gitlab-application-secret"
callbackUrl: "https://app.example.com/auth/callback"
scopes: ["openid", "profile", "email", "offline_access"]
# Even though offline_access is listed, it will be automatically
# filtered out if GitLab doesn't support it
```
#### Auth0 or Keycloak
These providers typically support `offline_access` and it will be included:
```yaml
# Auth0 scopes_supported: ["openid", "profile", "email", "offline_access", ...]
# Result: All requested scopes are sent
```
### Benefits
1. **Self-Hosted Support**: Works seamlessly with self-hosted provider instances
2. **No Manual Configuration**: No need to know which scopes each provider supports
3. **Error Prevention**: Eliminates "invalid scope" authentication failures
4. **Standards Compliant**: Uses official OIDC discovery specification (RFC 8414)
5. **Backward Compatible**: Existing configurations continue to work
### Logging
The middleware provides detailed logging for scope filtering:
```
INFO: ScopeFilter: Filtered unsupported scopes for https://gitlab.example.com: [offline_access]
DEBUG: ScopeFilter: Provider https://gitlab.example.com supported scopes: [openid profile email read_user read_api]
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
```
### Troubleshooting
**Issue**: Provider rejects scope even after filtering
**Possible Causes**:
1. Provider's discovery document is outdated
2. Provider doesn't properly implement `scopes_supported`
3. Custom authorization server with non-standard behavior
**Solutions**:
1. Use `overrideScopes: true` and explicitly list only supported scopes
2. Check the provider's discovery document manually: `curl https://your-provider/.well-known/openid-configuration`
3. Review middleware debug logs for filtering decisions
---
## Common Configuration Options
### Audience Configuration
The `audience` parameter specifies the expected JWT audience claim value. This is particularly important when using Auth0 APIs, Azure AD Application ID URIs, or other providers with custom audience requirements.
```yaml
# Optional: Custom audience for JWT validation
# If not set, defaults to clientID for backward compatibility
audience: "https://api.example.com" # Auth0 API identifier
# OR
audience: "api://12345-guid" # Azure AD Application ID URI
```
**When to use**:
- **Auth0**: When using Auth0 APIs with custom audience parameters
- **Azure AD**: When exposing your app as an API with Application ID URI
- **Keycloak**: When using audience-restricted tokens
- **Okta**: When using custom authorization servers with API audiences
**When to omit**:
- For standard ID token validation (default behavior)
- When the provider sets `aud` claim to your `clientID`
- For backward compatibility with existing configurations
**Security Note**: The `audience` parameter prevents token confusion attacks by ensuring tokens issued for one service cannot be used at another service.
### Security Settings
```yaml
# Force HTTPS (recommended for production)
+13 -3
View File
@@ -124,7 +124,12 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
}
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
// Read tokenURL with RLock
t.metadataMu.RLock()
tokenURL := t.tokenURL
t.metadataMu.RUnlock()
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
@@ -355,8 +360,13 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
}
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
// Read endSessionURL with RLock
t.metadataMu.RLock()
endSessionURL := t.endSessionURL
t.metadataMu.RUnlock()
if endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
+3 -1
View File
@@ -155,7 +155,9 @@ func (r *ProviderRegistry) detectProviderUnsafe(issuerURL string) OIDCProvider {
return p
}
case ProviderTypeGitLab:
if strings.Contains(host, "gitlab.com") {
// Match gitlab.com, self-hosted (gitlab.*), and instances with gitlab in subdomain
if strings.Contains(host, "gitlab.com") ||
strings.Contains(host, "gitlab") {
return p
}
}
+221
View File
@@ -1,6 +1,7 @@
package providers
import (
"fmt"
"sync"
"testing"
)
@@ -238,6 +239,26 @@ func TestProviderRegistry_DetectProvider(t *testing.T) {
issuerURL: "https://gitlab.com/oauth",
expected: gitlabProvider,
},
{
name: "GitLab self-hosted detection - gitlab subdomain",
issuerURL: "https://gitlab.example.com",
expected: gitlabProvider,
},
{
name: "GitLab self-hosted detection - gitlab in domain",
issuerURL: "https://my-gitlab.company.io",
expected: gitlabProvider,
},
{
name: "GitLab self-hosted detection - gitlab prefix",
issuerURL: "https://gitlab-prod.internal.net",
expected: gitlabProvider,
},
{
name: "GitLab self-hosted detection - gitlab suffix",
issuerURL: "https://company-gitlab.net",
expected: gitlabProvider,
},
{
name: "Generic provider fallback",
issuerURL: "https://auth.example.com",
@@ -482,6 +503,206 @@ func TestProviderRegistry_DoubleCheckedLocking(t *testing.T) {
}
}
// TestProviderRegistry_DetectGitLabSelfHosted tests improved GitLab detection for issue #61
func TestProviderRegistry_DetectGitLabSelfHosted(t *testing.T) {
registry := NewProviderRegistry()
genericProvider := NewGenericProvider()
gitlabProvider := NewGitLabProvider()
githubProvider := NewGitHubProvider()
registry.RegisterProvider(genericProvider)
registry.RegisterProvider(gitlabProvider)
registry.RegisterProvider(githubProvider)
tests := []struct {
name string
issuerURL string
expected OIDCProvider
description string
}{
{
name: "GitLab.com official",
issuerURL: "https://gitlab.com",
expected: gitlabProvider,
description: "Should detect official GitLab.com",
},
{
name: "GitLab.com with path",
issuerURL: "https://gitlab.com/oauth/authorize",
expected: gitlabProvider,
description: "Should detect GitLab.com with path",
},
{
name: "Self-hosted gitlab.example.com",
issuerURL: "https://gitlab.example.com",
expected: gitlabProvider,
description: "Should detect gitlab as subdomain",
},
{
name: "Self-hosted my.gitlab.io",
issuerURL: "https://my.gitlab.io",
expected: gitlabProvider,
description: "Should detect gitlab in domain",
},
{
name: "Self-hosted example-gitlab.com",
issuerURL: "https://example-gitlab.com",
expected: gitlabProvider,
description: "Should detect gitlab as suffix",
},
{
name: "Self-hosted gitlab-prod.company.net",
issuerURL: "https://gitlab-prod.company.net",
expected: gitlabProvider,
description: "Should detect gitlab as prefix",
},
{
name: "Self-hosted my-gitlab.internal",
issuerURL: "https://my-gitlab.internal",
expected: gitlabProvider,
description: "Should detect gitlab in middle of host",
},
{
name: "Self-hosted company.gitlab.services",
issuerURL: "https://company.gitlab.services",
expected: gitlabProvider,
description: "Should detect gitlab in middle of domain",
},
{
name: "Self-hosted with port",
issuerURL: "https://gitlab.example.com:8443",
expected: gitlabProvider,
description: "Should detect GitLab with custom port",
},
{
name: "Self-hosted with path and query",
issuerURL: "https://gitlab.example.com/oauth?param=value",
expected: gitlabProvider,
description: "Should detect GitLab with complex URL",
},
{
name: "Case insensitive - GITLAB",
issuerURL: "https://GITLAB.example.com",
expected: gitlabProvider,
description: "Should detect GitLab case-insensitively",
},
{
name: "Case insensitive - GitLab",
issuerURL: "https://GitLab.example.com",
expected: gitlabProvider,
description: "Should detect GitLab with mixed case",
},
{
name: "Not GitLab - git prefix only",
issuerURL: "https://github.com",
expected: githubProvider, // Should match GitHub provider, not GitLab
description: "Should not match github.com as GitLab",
},
{
name: "Not GitLab - lab suffix only",
issuerURL: "https://mylab.example.com",
expected: genericProvider,
description: "Should not match partial gitlab string",
},
{
name: "Not GitLab - git and lab separate",
issuerURL: "https://git.mylab.example.com",
expected: genericProvider,
description: "Should not match git and lab when not together",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Clear cache to ensure fresh detection
registry.ClearCache()
result := registry.DetectProvider(tt.issuerURL)
if result != tt.expected {
t.Errorf("%s: Expected %v, got %v", tt.description, tt.expected, result)
}
})
}
}
// TestProviderRegistry_GitLabDetection_RealWorldURLs tests real-world GitLab URLs
func TestProviderRegistry_GitLabDetection_RealWorldURLs(t *testing.T) {
registry := NewProviderRegistry()
genericProvider := NewGenericProvider()
gitlabProvider := NewGitLabProvider()
githubProvider := NewGitHubProvider()
registry.RegisterProvider(genericProvider)
registry.RegisterProvider(gitlabProvider)
registry.RegisterProvider(githubProvider)
realWorldTests := []struct {
name string
issuerURL string
expected OIDCProvider
}{
// Actual self-hosted GitLab examples from issue #61
{
name: "Company self-hosted GitLab",
issuerURL: "https://gitlab.company.com",
expected: gitlabProvider,
},
{
name: "Organization GitLab instance with gitlab in subdomain",
issuerURL: "https://gitlab.organization.org",
expected: gitlabProvider,
},
{
name: "Internal GitLab server",
issuerURL: "https://gitlab.internal.corp",
expected: gitlabProvider,
},
{
name: "GitLab with custom subdomain",
issuerURL: "https://code.gitlab.mycompany.com",
expected: gitlabProvider,
},
// Negative cases to ensure we don't over-match
{
name: "GitHub should not match GitLab",
issuerURL: "https://github.com",
expected: githubProvider,
},
{
name: "Generic git server",
issuerURL: "https://git.example.com",
expected: genericProvider,
},
}
for _, tt := range realWorldTests {
t.Run(tt.name, func(t *testing.T) {
registry.ClearCache()
result := registry.DetectProvider(tt.issuerURL)
if result != tt.expected {
var expectedType, resultType string
if tt.expected != nil {
expectedType = fmt.Sprintf("%v", tt.expected.GetType())
} else {
expectedType = "nil"
}
if result != nil {
resultType = fmt.Sprintf("%v", result.GetType())
} else {
resultType = "nil"
}
t.Errorf("Expected provider type %s, got %s for URL %s",
expectedType, resultType, tt.issuerURL)
}
})
}
}
// Benchmark tests
func BenchmarkProviderRegistry_DetectProvider_Cached(b *testing.B) {
registry := NewProviderRegistry()
+3
View File
@@ -586,6 +586,7 @@ func TestIssue67_TokenResilienceRecursionBug(t *testing.T) {
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenResilienceManager: resilienceManager,
tokenHTTPClient: &http.Client{
@@ -671,6 +672,7 @@ func TestIssue67_TokenResilienceManager_NoRecursion(t *testing.T) {
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenResilienceManager: resilienceManager,
tokenHTTPClient: &http.Client{
@@ -738,6 +740,7 @@ func TestIssue67_DirectRecursionDetection(t *testing.T) {
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test",
audience: "test",
clientSecret: "test",
tokenResilienceManager: NewTokenResilienceManager(config, logger),
tokenHTTPClient: &http.Client{Timeout: 2 * time.Second},
+3 -3
View File
@@ -257,12 +257,12 @@ func parseJWT(tokenString string) (*JWT, error) {
// not-before time (if present), and prevents replay attacks using JTI claims.
// Parameters:
// - issuerURL: Expected issuer URL to validate against
// - clientID: Expected audience (client ID) to validate against
// - expectedAudience: Expected audience to validate against (can be clientID or custom audience)
// - skipReplayCheck: Optional parameter to skip replay attack protection
//
// Returns:
// - An error describing the first validation failure encountered
func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error {
func (j *JWT) Verify(issuerURL, expectedAudience string, skipReplayCheck ...bool) error {
alg, ok := j.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing 'alg' header")
@@ -290,7 +290,7 @@ func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error
if !ok {
return fmt.Errorf("missing 'aud' claim")
}
if err := verifyAudience(aud, clientID); err != nil {
if err := verifyAudience(aud, expectedAudience); err != nil {
return err
}
+21
View File
@@ -11,12 +11,30 @@ var (
singletonNoOpLogger *Logger
// noOpLoggerOnce ensures the singleton is created only once
noOpLoggerOnce sync.Once
// noOpLoggerMu protects access to the singleton logger during reset
noOpLoggerMu sync.RWMutex
)
// GetSingletonNoOpLogger returns the singleton no-op logger instance.
// This reduces memory allocation by reusing the same no-op logger
// instance across the entire application.
func GetSingletonNoOpLogger() *Logger {
noOpLoggerMu.RLock()
if singletonNoOpLogger != nil {
logger := singletonNoOpLogger
noOpLoggerMu.RUnlock()
return logger
}
noOpLoggerMu.RUnlock()
noOpLoggerMu.Lock()
defer noOpLoggerMu.Unlock()
// Double-check after acquiring write lock
if singletonNoOpLogger != nil {
return singletonNoOpLogger
}
noOpLoggerOnce.Do(func() {
singletonNoOpLogger = &Logger{
logError: log.New(io.Discard, "", 0),
@@ -29,6 +47,9 @@ func GetSingletonNoOpLogger() *Logger {
// ResetSingletonNoOpLogger resets the singleton instance (mainly for testing)
func ResetSingletonNoOpLogger() {
noOpLoggerMu.Lock()
defer noOpLoggerMu.Unlock()
noOpLoggerOnce = sync.Once{}
singletonNoOpLogger = nil
}
+18
View File
@@ -157,6 +157,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
metadataCache: cacheManager.GetSharedMetadataCache(),
clientID: config.ClientID,
clientSecret: config.ClientSecret,
audience: func() string {
if config.Audience != "" {
return config.Audience
}
return config.ClientID
}(),
forceHTTPS: config.ForceHTTPS,
enablePKCE: config.EnablePKCE,
overrideScopes: config.OverrideScopes,
@@ -192,6 +198,14 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
cancelFunc: cancelFunc,
suppressDiagnosticLogs: isTestMode(),
securityHeadersApplier: config.GetSecurityHeadersApplier(),
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
}
// Log audience configuration
if config.Audience != "" && config.Audience != config.ClientID {
t.logger.Infof("Custom audience configured: %s", config.Audience)
} else {
t.logger.Debugf("No custom audience specified, using clientID as audience: %s", t.clientID)
}
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, t.logger)
@@ -345,7 +359,11 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
// Parameters:
// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints.
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
t.metadataMu.Lock()
defer t.metadataMu.Unlock()
t.jwksURL = metadata.JWKSURL
t.scopesSupported = metadata.ScopesSupported // NEW - store supported scopes from discovery
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
+18 -1
View File
@@ -10,6 +10,7 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"
"time"
)
@@ -37,6 +38,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -82,6 +84,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
enablePKCE: true,
tokenHTTPClient: &http.Client{
@@ -116,6 +119,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/invalid",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -146,6 +150,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/expired",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -176,6 +181,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/timeout",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 100 * time.Millisecond,
@@ -206,6 +212,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/error",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -236,6 +243,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/malformed",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -266,6 +274,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/incomplete",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -299,6 +308,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/slow",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -329,6 +339,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/ratelimit",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -482,13 +493,17 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
// TestExchangeCodeForToken_Integration tests integration scenarios
func TestExchangeCodeForToken_Integration(t *testing.T) {
t.Run("multiple concurrent exchanges", func(t *testing.T) {
// Use atomic counter for unique token generation to handle race detector slowdown
var tokenCounter int64
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add small delay to test concurrency
time.Sleep(10 * time.Millisecond)
// Generate unique token using atomic counter
tokenID := atomic.AddInt64(&tokenCounter, 1)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: fmt.Sprintf("token_%d", time.Now().UnixNano()),
AccessToken: fmt.Sprintf("token_%d", tokenID),
IDToken: "test_id_token",
RefreshToken: "test_refresh_token",
TokenType: "Bearer",
@@ -500,6 +515,7 @@ func TestExchangeCodeForToken_Integration(t *testing.T) {
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -586,6 +602,7 @@ func TestExchangeCodeForToken_Integration(t *testing.T) {
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
+14
View File
@@ -30,6 +30,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -71,6 +72,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/expired",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -97,6 +99,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/invalid",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -123,6 +126,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/revoked",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -149,6 +153,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/timeout",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 100 * time.Millisecond,
@@ -175,6 +180,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/error",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -201,6 +207,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/malformed",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -228,6 +235,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/partial",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -259,6 +267,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/ratelimit",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -285,6 +294,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -315,6 +325,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
return &TraefikOidc{
tokenURL: server.URL + "/token/rotating",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -519,6 +530,7 @@ func TestGetNewTokenWithRefreshToken_Concurrency(t *testing.T) {
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -588,6 +600,7 @@ func TestGetNewTokenWithRefreshToken_Concurrency(t *testing.T) {
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
@@ -642,6 +655,7 @@ func TestGetNewTokenWithRefreshToken_ErrorRecovery(t *testing.T) {
oidc := &TraefikOidc{
tokenURL: server.URL + "/token",
clientID: "test_client",
audience: "test_client",
clientSecret: "test_secret",
tokenHTTPClient: &http.Client{
Timeout: 10 * time.Second,
+2
View File
@@ -192,6 +192,7 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
logoutURLPath: "/logout",
tokenURL: "https://provider.example.com/token",
clientID: "test-client",
audience: "test-client",
clientSecret: "test-secret",
tokenHTTPClient: http.DefaultClient,
}
@@ -297,6 +298,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
logger: NewLogger("debug"),
authURL: "https://provider.example.com/auth",
clientID: "test-client",
audience: "test-client",
redirURLPath: "/callback",
}
},
+3
View File
@@ -124,6 +124,7 @@ func (ts *TestSuite) Setup() {
ts.tOidc = &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
@@ -1304,6 +1305,7 @@ func TestHandleCallback(t *testing.T) {
// Add potentially missing fields based on New() comparison
clientID: ts.tOidc.clientID,
audience: ts.tOidc.clientID,
issuerURL: ts.tOidc.issuerURL,
jwkCache: ts.tOidc.jwkCache, // Use the mock cache from TestSuite
httpClient: ts.tOidc.httpClient,
@@ -1668,6 +1670,7 @@ func TestHandleLogout(t *testing.T) {
tokenBlacklist: NewCache(), // Use generic cache for blacklist
httpClient: &http.Client{},
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
tokenCache: NewTokenCache(),
forceHTTPS: false,
+6 -1
View File
@@ -46,7 +46,12 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
if t.issuerURL == "" {
// Read issuerURL with RLock
t.metadataMu.RLock()
issuerURL := t.issuerURL
t.metadataMu.RUnlock()
if issuerURL == "" {
t.logger.Error("OIDC provider metadata initialization failed or incomplete")
t.sendErrorResponse(rw, req, "OIDC provider metadata initialization failed - please check provider availability and configuration", http.StatusServiceUnavailable)
return
+97
View File
@@ -0,0 +1,97 @@
package traefikoidc
import (
"strings"
)
// ScopeFilterLogger interface for dependency injection
type ScopeFilterLogger interface {
Debugf(format string, args ...interface{})
Infof(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// ScopeFilter handles OAuth scope validation and filtering based on provider capabilities.
type ScopeFilter struct {
logger ScopeFilterLogger
}
// NewScopeFilter creates a new ScopeFilter instance.
func NewScopeFilter(logger ScopeFilterLogger) *ScopeFilter {
return &ScopeFilter{
logger: logger,
}
}
// FilterSupportedScopes returns the intersection of requested and supported scopes.
// It preserves the order of requested scopes and returns all requested scopes
// if supportedScopes is empty (fallback for providers without scopes_supported).
//
// Parameters:
// - requestedScopes: Scopes the application wants to request
// - supportedScopes: Scopes advertised by the provider (from discovery doc)
// - providerURL: Provider URL for logging purposes
//
// Returns:
// - Filtered list of scopes safe to request from the provider
func (sf *ScopeFilter) FilterSupportedScopes(requestedScopes, supportedScopes []string, providerURL string) []string {
// If no supported scopes declared, return all requested (backward compatibility)
if len(supportedScopes) == 0 {
sf.logger.Debugf("ScopeFilter: Provider %s has no scopes_supported in discovery doc, using all requested scopes", providerURL)
return requestedScopes
}
// Build lookup map for efficient checking
supportedMap := make(map[string]bool, len(supportedScopes))
for _, scope := range supportedScopes {
supportedMap[strings.TrimSpace(scope)] = true
}
// Filter requested scopes
filtered := make([]string, 0, len(requestedScopes))
removed := make([]string, 0)
for _, scope := range requestedScopes {
trimmed := strings.TrimSpace(scope)
if trimmed == "" {
continue
}
if supportedMap[trimmed] {
filtered = append(filtered, trimmed)
} else {
removed = append(removed, trimmed)
}
}
// Log filtering results
if len(removed) > 0 {
sf.logger.Infof("ScopeFilter: Filtered unsupported scopes for %s: %v (not in provider's scopes_supported)",
providerURL, removed)
sf.logger.Debugf("ScopeFilter: Provider %s supported scopes: %v", providerURL, supportedScopes)
sf.logger.Debugf("ScopeFilter: Final filtered scopes: %v", filtered)
} else {
sf.logger.Debugf("ScopeFilter: All requested scopes are supported by %s", providerURL)
}
// If all scopes were filtered out, return at least "openid"
if len(filtered) == 0 {
sf.logger.Infof("ScopeFilter: All scopes filtered out for %s, falling back to 'openid'", providerURL)
return []string{"openid"}
}
return filtered
}
// EnsureOpenIDScope ensures "openid" scope is present in the scope list.
// This is required for OIDC compliance.
func (sf *ScopeFilter) EnsureOpenIDScope(scopes []string) []string {
for _, scope := range scopes {
if scope == "openid" {
return scopes
}
}
sf.logger.Debugf("ScopeFilter: Adding required 'openid' scope")
return append([]string{"openid"}, scopes...)
}
+724
View File
@@ -0,0 +1,724 @@
package traefikoidc
import (
"reflect"
"testing"
)
// mockLogger for testing
type mockScopeFilterLogger struct {
debugMessages []string
infoMessages []string
errorMessages []string
}
func (l *mockScopeFilterLogger) Debugf(format string, args ...interface{}) {
l.debugMessages = append(l.debugMessages, format)
}
func (l *mockScopeFilterLogger) Infof(format string, args ...interface{}) {
l.infoMessages = append(l.infoMessages, format)
}
func (l *mockScopeFilterLogger) Errorf(format string, args ...interface{}) {
l.errorMessages = append(l.errorMessages, format)
}
// TestNewScopeFilter tests the ScopeFilter constructor
func TestNewScopeFilter(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
if filter == nil {
t.Fatal("Expected ScopeFilter to be created, got nil")
}
// Logger is set correctly (we can't directly compare interface values)
if filter.logger == nil {
t.Error("Logger not set in ScopeFilter")
}
}
// TestFilterSupportedScopes_AllSupported tests when all requested scopes are supported
func TestFilterSupportedScopes_AllSupported(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"openid", "profile", "email"}
supported := []string{"openid", "profile", "email", "address", "phone"}
providerURL := "https://auth.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
expected := []string{"openid", "profile", "email"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v, got %v", expected, result)
}
// Should log debug message that all scopes are supported
if len(logger.debugMessages) == 0 {
t.Error("Expected debug messages to be logged")
}
// Should not log any info messages (no filtering occurred)
if len(logger.infoMessages) > 0 {
t.Error("Expected no info messages when all scopes supported")
}
}
// TestFilterSupportedScopes_SomeFiltered tests when some scopes need to be filtered
func TestFilterSupportedScopes_SomeFiltered(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"openid", "profile", "email", "offline_access", "custom_scope"}
supported := []string{"openid", "profile", "email"}
providerURL := "https://gitlab.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
expected := []string{"openid", "profile", "email"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v, got %v", expected, result)
}
// Verify offline_access and custom_scope were filtered out
for _, scope := range result {
if scope == "offline_access" || scope == "custom_scope" {
t.Errorf("Scope '%s' should have been filtered out", scope)
}
}
// Should log info message about filtered scopes
if len(logger.infoMessages) == 0 {
t.Error("Expected info message about filtered scopes")
}
// Should log debug messages about supported scopes and final result
if len(logger.debugMessages) < 2 {
t.Error("Expected debug messages about provider supported scopes and final result")
}
}
// TestFilterSupportedScopes_AllFiltered tests when all scopes are filtered (fallback to openid)
func TestFilterSupportedScopes_AllFiltered(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"custom_scope1", "custom_scope2", "unsupported"}
supported := []string{"openid", "profile", "email"}
providerURL := "https://auth.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
expected := []string{"openid"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected fallback to %v, got %v", expected, result)
}
// Should log info message about all scopes being filtered (falling back to openid)
if len(logger.infoMessages) < 2 { // One for filtered scopes, one for fallback
t.Error("Expected info messages when all scopes filtered")
}
// Should log info message about filtered scopes
if len(logger.infoMessages) == 0 {
t.Error("Expected info message about filtered scopes")
}
}
// TestFilterSupportedScopes_NoSupportedScopes tests fallback behavior when no scopes_supported
func TestFilterSupportedScopes_NoSupportedScopes(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"openid", "profile", "email", "offline_access"}
supported := []string{} // Empty supported list (backward compatibility)
providerURL := "https://auth.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
// Should return all requested scopes unchanged
if !reflect.DeepEqual(result, requested) {
t.Errorf("Expected all requested scopes %v, got %v", requested, result)
}
// Should log debug message about no scopes_supported
if len(logger.debugMessages) == 0 {
t.Error("Expected debug message about no scopes_supported")
}
// Should not log info messages (backward compatibility mode)
if len(logger.infoMessages) > 0 {
t.Error("Expected no info messages when no supported scopes provided")
}
}
// TestFilterSupportedScopes_EmptyRequested tests when requested scopes are empty
func TestFilterSupportedScopes_EmptyRequested(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{}
supported := []string{"openid", "profile", "email"}
providerURL := "https://auth.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
// Should return openid as fallback
expected := []string{"openid"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected fallback to %v when requested empty, got %v", expected, result)
}
// Should log info message about empty result (fallback to openid)
if len(logger.infoMessages) == 0 {
t.Error("Expected info message when no scopes requested")
}
}
// TestFilterSupportedScopes_DuplicateScopes tests handling of duplicate scope names
func TestFilterSupportedScopes_DuplicateScopes(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"openid", "profile", "openid", "email"}
supported := []string{"openid", "profile", "email"}
providerURL := "https://auth.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
// Should preserve duplicates from requested
expected := []string{"openid", "profile", "openid", "email"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v (preserving duplicates), got %v", expected, result)
}
}
// TestFilterSupportedScopes_WhitespaceHandling tests trimming of whitespace
func TestFilterSupportedScopes_WhitespaceHandling(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{" openid ", "profile", " email"}
supported := []string{"openid", "profile", "email", "phone"}
providerURL := "https://auth.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
// Should trim whitespace from scopes
expected := []string{"openid", "profile", "email"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected trimmed scopes %v, got %v", expected, result)
}
}
// TestFilterSupportedScopes_EmptyStrings tests filtering out empty strings
func TestFilterSupportedScopes_EmptyStrings(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"openid", "", "profile", " ", "email"}
supported := []string{"openid", "profile", "email"}
providerURL := "https://auth.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
// Should filter out empty strings
expected := []string{"openid", "profile", "email"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v (without empty strings), got %v", expected, result)
}
}
// TestFilterSupportedScopes_CasePreservation tests that scope case is preserved
func TestFilterSupportedScopes_CasePreservation(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"OpenID", "Profile", "Email"}
supported := []string{"OpenID", "Profile", "Email"}
providerURL := "https://auth.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
// Should preserve case exactly
expected := []string{"OpenID", "Profile", "Email"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected case-preserved %v, got %v", expected, result)
}
}
// TestFilterSupportedScopes_CaseSensitiveMatching tests case-sensitive matching
func TestFilterSupportedScopes_CaseSensitiveMatching(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"openid", "Profile", "EMAIL"}
supported := []string{"openid", "profile", "email"}
providerURL := "https://auth.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
// Only "openid" should match (case-sensitive)
// Profile and EMAIL won't match profile and email in supported list
expected := []string{"openid"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected case-sensitive filtering %v, got %v", expected, result)
}
// Should log info about filtered scopes
if len(logger.infoMessages) == 0 {
t.Error("Expected info message about filtered scopes due to case mismatch")
}
}
// TestFilterSupportedScopes_OrderPreservation tests that order is preserved
func TestFilterSupportedScopes_OrderPreservation(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"email", "profile", "openid", "phone"}
supported := []string{"openid", "profile", "email", "phone", "address"}
providerURL := "https://auth.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
// Should preserve order from requested
expected := []string{"email", "profile", "openid", "phone"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected order-preserved %v, got %v", expected, result)
}
}
// TestFilterSupportedScopes_GitLabScenario simulates GitLab rejecting offline_access
func TestFilterSupportedScopes_GitLabScenario(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
// User requests offline_access but GitLab doesn't support it
requested := []string{"openid", "profile", "email", "offline_access"}
supported := []string{"openid", "profile", "email", "read_user", "read_api"}
providerURL := "https://gitlab.example.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
expected := []string{"openid", "profile", "email"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v (without offline_access), got %v", expected, result)
}
// Verify offline_access was filtered out
for _, scope := range result {
if scope == "offline_access" {
t.Error("offline_access should have been filtered out for GitLab")
}
}
// Should log info about filtered scopes
if len(logger.infoMessages) == 0 {
t.Error("Expected info message about offline_access being filtered")
}
}
// TestFilterSupportedScopes_GoogleScenario simulates Google's scope handling
func TestFilterSupportedScopes_GoogleScenario(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
// Google supports these standard scopes
requested := []string{"openid", "profile", "email"}
supported := []string{"openid", "profile", "email"}
providerURL := "https://accounts.google.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
expected := []string{"openid", "profile", "email"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v, got %v", expected, result)
}
// No scopes should be filtered
if len(logger.infoMessages) > 0 {
t.Error("Expected no filtering for standard Google scopes")
}
}
// TestFilterSupportedScopes_AzureScenario simulates Azure's scope handling
func TestFilterSupportedScopes_AzureScenario(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
// Azure supports offline_access and OIDC scopes
requested := []string{"openid", "profile", "email", "offline_access"}
supported := []string{"openid", "profile", "email", "offline_access"}
providerURL := "https://login.microsoftonline.com/tenant"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
expected := []string{"openid", "profile", "email", "offline_access"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v (including offline_access), got %v", expected, result)
}
// All scopes should be retained
if len(logger.infoMessages) > 0 {
t.Error("Expected no filtering for standard Azure scopes with offline_access")
}
}
// TestFilterSupportedScopes_GenericWithFiltering simulates generic provider with filtering
func TestFilterSupportedScopes_GenericWithFiltering(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"openid", "profile", "email", "offline_access", "custom:scope"}
supported := []string{"openid", "profile", "email", "custom:scope"}
providerURL := "https://auth.custom-provider.com"
result := filter.FilterSupportedScopes(requested, supported, providerURL)
expected := []string{"openid", "profile", "email", "custom:scope"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v (without offline_access), got %v", expected, result)
}
// offline_access should be filtered
for _, scope := range result {
if scope == "offline_access" {
t.Error("offline_access should have been filtered for this provider")
}
}
// Should log info about filtering
if len(logger.infoMessages) == 0 {
t.Error("Expected info message about filtered offline_access")
}
}
// TestFilterSupportedScopes_MultipleProviderURLs tests different provider URLs
func TestFilterSupportedScopes_MultipleProviderURLs(t *testing.T) {
tests := []struct {
name string
providerURL string
requested []string
supported []string
expected []string
}{
{
name: "GitLab.com",
providerURL: "https://gitlab.com",
requested: []string{"openid", "offline_access"},
supported: []string{"openid"},
expected: []string{"openid"},
},
{
name: "Self-hosted GitLab",
providerURL: "https://gitlab.example.com",
requested: []string{"openid", "profile", "offline_access"},
supported: []string{"openid", "profile"},
expected: []string{"openid", "profile"},
},
{
name: "Keycloak",
providerURL: "https://keycloak.example.com/realms/master",
requested: []string{"openid", "profile", "email"},
supported: []string{"openid", "profile", "email", "offline_access"},
expected: []string{"openid", "profile", "email"},
},
{
name: "Auth0",
providerURL: "https://tenant.auth0.com",
requested: []string{"openid", "profile", "offline_access"},
supported: []string{"openid", "profile", "offline_access"},
expected: []string{"openid", "profile", "offline_access"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
result := filter.FilterSupportedScopes(tt.requested, tt.supported, tt.providerURL)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
// TestEnsureOpenIDScope_Present tests when openid is already present
func TestEnsureOpenIDScope_Present(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
scopes := []string{"openid", "profile", "email"}
result := filter.EnsureOpenIDScope(scopes)
// Should return scopes unchanged
if !reflect.DeepEqual(result, scopes) {
t.Errorf("Expected scopes unchanged %v, got %v", scopes, result)
}
// Should not log anything (openid already present)
if len(logger.debugMessages) > 0 {
t.Error("Expected no debug messages when openid already present")
}
}
// TestEnsureOpenIDScope_Missing tests when openid needs to be added
func TestEnsureOpenIDScope_Missing(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
scopes := []string{"profile", "email"}
result := filter.EnsureOpenIDScope(scopes)
// Should prepend openid
expected := []string{"openid", "profile", "email"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected openid prepended %v, got %v", expected, result)
}
// Should log debug message about adding openid
if len(logger.debugMessages) == 0 {
t.Error("Expected debug message about adding openid scope")
}
}
// TestEnsureOpenIDScope_Empty tests with empty scopes list
func TestEnsureOpenIDScope_Empty(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
scopes := []string{}
result := filter.EnsureOpenIDScope(scopes)
// Should return just openid
expected := []string{"openid"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v, got %v", expected, result)
}
// Should log debug message
if len(logger.debugMessages) == 0 {
t.Error("Expected debug message about adding openid scope")
}
}
// TestEnsureOpenIDScope_Nil tests with nil scopes list
func TestEnsureOpenIDScope_Nil(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
var scopes []string // nil slice
result := filter.EnsureOpenIDScope(scopes)
// Should return just openid
expected := []string{"openid"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v, got %v", expected, result)
}
}
// TestEnsureOpenIDScope_CaseVariations tests that case matters for openid detection
func TestEnsureOpenIDScope_CaseVariations(t *testing.T) {
tests := []struct {
name string
scopes []string
expected []string
}{
{
name: "Lowercase openid",
scopes: []string{"openid", "profile"},
expected: []string{"openid", "profile"},
},
{
name: "Mixed case OpenID (should add lowercase)",
scopes: []string{"OpenID", "profile"},
expected: []string{"openid", "OpenID", "profile"},
},
{
name: "OPENID uppercase (should add lowercase)",
scopes: []string{"OPENID", "profile"},
expected: []string{"openid", "OPENID", "profile"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
result := filter.EnsureOpenIDScope(tt.scopes)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
// TestFilterSupportedScopes_IntegrationScenario tests realistic end-to-end scenario
func TestFilterSupportedScopes_IntegrationScenario(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
// Simulate: User configures plugin with these scopes
requested := []string{"openid", "profile", "email", "offline_access", "custom_claim"}
// Provider discovery returns these supported scopes
supported := []string{"openid", "profile", "email", "read_user"}
providerURL := "https://gitlab.company.com"
// Filter should remove offline_access and custom_claim
result := filter.FilterSupportedScopes(requested, supported, providerURL)
expected := []string{"openid", "profile", "email"}
if !reflect.DeepEqual(result, expected) {
t.Errorf("Expected %v, got %v", expected, result)
}
// Verify logging occurred
if len(logger.infoMessages) == 0 {
t.Error("Expected info message about filtered scopes")
}
if len(logger.debugMessages) < 2 {
t.Error("Expected debug messages about supported scopes and final result")
}
// Verify specific scopes were filtered
for _, scope := range result {
if scope == "offline_access" || scope == "custom_claim" {
t.Errorf("Scope '%s' should have been filtered out", scope)
}
}
}
// TestFilterSupportedScopes_LoggingBehavior tests comprehensive logging scenarios
func TestFilterSupportedScopes_LoggingBehavior(t *testing.T) {
tests := []struct {
name string
requested []string
supported []string
expectDebugOnly bool
expectInfoLog bool
}{
{
name: "All supported - debug only",
requested: []string{"openid", "profile"},
supported: []string{"openid", "profile", "email"},
expectDebugOnly: true,
},
{
name: "Some filtered - info + debug",
requested: []string{"openid", "offline_access"},
supported: []string{"openid"},
expectInfoLog: true,
},
{
name: "All filtered - info + debug",
requested: []string{"custom1", "custom2"},
supported: []string{"openid"},
expectInfoLog: true,
},
{
name: "No supported scopes - debug only",
requested: []string{"openid"},
supported: []string{},
expectDebugOnly: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
filter.FilterSupportedScopes(tt.requested, tt.supported, "https://example.com")
hasDebug := len(logger.debugMessages) > 0
hasInfo := len(logger.infoMessages) > 0
if tt.expectDebugOnly && (!hasDebug || hasInfo) {
t.Errorf("Expected only debug logs, got debug=%v info=%v",
hasDebug, hasInfo)
}
if tt.expectInfoLog && !hasInfo {
t.Error("Expected info log but didn't get one")
}
})
}
}
// Benchmark tests
func BenchmarkFilterSupportedScopes_AllSupported(b *testing.B) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"openid", "profile", "email", "phone"}
supported := []string{"openid", "profile", "email", "phone", "address"}
providerURL := "https://example.com"
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.FilterSupportedScopes(requested, supported, providerURL)
}
}
func BenchmarkFilterSupportedScopes_SomeFiltered(b *testing.B) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"openid", "profile", "email", "offline_access", "custom"}
supported := []string{"openid", "profile", "email"}
providerURL := "https://example.com"
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.FilterSupportedScopes(requested, supported, providerURL)
}
}
func BenchmarkFilterSupportedScopes_NoSupported(b *testing.B) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
requested := []string{"openid", "profile", "email", "offline_access"}
supported := []string{}
providerURL := "https://example.com"
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.FilterSupportedScopes(requested, supported, providerURL)
}
}
func BenchmarkEnsureOpenIDScope_Present(b *testing.B) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
scopes := []string{"openid", "profile", "email"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.EnsureOpenIDScope(scopes)
}
}
func BenchmarkEnsureOpenIDScope_Missing(b *testing.B) {
logger := &mockScopeFilterLogger{}
filter := NewScopeFilter(logger)
scopes := []string{"profile", "email"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.EnsureOpenIDScope(scopes)
}
}
+6
View File
@@ -335,6 +335,7 @@ func TestJWTReplayAttack(t *testing.T) {
tOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://test-jwks-url.com",
@@ -551,6 +552,7 @@ func TestSessionFixationAttack(t *testing.T) {
logoutURLPath: "/callback/logout",
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://test-jwks-url.com",
@@ -857,6 +859,7 @@ func TestTokenBlacklisting(t *testing.T) {
tOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: mockJWKCache,
jwksURL: "https://test-jwks-url.com",
@@ -1278,6 +1281,7 @@ func TestRateLimiting(t *testing.T) {
tOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
@@ -1385,6 +1389,7 @@ func TestAuthorizationHeaderBypass(t *testing.T) {
logoutURLPath: "/callback/logout",
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
@@ -1560,6 +1565,7 @@ func TestInvalidRedirectURI(t *testing.T) {
logoutURLPath: "/callback/logout",
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
+36 -7
View File
@@ -27,13 +27,19 @@ type TemplatedHeader struct {
// It provides all necessary settings to configure OpenID Connect authentication
// with various providers like Auth0, Logto, or any standard OIDC provider.
type Config struct {
HTTPClient *http.Client `json:"-"`
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
CookieDomain string `json:"cookieDomain"`
CallbackURL string `json:"callbackURL"`
LogoutURL string `json:"logoutURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
HTTPClient *http.Client `json:"-"`
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
CookieDomain string `json:"cookieDomain"`
CallbackURL string `json:"callbackURL"`
LogoutURL string `json:"logoutURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
// Audience specifies the expected JWT audience claim value.
// If not set, defaults to ClientID for backward compatibility.
// For Auth0 API access tokens with custom audiences, set this to your API identifier.
// For Azure AD with Application ID URI, set to "api://your-app-id".
// Security: This value is validated against the JWT aud claim to prevent token confusion attacks.
Audience string `json:"audience,omitempty"`
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
LogLevel string `json:"logLevel"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
@@ -268,6 +274,29 @@ func (c *Config) Validate() error {
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
}
// Validate audience if specified
if c.Audience != "" {
// Validate audience format - should be a valid identifier or URL
if len(c.Audience) > 256 {
return fmt.Errorf("audience must not exceed 256 characters")
}
// If audience looks like a URL, validate it's HTTPS
if strings.HasPrefix(c.Audience, "http://") {
return fmt.Errorf("audience URL must use HTTPS, not HTTP")
}
// Prevent wildcard audiences which could weaken security
if strings.Contains(c.Audience, "*") {
return fmt.Errorf("audience must not contain wildcards")
}
// Validate that audience doesn't contain obvious injection patterns
if strings.ContainsAny(c.Audience, "\n\r\t\x00") {
return fmt.Errorf("audience contains invalid characters")
}
}
// Validate headers configuration for template security
for _, header := range c.Headers {
if header.Name == "" {
+29 -6
View File
@@ -276,6 +276,7 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
t.Run("SingletonTasksAcrossInstances", func(t *testing.T) {
// Reset singletons to ensure clean state
ResetGlobalTaskRegistry() // Reset circuit breaker and task registry
resetResourceManagerForTesting()
ResetUniversalCacheManagerForTesting()
defer ResetUniversalCacheManagerForTesting()
@@ -312,13 +313,35 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
plugins = append(plugins, plugin)
}
// Wait for cleanup to run multiple times
time.Sleep(350 * time.Millisecond)
// Wait for cleanup to run at least 2 times with adaptive timeout
// This handles race detector overhead which can slow goroutine scheduling significantly
// When running as part of full test suite, CPU contention is even higher, so use generous timeout
const minExpectedCount = 2
const maxExpectedCount = 5
timeout := time.After(5 * time.Second)
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
// Check that cleanup ran but not excessively (should be singleton)
count := atomic.LoadInt32(&cleanupCount)
if count < 2 || count > 5 {
t.Errorf("Unexpected cleanup count: %d (expected 2-5 for singleton)", count)
var count int32
waitLoop:
for {
select {
case <-ticker.C:
count = atomic.LoadInt32(&cleanupCount)
if count >= minExpectedCount {
// Success: reached minimum threshold
break waitLoop
}
case <-timeout:
count = atomic.LoadInt32(&cleanupCount)
t.Errorf("Timeout waiting for cleanup count to reach %d, got %d (race detector may be slowing execution)", minExpectedCount, count)
break waitLoop
}
}
// Verify count is within expected range (should be singleton, not running excessively)
if count > maxExpectedCount {
t.Errorf("Cleanup count too high: %d (expected max %d for singleton)", count, maxExpectedCount)
}
// Cleanup
+1
View File
@@ -244,6 +244,7 @@ func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httpt
next: nextHandler,
issuerURL: testIssuerURL,
clientID: config.ClientID,
audience: config.ClientID,
clientSecret: config.ClientSecret,
redirURLPath: callbackPath,
logoutURLPath: logoutPath,
+39 -10
View File
@@ -171,13 +171,18 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
t.safeLogDebugf("Verifying JWT signature and claims")
jwks, err := t.jwkCache.GetJWKS(context.Background(), t.jwksURL, t.httpClient)
// Read jwksURL with RLock
t.metadataMu.RLock()
jwksURL := t.jwksURL
t.metadataMu.RUnlock()
jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
if !t.suppressDiagnosticLogs && jwks != nil {
t.safeLogDebugf("DIAGNOSTIC: Retrieved JWKS with %d keys from URL: %s", len(jwks.Keys), t.jwksURL)
t.safeLogDebugf("DIAGNOSTIC: Retrieved JWKS with %d keys from URL: %s", len(jwks.Keys), jwksURL)
}
kid, ok := jwt.Header["kid"].(string)
@@ -235,7 +240,13 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
t.safeLogDebugf("DIAGNOSTIC: Signature verification successful for kid=%s", kid)
}
if err := jwt.Verify(t.issuerURL, t.clientID, true); err != nil {
// Use configured audience (defaults to clientID if not specified)
// Read issuerURL with RLock
t.metadataMu.RLock()
issuerURL := t.issuerURL
t.metadataMu.RUnlock()
if err := jwt.Verify(issuerURL, t.audience, true); err != nil {
return fmt.Errorf("standard claim verification failed: %w", err)
}
@@ -423,10 +434,15 @@ func (t *TraefikOidc) RevokeToken(token string) {
// Returns:
// - An error if the request fails or the provider returns a non-OK status.
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
if t.revocationURL == "" {
// Read revocationURL with RLock
t.metadataMu.RLock()
revocationURL := t.revocationURL
t.metadataMu.RUnlock()
if revocationURL == "" {
return fmt.Errorf("token revocation endpoint is not configured or discovered")
}
t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, t.revocationURL)
t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, revocationURL)
data := url.Values{
"token": {token},
@@ -435,7 +451,7 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
"client_secret": {t.clientSecret},
}
req, err := http.NewRequestWithContext(context.Background(), "POST", t.revocationURL, strings.NewReader(data.Encode()))
req, err := http.NewRequestWithContext(context.Background(), "POST", revocationURL, strings.NewReader(data.Encode()))
if err != nil {
return fmt.Errorf("failed to create token revocation request: %w", err)
}
@@ -446,7 +462,10 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
// Send the request with circuit breaker protection if available
var resp *http.Response
if t.errorRecoveryManager != nil {
// Read issuerURL with RLock for service name
t.metadataMu.RLock()
serviceName := fmt.Sprintf("token-revocation-%s", t.issuerURL)
t.metadataMu.RUnlock()
err = t.errorRecoveryManager.ExecuteWithRecovery(context.Background(), serviceName, func() error {
var reqErr error
resp, reqErr = t.httpClient.Do(req)
@@ -517,7 +536,12 @@ func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenRe
// Returns:
// - true if the provider is Google, false otherwise.
func (t *TraefikOidc) isGoogleProvider() bool {
return strings.Contains(t.issuerURL, "google") || strings.Contains(t.issuerURL, "accounts.google.com")
// Read issuerURL with RLock
t.metadataMu.RLock()
issuerURL := t.issuerURL
t.metadataMu.RUnlock()
return strings.Contains(issuerURL, "google") || strings.Contains(issuerURL, "accounts.google.com")
}
// isAzureProvider detects if the configured OIDC provider is Azure AD.
@@ -525,9 +549,14 @@ func (t *TraefikOidc) isGoogleProvider() bool {
// Returns:
// - true if the provider is Azure AD, false otherwise.
func (t *TraefikOidc) isAzureProvider() bool {
return strings.Contains(t.issuerURL, "login.microsoftonline.com") ||
strings.Contains(t.issuerURL, "sts.windows.net") ||
strings.Contains(t.issuerURL, "login.windows.net")
// Read issuerURL with RLock
t.metadataMu.RLock()
issuerURL := t.issuerURL
t.metadataMu.RUnlock()
return strings.Contains(issuerURL, "login.microsoftonline.com") ||
strings.Contains(issuerURL, "sts.windows.net") ||
strings.Contains(issuerURL, "login.windows.net")
}
// ============================================================================
+11 -6
View File
@@ -49,12 +49,13 @@ type TokenExchanger interface {
// This data is typically retrieved from the provider's .well-known/openid-configuration endpoint
// and contains essential URLs for authentication, token exchange, and key retrieval.
type ProviderMetadata struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
EndSessionURL string `json:"end_session_endpoint"`
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
EndSessionURL string `json:"end_session_endpoint"`
ScopesSupported []string `json:"scopes_supported,omitempty"` // NEW FIELD
}
// TraefikOidc is the main middleware struct that implements OIDC authentication for Traefik.
@@ -92,9 +93,11 @@ type TraefikOidc struct {
goroutineWG *sync.WaitGroup
clientSecret string
clientID string
audience string // Expected JWT audience, defaults to clientID
name string
redirURLPath string
logoutURLPath string
metadataMu sync.RWMutex // Protects metadata endpoint fields
tokenURL string
authURL string
endSessionURL string
@@ -115,4 +118,6 @@ type TraefikOidc struct {
firstRequestReceived bool
metadataRefreshStarted bool
securityHeadersApplier func(http.ResponseWriter, *http.Request)
scopeFilter *ScopeFilter // NEW - for discovery-based scope filtering
scopesSupported []string // NEW - from provider metadata
}
+45 -3
View File
@@ -98,7 +98,28 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
scopes := make([]string, len(t.scopes))
copy(scopes, t.scopes)
// Apply discovery-based scope filtering if available
// Read scopesSupported with RLock
t.metadataMu.RLock()
scopesSupported := t.scopesSupported
t.metadataMu.RUnlock()
if t.scopeFilter != nil && len(scopesSupported) > 0 {
scopes = t.scopeFilter.FilterSupportedScopes(scopes, scopesSupported, t.providerURL)
t.logger.Debugf("TraefikOidc.buildAuthURL: After discovery filtering: %v", scopes)
}
// Then apply provider-specific modifications
if t.isGoogleProvider() {
// Google: Remove offline_access if present, add access_type=offline
filteredScopes := make([]string, 0, len(scopes))
for _, scope := range scopes {
if scope != "offline_access" {
filteredScopes = append(filteredScopes, scope)
}
}
scopes = filteredScopes
params.Set("access_type", "offline")
t.logger.Debug("Google OIDC provider detected, added access_type=offline for refresh tokens")
@@ -143,13 +164,29 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
}
}
// Final filtering pass to remove anything the provider doesn't support
// Read scopesSupported with RLock
t.metadataMu.RLock()
scopesSupported = t.scopesSupported
t.metadataMu.RUnlock()
if t.scopeFilter != nil && len(scopesSupported) > 0 {
scopes = t.scopeFilter.FilterSupportedScopes(scopes, scopesSupported, t.providerURL)
t.logger.Debugf("TraefikOidc.buildAuthURL: After final filtering: %v", scopes)
}
if len(scopes) > 0 {
finalScopeString := strings.Join(scopes, " ")
params.Set("scope", finalScopeString)
t.logger.Debugf("TraefikOidc.buildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
}
return t.buildURLWithParams(t.authURL, params)
// Read authURL with RLock
t.metadataMu.RLock()
authURL := t.authURL
t.metadataMu.RUnlock()
return t.buildURLWithParams(authURL, params)
}
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
@@ -172,9 +209,14 @@ func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) stri
}
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
issuerURLParsed, err := url.Parse(t.issuerURL)
// Read issuerURL with RLock
t.metadataMu.RLock()
issuerURL := t.issuerURL
t.metadataMu.RUnlock()
issuerURLParsed, err := url.Parse(issuerURL)
if err != nil {
t.logger.Errorf("Could not parse issuerURL: %s. Error: %v", t.issuerURL, err)
t.logger.Errorf("Could not parse issuerURL: %s. Error: %v", issuerURL, err)
return ""
}