diff --git a/.traefik.yml b/.traefik.yml index 197c594..c64a221 100644 --- a/.traefik.yml +++ b/.traefik.yml @@ -77,6 +77,7 @@ testData: # Custom claim names for Auth0 and other providers with namespaced claims roleClaimName: roles # JWT claim name for extracting user roles (default: "roles") groupClaimName: groups # JWT claim name for extracting user groups (default: "groups") + userIdentifierClaim: email # JWT claim for user identification (default: "email", alternatives: "sub", "oid", "upn", "preferred_username") # ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.) # When NOT specified in config: defaults to FALSE (Go zero value) @@ -120,6 +121,8 @@ testData: allowOpaqueTokens: false # Enable opaque (non-JWT) access token support via RFC 7662 introspection requireTokenIntrospection: false # Force introspection for opaque tokens (requires introspection endpoint) disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false) + allowPrivateIPAddresses: false # Allow private IP addresses in provider URLs for internal networks (default: false) + minimalHeaders: false # Reduce forwarded headers to prevent 431 errors (default: false) # Security Headers Configuration (enabled by default with 'default' profile) securityHeaders: @@ -266,6 +269,8 @@ testDataWithRedis: # allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers # - admin # - editor +# # For internal Keycloak deployments with private IPs (Docker/Kubernetes internal): +# # allowPrivateIPAddresses: true # Enable for private IP addresses like 192.168.x.x, 10.x.x.x # # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token. # # See README.md "Provider Configuration Recommendations" for Keycloak. @@ -287,6 +292,26 @@ testDataWithRedis: # - "AppRoleName" # # See README.md "Provider Configuration Recommendations" for Azure AD. +# --- Azure AD Users Without Email Example (Issue #95) --- +# testDataAzureADNoEmail: +# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0 +# clientID: your-azure-ad-client-id +# clientSecret: your-azure-ad-client-secret +# callbackURL: /oauth2/callback +# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure" +# # Use 'sub' claim instead of 'email' for user identification +# userIdentifierClaim: sub # or "oid", "upn", "preferred_username" +# overrideScopes: true # Remove email scope if not needed +# scopes: +# - openid +# - profile +# - groups # For group-based access control +# # When using non-email identifiers, allowedUsers matches against the claim value +# allowedUsers: +# - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID (sub or oid claim) +# # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email" +# # See: https://github.com/lukaszraczylo/traefikoidc/issues/95 + # --- Google Workspace / Google Cloud Identity Example --- # testDataGoogle: # providerURL: https://accounts.google.com # Standard Google OIDC endpoint @@ -605,6 +630,38 @@ configuration: items: type: string + userIdentifierClaim: + type: string + description: | + Specifies the JWT claim to use as the user identifier for authentication and authorization. + + This allows authentication for users without email addresses, such as Azure AD service + accounts or organizational accounts that don't have email attributes configured. + + When set to a non-email claim (e.g., "sub", "oid", "upn"): + - AllowedUsers will match against this claim value instead of email + - AllowedUserDomains validation is skipped (domains only apply to email addresses) + - The session stores this identifier as the user's identity + - If the configured claim is missing, falls back to "sub" (required by OIDC spec) + + Common values by provider: + - Default: "email" (standard email-based identification) + - Azure AD: "sub", "oid" (object ID), "upn" (User Principal Name), "preferred_username" + - Generic OIDC: "sub" (always present per OIDC specification) + - Keycloak: "sub", "preferred_username" + + Example for Azure AD users without email: + ```yaml + userIdentifierClaim: sub + allowedUsers: + - "abc123-user-object-id" + - "xyz789-another-user-id" + ``` + + Default: "email" + See: https://github.com/lukaszraczylo/traefikoidc/issues/95 + required: false + revocationURL: type: string description: | @@ -903,6 +960,67 @@ configuration: Default: false (replay detection enabled) required: false + allowPrivateIPAddresses: + type: boolean + description: | + Allow private IP addresses in OIDC provider URLs for internal network deployments. + + By default, the plugin blocks URLs containing private IP address ranges + (10.x.x.x, 172.16-31.x.x, 192.168.x.x) to prevent SSRF attacks and ensure + OIDC providers are publicly accessible. + + Enable this option when: + - Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs + - You don't have DNS resolution available for internal services + - Your entire stack runs in a Docker network or Kubernetes cluster with private addressing + + When enabled, the plugin will accept provider URLs like: + - https://192.168.1.100:8443/auth/realms/your-realm + - https://10.0.0.50:8080/realms/master + - https://172.16.0.10/auth + + Security Warning: + Enabling this option reduces SSRF protection. Only use in trusted network + environments where the OIDC provider is known and controlled. Loopback + addresses (127.0.0.1, localhost, ::1) remain blocked even with this option enabled. + + Default: false (private IPs are blocked for security) + See: https://github.com/lukaszraczylo/traefikoidc/issues/97 + required: false + + minimalHeaders: + type: boolean + description: | + Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors. + + When enabled, the middleware only forwards the X-Forwarded-User header and skips + the larger authentication headers that can cause downstream services to reject + requests due to header size limits (typically 8KB). + + Headers when disabled (default): + - X-Forwarded-User: User's email address (always set) + - X-Auth-Request-Redirect: Original request URI + - X-Auth-Request-User: User's email address + - X-Auth-Request-Token: Full ID token (can be very large with many claims) + - X-User-Groups: Comma-separated user groups (if configured) + - X-User-Roles: Comma-separated user roles (if configured) + + Headers when enabled: + - X-Forwarded-User: User's email address (always set) + - X-User-Groups: Comma-separated user groups (if configured, still forwarded) + - X-User-Roles: Comma-separated user roles (if configured, still forwarded) + - Custom templated headers (still processed) + + Use this option when: + - Downstream services return "431 Request Header Fields Too Large" errors + - Your ID tokens are large (many claims, long group lists) + - You don't need the full ID token forwarded to backend services + - You want to reduce request overhead + + Default: false (all headers forwarded for backward compatibility) + See: https://github.com/lukaszraczylo/traefikoidc/issues/64 + required: false + headers: type: array description: | diff --git a/README.md b/README.md index 2e97a2f..3103b82 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,7 @@ The middleware supports the following configuration options: | `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` | | `roleClaimName` | JWT claim name for extracting user roles (supports namespaced claims for Auth0) | `"roles"` | `"https://myapp.com/roles"`, `"user_roles"` | | `groupClaimName` | JWT claim name for extracting user groups (supports namespaced claims for Auth0) | `"groups"` | `"https://myapp.com/groups"`, `"user_groups"` | +| `userIdentifierClaim` | JWT claim to use as user identifier (for users without email, e.g., Azure AD service accounts) | `"email"` | `"sub"`, `"oid"`, `"upn"`, `"preferred_username"` | | `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` | | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` | | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` | @@ -138,6 +139,8 @@ The middleware supports the following configuration options: | `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section | | `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section | | `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` | +| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` | +| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` | | `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section | > **⚠️ IMPORTANT - TLS Termination at Load Balancer:** @@ -1241,6 +1244,45 @@ spec: - "AppRoleName" # Application role names ``` +### Azure AD Configuration (Users Without Email) + +For Azure AD users without email addresses (service accounts, organizational accounts without mail attributes): + +```yaml +apiVersion: traefik.io/v1alpha1 +kind: Middleware +metadata: + name: oidc-azure-no-email + namespace: traefik +spec: + plugin: + traefikoidc: + providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0 + clientID: your-azure-ad-client-id + clientSecret: your-azure-ad-client-secret + sessionEncryptionKey: your-secure-encryption-key-min-32-chars + callbackURL: /oauth2/callback + logoutURL: /oauth2/logout + + # Use 'sub' instead of 'email' for user identification + userIdentifierClaim: sub # Can also use: "oid", "upn", "preferred_username" + + overrideScopes: true # Optional: Don't request email scope if not needed + scopes: + - openid + - profile + - groups + + # When using non-email identifiers, allowedUsers matches against the claim value + allowedUsers: + - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID + - "def67890-1234-5678-90ab-cdef12345678" + + # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email" +``` + +> **Note**: When `userIdentifierClaim` is set to a non-email claim (like `sub`, `oid`, or `upn`), the `allowedUserDomains` configuration is ignored since domain-based validation only applies to email addresses. Use `allowedUsers` with the actual claim values instead. + ### Auth0 Configuration ```yaml @@ -1327,8 +1369,12 @@ spec: - admin - editor # Ensure Keycloak client mappers add necessary claims to ID Token + # For internal Keycloak deployments with private IPs (e.g., Docker network): + # allowPrivateIPAddresses: true ``` +> **Internal Network Deployment**: If your Keycloak runs on an internal network with private IP addresses (e.g., `192.168.x.x`, `10.x.x.x`, `172.16-31.x.x`) and you don't have DNS resolution available, set `allowPrivateIPAddresses: true` to allow the plugin to connect to your Keycloak instance. See [Issue #97](https://github.com/lukaszraczylo/traefikoidc/issues/97) for details. + ### AWS Cognito Configuration ```yaml @@ -1629,12 +1675,39 @@ headers: When a user is authenticated, the middleware sets the following headers for downstream services: -- `X-Forwarded-User`: The user's email address +- `X-Forwarded-User`: The user's email address (always set) - `X-User-Groups`: Comma-separated list of user groups (if available) - `X-User-Roles`: Comma-separated list of user roles (if available) - `X-Auth-Request-Redirect`: The original request URI - `X-Auth-Request-User`: The user's email address -- `X-Auth-Request-Token`: The user's access token +- `X-Auth-Request-Token`: The user's ID token (can be large) + +#### Minimal Headers Mode + +If your downstream services return **"431 Request Header Fields Too Large"** errors, you can enable minimal headers mode to reduce header overhead: + +```yaml +http: + middlewares: + my-auth: + plugin: + traefikoidc: + minimalHeaders: true + # ... other config +``` + +When `minimalHeaders: true` is set: +- **Only forwards**: `X-Forwarded-User` +- **Skips**: `X-Auth-Request-Token` (the full ID token - often the largest header), `X-Auth-Request-User`, `X-Auth-Request-Redirect` +- **Still forwards**: `X-User-Groups` and `X-User-Roles` (if configured) +- **Still processes**: Custom templated headers + +This is particularly useful when: +- Your ID tokens are large (many claims, long group lists) +- Downstream services have limited header buffer sizes (default 8KB in many servers) +- You don't need the full token forwarded to backend services + +See [GitHub Issue #64](https://github.com/lukaszraczylo/traefikoidc/issues/64) for details. ### Security Headers @@ -1862,6 +1935,15 @@ logLevel: debug - No refresh tokens (re-authentication required on expiry) - Use only for GitHub API access, not user authentication +15. **Environment variable names containing "API" cause plugin failure** ([Issue #98](https://github.com/lukaszraczylo/traefikoidc/issues/98)): + - When using environment variable syntax like `${OIDC_ENCRYPTION_SECRET_API}` in Traefik configuration, the plugin fails with "invalid handler type: \" error + - This is a **Traefik-side issue**, not a plugin bug. Traefik uses reserved environment variables starting with `TRAEFIK_API_*` for its internal API configuration, and the "API" substring in user-defined variable names may interfere with Traefik's environment variable processing + - **Workaround**: Avoid using "API" as a substring in environment variable names. Use alternatives like: + - `${OIDC_ENCRYPTION_SECRET_SVC}` instead of `${OIDC_ENCRYPTION_SECRET_API}` + - `${OIDC_ENCRYPTION_SECRET_SERVICE}` + - `${OIDC_ENCRYPTION_SECRET_BACKEND}` + - Any name that doesn't contain the literal substring "API" + ### Provider Warnings and Recommendations The middleware includes built-in warnings for provider-specific limitations. Check your logs for important notices about: diff --git a/audience_validation_test.go b/audience_validation_test.go index 8e07184..7fe7aa7 100644 --- a/audience_validation_test.go +++ b/audience_validation_test.go @@ -849,26 +849,27 @@ func TestAudienceEndToEndScenario(t *testing.T) { customAudience := "https://api.company.com" tOidc := &TraefikOidc{ - next: nextHandler, - name: "test", - redirURLPath: "/callback", - logoutURLPath: "/callback/logout", - issuerURL: "https://auth.company.com", - clientID: "test-client-id", - clientSecret: "test-client-secret", - audience: customAudience, // Set custom audience - jwkCache: mockJWKCache, - jwksURL: "https://auth.company.com/.well-known/jwks.json", - tokenBlacklist: tokenBlacklist, - tokenCache: tokenCache, - limiter: rate.NewLimiter(rate.Every(time.Second), 10), - logger: logger, - allowedUserDomains: map[string]struct{}{"company.com": {}}, - excludedURLs: map[string]struct{}{}, - httpClient: &http.Client{}, - initComplete: make(chan struct{}), - sessionManager: sm, - extractClaimsFunc: extractClaims, + next: nextHandler, + name: "test", + redirURLPath: "/callback", + logoutURLPath: "/callback/logout", + issuerURL: "https://auth.company.com", + clientID: "test-client-id", + clientSecret: "test-client-secret", + audience: customAudience, // Set custom audience + jwkCache: mockJWKCache, + jwksURL: "https://auth.company.com/.well-known/jwks.json", + tokenBlacklist: tokenBlacklist, + tokenCache: tokenCache, + limiter: rate.NewLimiter(rate.Every(time.Second), 10), + logger: logger, + allowedUserDomains: map[string]struct{}{"company.com": {}}, + userIdentifierClaim: "email", // Required for user identification + excludedURLs: map[string]struct{}{}, + httpClient: &http.Client{}, + initComplete: make(chan struct{}), + sessionManager: sm, + extractClaimsFunc: extractClaims, } tOidc.jwtVerifier = tOidc tOidc.tokenVerifier = tOidc diff --git a/auth/auth_handler.go b/auth/auth_handler.go index 7ab5a7b..8e303e5 100644 --- a/auth/auth_handler.go +++ b/auth/auth_handler.go @@ -18,17 +18,18 @@ type ScopeFilter interface { // Handler provides core authentication functionality for OIDC flows type Handler struct { - logger Logger - enablePKCE bool - isGoogleProv func() bool - isAzureProv func() bool - clientID string - authURL string - issuerURL string - scopes []string - overrideScopes bool - scopeFilter ScopeFilter // NEW - scopesSupported []string // NEW - from provider metadata + logger Logger + enablePKCE bool + isGoogleProv func() bool + isAzureProv func() bool + clientID string + authURL string + issuerURL string + scopes []string + overrideScopes bool + scopeFilter ScopeFilter // NEW + scopesSupported []string // NEW - from provider metadata + allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks) } // Logger interface for dependency injection @@ -40,19 +41,20 @@ type Logger interface { // NewAuthHandler creates a new Handler instance func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool, clientID, authURL, issuerURL string, scopes []string, overrideScopes bool, - scopeFilter ScopeFilter, scopesSupported []string) *Handler { + scopeFilter ScopeFilter, scopesSupported []string, allowPrivateIPAddresses bool) *Handler { return &Handler{ - logger: logger, - enablePKCE: enablePKCE, - isGoogleProv: isGoogleProv, - isAzureProv: isAzureProv, - clientID: clientID, - authURL: authURL, - issuerURL: issuerURL, - scopes: scopes, - overrideScopes: overrideScopes, - scopeFilter: scopeFilter, // NEW - scopesSupported: scopesSupported, // NEW + logger: logger, + enablePKCE: enablePKCE, + isGoogleProv: isGoogleProv, + isAzureProv: isAzureProv, + clientID: clientID, + authURL: authURL, + issuerURL: issuerURL, + scopes: scopes, + overrideScopes: overrideScopes, + scopeFilter: scopeFilter, + scopesSupported: scopesSupported, + allowPrivateIPAddresses: allowPrivateIPAddresses, } } @@ -347,6 +349,7 @@ func (h *Handler) validateParsedURL(u *url.URL) error { // validateHost validates a hostname for security and reachability. // It prevents access to private networks and localhost addresses. +// When allowPrivateIPAddresses is enabled, private IP checks are skipped. func (h *Handler) validateHost(host string) error { if host == "" { return fmt.Errorf("empty host") @@ -361,7 +364,7 @@ func (h *Handler) validateHost(host string) error { } } - // Check for localhost variations + // Check for localhost variations (always blocked, even with allowPrivateIPAddresses) localhostVariations := []string{ "localhost", "127.0.0.1", "::1", "0.0.0.0", } @@ -376,7 +379,8 @@ func (h *Handler) validateHost(host string) error { if ip.IsLoopback() { return fmt.Errorf("loopback IP not allowed: %s", host) } - if ip.IsPrivate() { + // Skip private IP check if allowPrivateIPAddresses is enabled + if !h.allowPrivateIPAddresses && ip.IsPrivate() { return fmt.Errorf("private IP not allowed: %s", host) } if ip.IsLinkLocalUnicast() { diff --git a/auth/auth_handler_test.go b/auth/auth_handler_test.go index 974df40..1451db6 100644 --- a/auth/auth_handler_test.go +++ b/auth/auth_handler_test.go @@ -86,7 +86,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) { handler := NewAuthHandler(logger, true, isGoogleProv, isAzureProv, "test-client-id", "https://example.com/auth", "https://example.com", - scopes, false, nil, nil) + scopes, false, nil, nil, false) if handler == nil { t.Fatal("Expected handler to be created, got nil") @@ -125,7 +125,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) { func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, - "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil) + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false) session := &mockSessionData{redirectCount: 5} // At the limit req := httptest.NewRequest("GET", "/test", nil) @@ -160,7 +160,7 @@ func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) { func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, - "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil) + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false) session := &mockSessionData{} req := httptest.NewRequest("GET", "/test", nil) @@ -191,7 +191,7 @@ func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) { func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false }, - "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil) + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false) session := &mockSessionData{} req := httptest.NewRequest("GET", "/test", nil) @@ -222,7 +222,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T) func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false }, - "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil) + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false) session := &mockSessionData{} req := httptest.NewRequest("GET", "/test", nil) @@ -253,7 +253,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T) func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, - "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil) + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false) session := &mockSessionData{saveError: &testError{"save failed"}} req := httptest.NewRequest("GET", "/test?param=value", nil) @@ -297,7 +297,7 @@ func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) { func TestAuthHandler_InitiateAuthentication_Success(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false }, - "test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil) + "test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil, false) session := &mockSessionData{} req := httptest.NewRequest("GET", "/protected/resource", nil) @@ -400,7 +400,7 @@ func TestAuthHandler_BuildAuthURL_GoogleProvider(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false }, "google-client", "https://accounts.google.com/oauth2/auth", "https://accounts.google.com", - []string{"openid", "profile", "email"}, false, nil, nil) + []string{"openid", "profile", "email"}, false, nil, nil, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -440,7 +440,7 @@ func TestAuthHandler_BuildAuthURL_AzureProvider(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true }, "azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize", "https://login.microsoftonline.com/tenant/v2.0", - []string{"openid", "profile", "email"}, false, nil, nil) + []string{"openid", "profile", "email"}, false, nil, nil, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -468,7 +468,7 @@ func TestAuthHandler_BuildAuthURL_PKCEEnabled(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false }, "pkce-client", "https://example.com/auth", "https://example.com", - []string{"openid"}, false, nil, nil) + []string{"openid"}, false, nil, nil, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge") @@ -493,7 +493,7 @@ func TestAuthHandler_BuildAuthURL_PKCEDisabled(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, "no-pkce-client", "https://example.com/auth", "https://example.com", - []string{"openid"}, false, nil, nil) + []string{"openid"}, false, nil, nil, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge") @@ -565,7 +565,7 @@ func TestAuthHandler_BuildAuthURL_ScopeHandling(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return tt.isAzure }, "test-client", "https://example.com/auth", "https://example.com", - tt.scopes, tt.overrideScopes, nil, nil) + tt.scopes, tt.overrideScopes, nil, nil, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -634,7 +634,7 @@ func TestAuthHandler_BuildAuthURL_WithScopeFiltering(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, "test-client", "https://example.com/auth", "https://example.com", - scopes, false, scopeFilter, scopesSupported) + scopes, false, scopeFilter, scopesSupported, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -676,7 +676,7 @@ func TestAuthHandler_BuildAuthURL_WithoutScopeFiltering(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, "test-client", "https://example.com/auth", "https://example.com", - scopes, false, nil, nil) + scopes, false, nil, nil, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -714,7 +714,7 @@ func TestAuthHandler_BuildAuthURL_GitLabFiltersOfflineAccess(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, "gitlab-client", "https://gitlab.example.com/oauth/authorize", "https://gitlab.example.com", - scopes, false, scopeFilter, scopesSupported) + scopes, false, scopeFilter, scopesSupported, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -756,7 +756,7 @@ func TestAuthHandler_BuildAuthURL_GoogleRemovesOfflineAccess(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false }, "google-client", "https://accounts.google.com/o/oauth2/v2/auth", "https://accounts.google.com", - scopes, false, scopeFilter, scopesSupported) + scopes, false, scopeFilter, scopesSupported, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -797,7 +797,7 @@ func TestAuthHandler_BuildAuthURL_AzureAddsOfflineAccess(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true }, "azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize", "https://login.microsoftonline.com/tenant/v2.0", - scopes, false, scopeFilter, scopesSupported) + scopes, false, scopeFilter, scopesSupported, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -831,7 +831,7 @@ func TestAuthHandler_BuildAuthURL_GenericWithFiltering(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, "generic-client", "https://auth.provider.com/authorize", "https://auth.provider.com", - scopes, false, scopeFilter, scopesSupported) + scopes, false, scopeFilter, scopesSupported, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -870,7 +870,7 @@ func TestAuthHandler_BuildAuthURL_OverrideScopesWithFiltering(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, "test-client", "https://example.com/auth", "https://example.com", - scopes, true, scopeFilter, scopesSupported) + scopes, true, scopeFilter, scopesSupported, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -916,7 +916,7 @@ func TestAuthHandler_BuildAuthURL_DoubleFiltering(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, "test-client", "https://example.com/auth", "https://example.com", - scopes, false, scopeFilter, scopesSupported) + scopes, false, scopeFilter, scopesSupported, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -955,7 +955,7 @@ func TestAuthHandler_BuildAuthURL_NoScopeFilterProvided(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, "test-client", "https://example.com/auth", "https://example.com", - scopes, false, nil, scopesSupported) // scopeFilter is nil + scopes, false, nil, scopesSupported, false) // scopeFilter is nil authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -988,7 +988,7 @@ func TestAuthHandler_BuildAuthURL_EmptyScopesSupported(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, "test-client", "https://example.com/auth", "https://example.com", - scopes, false, scopeFilter, scopesSupported) + scopes, false, scopeFilter, scopesSupported, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") @@ -1021,7 +1021,7 @@ func TestAuthHandler_BuildAuthURL_FilteringWithPKCE(t *testing.T) { handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false }, "test-client", "https://example.com/auth", "https://example.com", - scopes, false, scopeFilter, scopesSupported) + scopes, false, scopeFilter, scopesSupported, false) authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge") @@ -1064,7 +1064,7 @@ func TestAuthHandler_BuildAuthURL_ComplexScenario(t *testing.T) { handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false }, "complex-client", "https://auth.complex.com/authorize", "https://auth.complex.com", - scopes, false, scopeFilter, scopesSupported) + scopes, false, scopeFilter, scopesSupported, false) authURL := handler.BuildAuthURL("https://example.com/callback", "state-123", "nonce-456", "challenge-789") @@ -1130,7 +1130,7 @@ func TestAuthHandler_BuildAuthURL_LoggingVerification(t *testing.T) { handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, "test-client", "https://example.com/auth", "https://example.com", - scopes, false, scopeFilter, scopesSupported) + scopes, false, scopeFilter, scopesSupported, false) handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "") diff --git a/auth/url_validation_test.go b/auth/url_validation_test.go index db1d93f..d326524 100644 --- a/auth/url_validation_test.go +++ b/auth/url_validation_test.go @@ -10,7 +10,7 @@ import ( func TestAuthHandler_validateURL(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, - "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil) + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false) tests := []struct { name string @@ -185,7 +185,7 @@ func TestAuthHandler_validateURL(t *testing.T) { func TestAuthHandler_validateHost(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, - "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil) + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false) tests := []struct { name string @@ -334,7 +334,7 @@ func TestAuthHandler_validateHost(t *testing.T) { func TestAuthHandler_buildURLWithParams(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, - "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil) + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false) tests := []struct { name string @@ -438,7 +438,7 @@ func TestAuthHandler_buildURLWithParams(t *testing.T) { func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, - "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil) + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false) // Test special characters that need encoding params := url.Values{ @@ -477,7 +477,7 @@ func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) { func TestAuthHandler_validateParsedURL(t *testing.T) { logger := &mockLogger{} handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, - "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil) + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false) tests := []struct { name string @@ -560,3 +560,101 @@ func TestAuthHandler_validateParsedURL(t *testing.T) { }) } } + +// TestAuthHandler_validateHost_AllowPrivateIPAddresses tests the allowPrivateIPAddresses flag +func TestAuthHandler_validateHost_AllowPrivateIPAddresses(t *testing.T) { + logger := &mockLogger{} + + // Test with allowPrivateIPAddresses = false (default) + t.Run("Private IPs blocked by default", func(t *testing.T) { + handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false) + + privateIPs := []string{ + "192.168.1.1", + "10.0.0.1", + "172.16.0.1", + "172.31.255.255", + } + + for _, ip := range privateIPs { + err := handler.validateHost(ip) + if err == nil { + t.Errorf("Expected private IP %s to be blocked, but it was allowed", ip) + } + if err != nil && !strings.Contains(err.Error(), "private IP not allowed") { + t.Errorf("Expected 'private IP not allowed' error for %s, got: %v", ip, err) + } + } + }) + + // Test with allowPrivateIPAddresses = true + t.Run("Private IPs allowed when flag enabled", func(t *testing.T) { + handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true) + + privateIPs := []string{ + "192.168.1.1", + "10.0.0.1", + "172.16.0.1", + "172.31.255.255", + } + + for _, ip := range privateIPs { + err := handler.validateHost(ip) + if err != nil { + t.Errorf("Expected private IP %s to be allowed with flag enabled, but got error: %v", ip, err) + } + } + }) + + // Test that loopback is still blocked even with flag enabled + t.Run("Loopback always blocked", func(t *testing.T) { + handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true) + + loopbackAddresses := []string{ + "127.0.0.1", + "localhost", + "::1", + "0.0.0.0", + } + + for _, addr := range loopbackAddresses { + err := handler.validateHost(addr) + if err == nil { + t.Errorf("Expected loopback address %s to be blocked even with allowPrivateIPAddresses=true", addr) + } + } + }) + + // Test that link-local is still blocked even with flag enabled + t.Run("Link-local always blocked", func(t *testing.T) { + handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true) + + err := handler.validateHost("169.254.1.1") + if err == nil { + t.Error("Expected link-local address to be blocked even with allowPrivateIPAddresses=true") + } + }) + + // Test that public IPs work with flag enabled + t.Run("Public IPs allowed", func(t *testing.T) { + handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false }, + "test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true) + + publicIPs := []string{ + "8.8.8.8", + "1.1.1.1", + "142.250.185.68", + } + + for _, ip := range publicIPs { + err := handler.validateHost(ip) + if err != nil { + t.Errorf("Expected public IP %s to be allowed, but got error: %v", ip, err) + } + } + }) +} diff --git a/auth_flow.go b/auth_flow.go index 5b0cb9f..505c532 100644 --- a/auth_flow.go +++ b/auth_flow.go @@ -223,15 +223,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - email, _ := claims["email"].(string) - if email == "" { - t.logger.Errorf("Email claim missing or empty in token during callback") - t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError) - return + // Extract user identifier from the configured claim (defaults to "email" for backward compatibility) + userIdentifier, _ := claims[t.userIdentifierClaim].(string) + if userIdentifier == "" { + // Try "sub" as fallback since it's required by OIDC spec + if t.userIdentifierClaim != "sub" { + userIdentifier, _ = claims["sub"].(string) + } + if userIdentifier == "" { + t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim) + t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError) + return + } + t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim) } - if !t.isAllowedDomain(email) { - t.logger.Errorf("Disallowed email domain during callback: %s", email) - t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden) + + // Validate user authorization + if !t.isAllowedUser(userIdentifier) { + t.logger.Errorf("User not authorized during callback: %s", userIdentifier) + t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden) return } @@ -240,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError) return } - session.SetEmail(email) + session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim) session.SetIDToken(tokenResponse.IDToken) session.SetAccessToken(tokenResponse.AccessToken) session.SetRefreshToken(tokenResponse.RefreshToken) diff --git a/docs/PROVIDER_CONFIGURATIONS.md b/docs/PROVIDER_CONFIGURATIONS.md index e00a9d3..b0db1b8 100644 --- a/docs/PROVIDER_CONFIGURATIONS.md +++ b/docs/PROVIDER_CONFIGURATIONS.md @@ -437,6 +437,21 @@ http: 4. Configure client scopes and mappers 5. Generate client secret in Credentials tab +### Internal Network Deployment + +If your Keycloak instance runs on an internal network with private IP addresses (e.g., Docker networks, Kubernetes internal services), set `allowPrivateIPAddresses: true`: + +```yaml +traefikoidc: + providerUrl: "https://192.168.1.100:8443/auth/realms/your-realm" # Private IP + allowPrivateIPAddresses: true # Required for private IP addresses + clientId: "your-client-id" + clientSecret: "your-client-secret" + # ... other config +``` + +> **Security Warning**: Only enable `allowPrivateIPAddresses` in trusted network environments where you control the OIDC provider. This setting reduces SSRF protection. + --- ## Okta diff --git a/error_recovery.go b/error_recovery.go index dcd08af..9300696 100644 --- a/error_recovery.go +++ b/error_recovery.go @@ -2,10 +2,14 @@ package traefikoidc import ( "context" + "crypto/x509" + "errors" "fmt" + "io" "math" "math/rand/v2" "net" + "strings" "sync" "sync/atomic" "time" @@ -411,6 +415,31 @@ func DefaultRetryConfig() RetryConfig { } } +// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata +// fetching during startup. Uses more aggressive retry settings to handle the race +// condition where Traefik initializes the plugin before routes are fully established, +// or before TLS certificates are properly loaded. +// See: https://github.com/lukaszraczylo/traefikoidc/issues/90 +func MetadataFetchRetryConfig() RetryConfig { + return RetryConfig{ + MaxAttempts: 10, // More attempts for startup scenarios + InitialDelay: 1 * time.Second, // 1 second between attempts as suggested + MaxDelay: 10 * time.Second, // Cap at 10 seconds + BackoffFactor: 1.5, // Gentler backoff for startup + EnableJitter: true, // Prevent thundering herd + RetryableErrors: []string{ + "connection refused", + "timeout", + "temporary failure", + "network unreachable", + "EOF", + "certificate", + "x509", + "tls", + }, + } +} + // RetryExecutor implements retry logic with exponential backoff and jitter. // It automatically retries failed operations based on configurable error patterns // and uses exponential backoff to avoid overwhelming failing services. @@ -487,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error { // isRetryableError checks if an error should trigger a retry // isRetryableError determines if an error should trigger a retry attempt. // Checks error message against configured retryable error patterns. +// Also handles startup-specific errors like Traefik default certificate errors +// and EOF errors that occur during service initialization. func (re *RetryExecutor) isRetryableError(err error) bool { if err == nil { return false } + // Check for Traefik default certificate error (startup race condition) + // See: https://github.com/lukaszraczylo/traefikoidc/issues/90 + if isTraefikDefaultCertError(err) { + return true + } + + // Check for EOF errors (common during startup when services aren't ready) + if isEOFError(err) { + return true + } + + // Check for certificate errors (transient during startup) + if isCertificateError(err) { + return true + } + errStr := err.Error() for _, retryableErr := range re.config.RetryableErrors { @@ -1088,3 +1135,86 @@ func containsSubstring(s, substr string) bool { } return false } + +// isTraefikDefaultCertError detects when Traefik is serving its default self-signed +// certificate during cold-start, before the real certificates are loaded. +// This manifests as an x509.HostnameError where one of the certificate's DNS names +// ends with "traefik.default" (the default Traefik certificate pattern). +// See: https://github.com/lukaszraczylo/traefikoidc/issues/90 +func isTraefikDefaultCertError(err error) bool { + if err == nil { + return false + } + + var hostnameErr x509.HostnameError + if errors.As(err, &hostnameErr) { + if hostnameErr.Certificate != nil { + for _, name := range hostnameErr.Certificate.DNSNames { + if strings.HasSuffix(name, "traefik.default") { + return true + } + } + } + } + + return false +} + +// isEOFError checks if an error is an EOF error, which can occur during +// connection establishment when the remote end closes unexpectedly. +// This is common during service startup when endpoints aren't fully ready. +func isEOFError(err error) bool { + if err == nil { + return false + } + + // Check for direct EOF + if errors.Is(err, io.EOF) { + return true + } + + // Check for unexpected EOF + if errors.Is(err, io.ErrUnexpectedEOF) { + return true + } + + // Check error message for EOF patterns (wrapped errors) + errStr := err.Error() + return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF") +} + +// isCertificateError checks if an error is related to TLS certificate validation. +// These errors are often transient during startup when services are still initializing. +func isCertificateError(err error) bool { + if err == nil { + return false + } + + // Check for x509 certificate errors + var certInvalidErr x509.CertificateInvalidError + var hostnameErr x509.HostnameError + var unknownAuthErr x509.UnknownAuthorityError + + if errors.As(err, &certInvalidErr) || + errors.As(err, &hostnameErr) || + errors.As(err, &unknownAuthErr) { + return true + } + + // Check error message for certificate patterns + errStr := strings.ToLower(err.Error()) + certPatterns := []string{ + "certificate", + "x509", + "tls", + "ssl", + } + + for _, pattern := range certPatterns { + if strings.Contains(errStr, pattern) { + return true + } + } + + return false +} diff --git a/error_recovery_test.go b/error_recovery_test.go index 17f2c92..2687376 100644 --- a/error_recovery_test.go +++ b/error_recovery_test.go @@ -846,3 +846,296 @@ func (e *mockNetError) Temporary() bool { return e.temporary } // Ensure mockNetError implements net.Error var _ net.Error = (*mockNetError)(nil) + +// Test isTraefikDefaultCertError +// See: https://github.com/lukaszraczylo/traefikoidc/issues/90 + +func TestIsTraefikDefaultCertError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "regular error", + err: errors.New("some error"), + expected: false, + }, + { + name: "network error", + err: &mockNetError{msg: "connection refused"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isTraefikDefaultCertError(tt.err) + if result != tt.expected { + t.Errorf("isTraefikDefaultCertError() = %v, expected %v", result, tt.expected) + } + }) + } +} + +// Test isEOFError + +func TestIsEOFError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "regular error", + err: errors.New("some error"), + expected: false, + }, + { + name: "error containing EOF in message", + err: errors.New("connection closed: EOF"), + expected: true, + }, + { + name: "error containing unexpected EOF", + err: errors.New("read: unexpected EOF"), + expected: true, + }, + { + name: "network error without EOF", + err: &mockNetError{msg: "connection refused"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isEOFError(tt.err) + if result != tt.expected { + t.Errorf("isEOFError() = %v, expected %v", result, tt.expected) + } + }) + } +} + +// Test isCertificateError + +func TestIsCertificateError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "regular error", + err: errors.New("some error"), + expected: false, + }, + { + name: "error containing certificate in message", + err: errors.New("tls: failed to verify certificate"), + expected: true, + }, + { + name: "error containing x509 in message", + err: errors.New("x509: certificate signed by unknown authority"), + expected: true, + }, + { + name: "error containing tls in message", + err: errors.New("tls handshake failed"), + expected: true, + }, + { + name: "error containing ssl in message", + err: errors.New("ssl connection error"), + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isCertificateError(tt.err) + if result != tt.expected { + t.Errorf("isCertificateError() = %v, expected %v", result, tt.expected) + } + }) + } +} + +// Test MetadataFetchRetryConfig + +func TestMetadataFetchRetryConfig(t *testing.T) { + config := MetadataFetchRetryConfig() + + if config.MaxAttempts != 10 { + t.Errorf("Expected MaxAttempts 10, got %d", config.MaxAttempts) + } + + if config.InitialDelay != 1*time.Second { + t.Errorf("Expected InitialDelay 1s, got %v", config.InitialDelay) + } + + if config.MaxDelay != 10*time.Second { + t.Errorf("Expected MaxDelay 10s, got %v", config.MaxDelay) + } + + if config.BackoffFactor != 1.5 { + t.Errorf("Expected BackoffFactor 1.5, got %v", config.BackoffFactor) + } + + if !config.EnableJitter { + t.Error("Expected EnableJitter to be true") + } + + // Verify retryable errors include startup-related patterns + expectedPatterns := []string{"EOF", "certificate", "x509", "tls"} + for _, pattern := range expectedPatterns { + found := false + for _, retryableErr := range config.RetryableErrors { + if retryableErr == pattern { + found = true + break + } + } + if !found { + t.Errorf("Expected '%s' in RetryableErrors", pattern) + } + } +} + +// Test RetryExecutor with startup-specific errors + +func TestRetryExecutorStartupErrors(t *testing.T) { + // Verify MetadataFetchRetryConfig creates a valid retry executor + _ = NewRetryExecutor(MetadataFetchRetryConfig(), nil) + + tests := []struct { + name string + err error + shouldRetry bool + }{ + { + name: "EOF error", + err: errors.New("read tcp: EOF"), + shouldRetry: true, + }, + { + name: "unexpected EOF", + err: errors.New("http: unexpected EOF"), + shouldRetry: true, + }, + { + name: "certificate error", + err: errors.New("x509: certificate signed by unknown authority"), + shouldRetry: true, + }, + { + name: "TLS error", + err: errors.New("tls: failed to verify certificate"), + shouldRetry: true, + }, + { + name: "connection refused", + err: errors.New("dial tcp: connection refused"), + shouldRetry: true, + }, + { + name: "permanent error", + err: errors.New("invalid response format"), + shouldRetry: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Use very short delays for testing + testConfig := RetryConfig{ + MaxAttempts: 3, + InitialDelay: 1 * time.Millisecond, + MaxDelay: 10 * time.Millisecond, + BackoffFactor: 1.5, + EnableJitter: false, + RetryableErrors: []string{ + "connection refused", + "timeout", + "temporary failure", + "network unreachable", + "EOF", + "certificate", + "x509", + "tls", + }, + } + testRe := NewRetryExecutor(testConfig, nil) + + attempts := 0 + _ = testRe.ExecuteWithContext(context.Background(), func() error { + attempts++ + return tt.err + }) + + expectedAttempts := 1 + if tt.shouldRetry { + expectedAttempts = 3 + } + + if attempts != expectedAttempts { + t.Errorf("Expected %d attempts for '%s', got %d", expectedAttempts, tt.name, attempts) + } + }) + } +} + +// Test that retry executor properly uses isRetryableError with new error types + +func TestRetryExecutorIsRetryableErrorIntegration(t *testing.T) { + re := NewRetryExecutor(DefaultRetryConfig(), nil) + + // Test that the enhanced isRetryableError is being used + tests := []struct { + name string + err error + shouldRetry bool + }{ + { + name: "EOF in error message", + err: errors.New("connection reset by peer: EOF"), + shouldRetry: true, + }, + { + name: "certificate in error message", + err: errors.New("x509: certificate has expired"), + shouldRetry: true, + }, + { + name: "TLS in error message", + err: errors.New("tls: handshake failure"), + shouldRetry: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := re.isRetryableError(tt.err) + if result != tt.shouldRetry { + t.Errorf("isRetryableError(%q) = %v, expected %v", tt.err.Error(), result, tt.shouldRetry) + } + }) + } +} diff --git a/handlers/oauth_handler.go b/handlers/oauth_handler.go index 055d4f6..2a1f1d3 100644 --- a/handlers/oauth_handler.go +++ b/handlers/oauth_handler.go @@ -15,7 +15,8 @@ type OAuthHandler struct { tokenExchanger TokenExchanger tokenVerifier TokenVerifier extractClaimsFunc func(tokenString string) (map[string]interface{}, error) - isAllowedDomainFunc func(email string) bool + isAllowedUserFunc func(userIdentifier string) bool // validates user authorization + userIdentifierClaim string // JWT claim to use for user identification redirURLPath string sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int) } @@ -77,16 +78,22 @@ type TokenResponse struct { // NewOAuthHandler creates a new OAuth handler func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger, tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error), - isAllowedDomainFunc func(string) bool, redirURLPath string, + isAllowedUserFunc func(string) bool, userIdentifierClaim string, redirURLPath string, sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler { + // Default to "email" for backward compatibility + if userIdentifierClaim == "" { + userIdentifierClaim = "email" + } + return &OAuthHandler{ logger: logger, sessionManager: sessionManager, tokenExchanger: tokenExchanger, tokenVerifier: tokenVerifier, extractClaimsFunc: extractClaimsFunc, - isAllowedDomainFunc: isAllowedDomainFunc, + isAllowedUserFunc: isAllowedUserFunc, + userIdentifierClaim: userIdentifierClaim, redirURLPath: redirURLPath, sendErrorResponseFunc: sendErrorResponseFunc, } @@ -225,15 +232,25 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request, return } - email, _ := claims["email"].(string) - if email == "" { - h.logger.Errorf("Email claim missing or empty in token during callback") - h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError) - return + // Extract user identifier from the configured claim (defaults to "email" for backward compatibility) + userIdentifier, _ := claims[h.userIdentifierClaim].(string) + if userIdentifier == "" { + // Try "sub" as fallback since it's required by OIDC spec + if h.userIdentifierClaim != "sub" { + userIdentifier, _ = claims["sub"].(string) + } + if userIdentifier == "" { + h.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", h.userIdentifierClaim) + h.sendErrorResponseFunc(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError) + return + } + h.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", h.userIdentifierClaim) } - if !h.isAllowedDomainFunc(email) { - h.logger.Errorf("Disallowed email domain during callback: %s", email) - h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden) + + // Validate user authorization + if !h.isAllowedUserFunc(userIdentifier) { + h.logger.Errorf("User not authorized during callback: %s", userIdentifier) + h.sendErrorResponseFunc(rw, req, "Authentication failed: User not authorized", http.StatusForbidden) return } @@ -242,7 +259,7 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request, h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError) return } - session.SetEmail(email) + session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim) session.SetIDToken(tokenResponse.IDToken) session.SetAccessToken(tokenResponse.AccessToken) session.SetRefreshToken(tokenResponse.RefreshToken) diff --git a/handlers/oauth_handler_test.go b/handlers/oauth_handler_test.go index 2e3c9f0..615ce55 100644 --- a/handlers/oauth_handler_test.go +++ b/handlers/oauth_handler_test.go @@ -108,11 +108,11 @@ func TestOAuthHandler_NewOAuthHandler(t *testing.T) { return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil } - isAllowed := func(email string) bool { return true } + isAllowedUser := func(userIdentifier string) bool { return true } sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {} handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowedUser, "email", "/callback", sendError) if handler == nil { t.Fatal("Expected handler to be created, got nil") @@ -151,7 +151,7 @@ func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil) rw := httptest.NewRecorder() @@ -190,7 +190,7 @@ func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) // Test with error parameter req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil) @@ -230,7 +230,7 @@ func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test", nil) rw := httptest.NewRecorder() @@ -265,7 +265,7 @@ func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil) rw := httptest.NewRecorder() @@ -300,7 +300,7 @@ func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil) rw := httptest.NewRecorder() @@ -335,7 +335,7 @@ func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?state=test-state", nil) rw := httptest.NewRecorder() @@ -370,7 +370,7 @@ func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -406,7 +406,7 @@ func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -444,7 +444,7 @@ func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -483,7 +483,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -521,7 +521,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -559,7 +559,7 @@ func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -591,13 +591,13 @@ func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) { if code != http.StatusInternalServerError { t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code) } - if !strings.Contains(msg, "Email missing in token") { - t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg) + if !strings.Contains(msg, "User identifier missing in token") { + t.Errorf("Expected error message to contain 'User identifier missing in token', got '%s'", msg) } } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -629,13 +629,13 @@ func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) { if code != http.StatusForbidden { t.Errorf("Expected status %d, got %d", http.StatusForbidden, code) } - if !strings.Contains(msg, "Email domain not allowed") { - t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg) + if !strings.Contains(msg, "User not authorized") { + t.Errorf("Expected error message to contain 'User not authorized', got '%s'", msg) } } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -677,7 +677,7 @@ func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -719,7 +719,7 @@ func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -760,7 +760,7 @@ func TestOAuthHandler_HandleCallback_Success(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -843,7 +843,7 @@ func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -884,7 +884,7 @@ func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() diff --git a/input_validation.go b/input_validation.go index 723cc05..d59eb40 100644 --- a/input_validation.go +++ b/input_validation.go @@ -15,20 +15,21 @@ import ( // XSS, path traversal, and other injection attacks. It validates and sanitizes // various input types used in OIDC authentication flows. type InputValidator struct { - usernameRegex *regexp.Regexp - tokenRegex *regexp.Regexp - logger *Logger - urlRegex *regexp.Regexp - emailRegex *regexp.Regexp - sqlInjectionPatterns []string - pathTraversalPatterns []string - xssPatterns []string - maxUsernameLength int - maxURLLength int - maxTokenLength int - maxEmailLength int - maxClaimLength int - maxHeaderLength int + usernameRegex *regexp.Regexp + tokenRegex *regexp.Regexp + logger *Logger + urlRegex *regexp.Regexp + emailRegex *regexp.Regexp + sqlInjectionPatterns []string + pathTraversalPatterns []string + xssPatterns []string + maxUsernameLength int + maxURLLength int + maxTokenLength int + maxEmailLength int + maxClaimLength int + maxHeaderLength int + allowPrivateIPAddresses bool // Allow private IP addresses in URL validation } // ValidationResult encapsulates the outcome of input validation. @@ -46,13 +47,14 @@ type ValidationResult struct { // It specifies maximum lengths for various input types and controls whether // strict validation mode is enabled. type InputValidationConfig struct { - MaxTokenLength int `json:"max_token_length"` - MaxURLLength int `json:"max_url_length"` - MaxHeaderLength int `json:"max_header_length"` - MaxClaimLength int `json:"max_claim_length"` - MaxEmailLength int `json:"max_email_length"` - MaxUsernameLength int `json:"max_username_length"` - StrictMode bool `json:"strict_mode"` + MaxTokenLength int `json:"max_token_length"` + MaxURLLength int `json:"max_url_length"` + MaxHeaderLength int `json:"max_header_length"` + MaxClaimLength int `json:"max_claim_length"` + MaxEmailLength int `json:"max_email_length"` + MaxUsernameLength int `json:"max_username_length"` + StrictMode bool `json:"strict_mode"` + AllowPrivateIPAddresses bool `json:"allow_private_ip_addresses"` // Allow private IP addresses in URL validation } // DefaultInputValidationConfig returns a secure default configuration @@ -103,16 +105,17 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali } return &InputValidator{ - maxTokenLength: config.MaxTokenLength, - maxURLLength: config.MaxURLLength, - maxHeaderLength: config.MaxHeaderLength, - maxClaimLength: config.MaxClaimLength, - maxEmailLength: config.MaxEmailLength, - maxUsernameLength: config.MaxUsernameLength, - emailRegex: emailRegex, - urlRegex: urlRegex, - tokenRegex: tokenRegex, - usernameRegex: usernameRegex, + maxTokenLength: config.MaxTokenLength, + maxURLLength: config.MaxURLLength, + maxHeaderLength: config.MaxHeaderLength, + maxClaimLength: config.MaxClaimLength, + maxEmailLength: config.MaxEmailLength, + maxUsernameLength: config.MaxUsernameLength, + allowPrivateIPAddresses: config.AllowPrivateIPAddresses, + emailRegex: emailRegex, + urlRegex: urlRegex, + tokenRegex: tokenRegex, + usernameRegex: usernameRegex, sqlInjectionPatterns: []string{ "'", "\"", ";", "--", "/*", "*/", "xp_", "sp_", "union", "select", "insert", "update", "delete", "drop", @@ -335,24 +338,26 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult { } } - // Check for private IP ranges (RFC 1918) - if strings.HasPrefix(hostname, "10.") || - strings.HasPrefix(hostname, "192.168.") || - strings.HasPrefix(hostname, "172.") { - // For 172.x check if it's in the 172.16.0.0/12 range - if strings.HasPrefix(hostname, "172.") { - parts := strings.Split(hostname, ".") - if len(parts) >= 2 { - if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 { - result.IsValid = false - result.Errors = append(result.Errors, "private IP URLs are not allowed for security") - return result + // Check for private IP ranges (RFC 1918) - skip if allowPrivateIPAddresses is enabled + if !iv.allowPrivateIPAddresses { + if strings.HasPrefix(hostname, "10.") || + strings.HasPrefix(hostname, "192.168.") || + strings.HasPrefix(hostname, "172.") { + // For 172.x check if it's in the 172.16.0.0/12 range + if strings.HasPrefix(hostname, "172.") { + parts := strings.Split(hostname, ".") + if len(parts) >= 2 { + if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 { + result.IsValid = false + result.Errors = append(result.Errors, "private IP URLs are not allowed for security") + return result + } } + } else { + result.IsValid = false + result.Errors = append(result.Errors, "private IP URLs are not allowed for security") + return result } - } else { - result.IsValid = false - result.Errors = append(result.Errors, "private IP URLs are not allowed for security") - return result } } diff --git a/main.go b/main.go index 398c8a5..e61a2f2 100644 --- a/main.go +++ b/main.go @@ -177,6 +177,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name } return "groups" // Backward compatible default }(), + userIdentifierClaim: func() string { + if config.UserIdentifierClaim != "" { + return config.UserIdentifierClaim + } + return "email" // Backward compatible default + }(), forceHTTPS: config.ForceHTTPS, enablePKCE: config.EnablePKCE, overrideScopes: config.OverrideScopes, @@ -218,6 +224,8 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name securityHeadersApplier: config.GetSecurityHeadersApplier(), scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering dcrConfig: config.DynamicClientRegistration, + allowPrivateIPAddresses: config.AllowPrivateIPAddresses, + minimalHeaders: config.MinimalHeaders, } // Log audience configuration diff --git a/main_servehttp_test.go b/main_servehttp_test.go index bb7a82f..c10c349 100644 --- a/main_servehttp_test.go +++ b/main_servehttp_test.go @@ -545,3 +545,168 @@ func createTestSessionManager(t *testing.T) *SessionManager { } return sm } + +// TestMinimalHeaders tests the minimalHeaders configuration option +// This addresses GitHub issue #64 - Request Header Fields Too Large +func TestMinimalHeaders(t *testing.T) { + tests := []struct { + name string + minimalHeaders bool + expectForwardedUser bool + expectAuthRequestUser bool + expectAuthRequestRedirect bool + }{ + { + name: "minimalHeaders=false (default) forwards all headers", + minimalHeaders: false, + expectForwardedUser: true, + expectAuthRequestUser: true, + expectAuthRequestRedirect: true, + }, + { + name: "minimalHeaders=true only forwards X-Forwarded-User", + minimalHeaders: true, + expectForwardedUser: true, + expectAuthRequestUser: false, + expectAuthRequestRedirect: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Track which headers were set + var capturedHeaders http.Header + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + }) + + sessionManager := createTestSessionManager(t) + oidc := &TraefikOidc{ + next: next, + logger: NewLogger("debug"), + initComplete: make(chan struct{}), + sessionManager: sessionManager, + firstRequestReceived: true, + metadataRefreshStarted: true, + issuerURL: "https://provider.example.com", + minimalHeaders: tt.minimalHeaders, + extractClaimsFunc: func(token string) (map[string]interface{}, error) { + return map[string]interface{}{ + "email": "user@example.com", + }, nil + }, + } + close(oidc.initComplete) + + // Create request and get session properly through session manager + req := httptest.NewRequest("GET", "/protected", nil) + rw := httptest.NewRecorder() + + session, err := sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + // Set up session data + session.SetEmail("user@example.com") + session.SetAuthenticated(true) + + // Call processAuthorizedRequest directly + oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback") + + // Verify X-Forwarded-User is always set + if tt.expectForwardedUser { + if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" { + t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User")) + } + } + + // Verify X-Auth-Request-User + hasAuthRequestUser := capturedHeaders.Get("X-Auth-Request-User") != "" + if tt.expectAuthRequestUser && !hasAuthRequestUser { + t.Error("expected X-Auth-Request-User to be set") + } + if !tt.expectAuthRequestUser && hasAuthRequestUser { + t.Errorf("expected X-Auth-Request-User to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-User")) + } + + // Verify X-Auth-Request-Redirect + hasAuthRequestRedirect := capturedHeaders.Get("X-Auth-Request-Redirect") != "" + if tt.expectAuthRequestRedirect && !hasAuthRequestRedirect { + t.Error("expected X-Auth-Request-Redirect to be set") + } + if !tt.expectAuthRequestRedirect && hasAuthRequestRedirect { + t.Errorf("expected X-Auth-Request-Redirect to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Redirect")) + } + + // Note: X-Auth-Request-Token is only set if session.GetIDToken() returns non-empty. + // Token storage has validation that may reject test tokens, so we verify the flag + // logic through the other headers. The important behavior is that when + // minimalHeaders=true, the token header would NOT be set even if a token existed. + }) + } +} + +// TestMinimalHeaders_TokenHeaderNotSet verifies that the X-Auth-Request-Token header +// is NOT set when minimalHeaders is enabled, even if a token exists. +func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) { + var capturedHeaders http.Header + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + }) + + sessionManager := createTestSessionManager(t) + oidc := &TraefikOidc{ + next: next, + logger: NewLogger("debug"), + initComplete: make(chan struct{}), + sessionManager: sessionManager, + firstRequestReceived: true, + metadataRefreshStarted: true, + issuerURL: "https://provider.example.com", + minimalHeaders: true, // Enable minimal headers + extractClaimsFunc: func(token string) (map[string]interface{}, error) { + return map[string]interface{}{ + "email": "user@example.com", + }, nil + }, + } + close(oidc.initComplete) + + req := httptest.NewRequest("GET", "/protected", nil) + rw := httptest.NewRecorder() + + session, err := sessionManager.GetSession(req) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + session.SetEmail("user@example.com") + session.SetAuthenticated(true) + + oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback") + + // Verify X-Forwarded-User is set (always should be) + if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" { + t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User")) + } + + // The key verification: X-Auth-Request-Token should NOT be set with minimalHeaders=true + if capturedHeaders.Get("X-Auth-Request-Token") != "" { + t.Error("expected X-Auth-Request-Token to NOT be set with minimalHeaders=true") + } + + // X-Auth-Request-User should also NOT be set with minimalHeaders=true + if capturedHeaders.Get("X-Auth-Request-User") != "" { + t.Error("expected X-Auth-Request-User to NOT be set with minimalHeaders=true") + } + + // X-Auth-Request-Redirect should also NOT be set with minimalHeaders=true + if capturedHeaders.Get("X-Auth-Request-Redirect") != "" { + t.Error("expected X-Auth-Request-Redirect to NOT be set with minimalHeaders=true") + } +} diff --git a/main_test.go b/main_test.go index 3da5b2e..57a2d9e 100644 --- a/main_test.go +++ b/main_test.go @@ -122,22 +122,23 @@ func (ts *TestSuite) Setup() { // Common TraefikOidc instance ts.tOidc = &TraefikOidc{ - issuerURL: "https://test-issuer.com", - clientID: "test-client-id", - audience: "test-client-id", - clientSecret: "test-client-secret", - roleClaimName: "roles", // Set default for backward compatibility - groupClaimName: "groups", // Set default for backward compatibility - jwkCache: ts.mockJWKCache, - jwksURL: "https://test-jwks-url.com", - revocationURL: "https://revocation-endpoint.com", - limiter: rate.NewLimiter(rate.Every(time.Second), 10), - tokenBlacklist: tokenBlacklist, - tokenCache: tokenCache, - logger: logger, - allowedUserDomains: map[string]struct{}{"example.com": {}}, - excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}}, - httpClient: &http.Client{Timeout: 10 * time.Second}, + issuerURL: "https://test-issuer.com", + clientID: "test-client-id", + audience: "test-client-id", + clientSecret: "test-client-secret", + roleClaimName: "roles", // Set default for backward compatibility + groupClaimName: "groups", // Set default for backward compatibility + userIdentifierClaim: "email", // Set default for backward compatibility + jwkCache: ts.mockJWKCache, + jwksURL: "https://test-jwks-url.com", + revocationURL: "https://revocation-endpoint.com", + limiter: rate.NewLimiter(rate.Every(time.Second), 10), + tokenBlacklist: tokenBlacklist, + tokenCache: tokenCache, + logger: logger, + allowedUserDomains: map[string]struct{}{"example.com": {}}, + excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}}, + httpClient: &http.Client{Timeout: 10 * time.Second}, // Explicitly set paths as New() is bypassed redirURLPath: "/callback", // Assume default callback path for tests logoutURLPath: "/callback/logout", // Assume default logout path for tests @@ -784,7 +785,7 @@ func TestServeHTTP(t *testing.T) { "Accept": "application/json", }, expectedStatus: http.StatusForbidden, - expectedBody: `{"error":"Forbidden","error_description":"Access denied: Your email domain is not allowed. To log out, visit: /callback/logout","status_code":403}`, + expectedBody: `{"error":"Forbidden","error_description":"Access denied: You are not authorized to access this resource. To log out, visit: /callback/logout","status_code":403}`, }, { name: "Disallowed Domain (Accept: HTML)", @@ -1282,8 +1283,9 @@ func TestHandleCallback(t *testing.T) { instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case } tOidc := &TraefikOidc{ - allowedUserDomains: map[string]struct{}{"example.com": {}}, - logger: logger, + allowedUserDomains: map[string]struct{}{"example.com": {}}, + logger: logger, + userIdentifierClaim: "email", // Required for claim extraction // exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function tokenVerifier: nil, // Will be set to self below @@ -1438,6 +1440,228 @@ func TestIsAllowedDomain(t *testing.T) { } } +func TestIsAllowedUser(t *testing.T) { + ts := NewTestSuite(t) + ts.Setup() + + tests := []struct { + allowedDomains map[string]struct{} + allowedUsers map[string]struct{} + userIdentifierClaim string + name string + userIdentifier string + allowed bool + }{ + // Email-based identification (default behavior) + { + name: "Email identifier - allowed domain", + userIdentifier: "user@example.com", + userIdentifierClaim: "email", + allowedDomains: map[string]struct{}{"example.com": {}}, + allowedUsers: map[string]struct{}{}, + allowed: true, + }, + { + name: "Email identifier - disallowed domain", + userIdentifier: "user@notallowed.com", + userIdentifierClaim: "email", + allowedDomains: map[string]struct{}{"example.com": {}}, + allowedUsers: map[string]struct{}{}, + allowed: false, + }, + { + name: "Email identifier - specific user allowed", + userIdentifier: "specific.user@otherdomain.com", + userIdentifierClaim: "email", + allowedDomains: map[string]struct{}{"example.com": {}}, + allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}}, + allowed: true, + }, + + // Non-email identifier (sub claim - for Azure AD users without email) + { + name: "Sub identifier - allowed in allowedUsers", + userIdentifier: "abc12345-6789-0abc-def0-123456789abc", + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, + allowed: true, + }, + { + name: "Sub identifier - not in allowedUsers", + userIdentifier: "xyz-not-allowed-user", + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, + allowed: false, + }, + { + name: "Sub identifier - allowedDomains ignored for non-email", + userIdentifier: "user-id-12345", + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{"example.com": {}}, // Should be ignored + allowedUsers: map[string]struct{}{"user-id-12345": {}}, + allowed: true, + }, + { + name: "Sub identifier - no restrictions allows all", + userIdentifier: "any-user-id", + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{}, + allowed: true, + }, + { + name: "Sub identifier - case insensitive matching", + userIdentifier: "ABC12345-6789-0ABC-DEF0-123456789ABC", // Uppercase + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, // Lowercase + allowed: true, + }, + + // OID claim (Azure AD object ID) + { + name: "OID identifier - allowed user", + userIdentifier: "oid-12345-67890", + userIdentifierClaim: "oid", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{"oid-12345-67890": {}}, + allowed: true, + }, + + // UPN claim (Azure AD User Principal Name) + { + name: "UPN identifier - allowed user (looks like email but use sub logic)", + userIdentifier: "user@tenant.onmicrosoft.com", + userIdentifierClaim: "upn", + allowedDomains: map[string]struct{}{"example.com": {}}, // Different domain, should be ignored + allowedUsers: map[string]struct{}{"user@tenant.onmicrosoft.com": {}}, + allowed: true, + }, + + // Edge cases + { + name: "Empty identifier - not allowed", + userIdentifier: "", + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{"some-user": {}}, + allowed: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Configure TraefikOidc instance for this test case + tOidc := ts.tOidc + tOidc.allowedUserDomains = tc.allowedDomains + tOidc.allowedUsers = tc.allowedUsers + tOidc.userIdentifierClaim = tc.userIdentifierClaim + + allowed := tOidc.isAllowedUser(tc.userIdentifier) + if allowed != tc.allowed { + t.Errorf("Expected allowed=%v, got %v for userIdentifier=%q with claim=%q", + tc.allowed, allowed, tc.userIdentifier, tc.userIdentifierClaim) + } + }) + } +} + +func TestUserIdentifierClaimExtraction(t *testing.T) { + // Test that the correct claim is extracted based on userIdentifierClaim config + tests := []struct { + name string + userIdentifierClaim string + claims map[string]interface{} + expectedIdentifier string + shouldFallbackToSub bool + }{ + { + name: "Email claim extraction (default)", + userIdentifierClaim: "email", + claims: map[string]interface{}{ + "sub": "user-sub-id", + "email": "user@example.com", + }, + expectedIdentifier: "user@example.com", + shouldFallbackToSub: false, + }, + { + name: "Sub claim extraction", + userIdentifierClaim: "sub", + claims: map[string]interface{}{ + "sub": "user-sub-id", + "email": "user@example.com", + }, + expectedIdentifier: "user-sub-id", + shouldFallbackToSub: false, + }, + { + name: "OID claim extraction (Azure AD)", + userIdentifierClaim: "oid", + claims: map[string]interface{}{ + "sub": "user-sub-id", + "email": "user@example.com", + "oid": "azure-object-id", + }, + expectedIdentifier: "azure-object-id", + shouldFallbackToSub: false, + }, + { + name: "UPN claim extraction (Azure AD)", + userIdentifierClaim: "upn", + claims: map[string]interface{}{ + "sub": "user-sub-id", + "upn": "user@tenant.onmicrosoft.com", + }, + expectedIdentifier: "user@tenant.onmicrosoft.com", + shouldFallbackToSub: false, + }, + { + name: "Fallback to sub when configured claim is missing", + userIdentifierClaim: "email", + claims: map[string]interface{}{ + "sub": "fallback-sub-id", + // email is missing + }, + expectedIdentifier: "fallback-sub-id", + shouldFallbackToSub: true, + }, + { + name: "preferred_username claim extraction", + userIdentifierClaim: "preferred_username", + claims: map[string]interface{}{ + "sub": "user-sub-id", + "preferred_username": "jdoe", + }, + expectedIdentifier: "jdoe", + shouldFallbackToSub: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Extract user identifier using the same logic as auth_flow.go + userIdentifier, _ := tc.claims[tc.userIdentifierClaim].(string) + usedFallback := false + + if userIdentifier == "" && tc.userIdentifierClaim != "sub" { + userIdentifier, _ = tc.claims["sub"].(string) + usedFallback = true + } + + if userIdentifier != tc.expectedIdentifier { + t.Errorf("Expected identifier %q, got %q", tc.expectedIdentifier, userIdentifier) + } + + if usedFallback != tc.shouldFallbackToSub { + t.Errorf("Expected fallback=%v, got %v", tc.shouldFallbackToSub, usedFallback) + } + }) + } +} + func TestOIDCHandler(t *testing.T) { ts := NewTestSuite(t) ts.Setup() diff --git a/metadata_cache.go b/metadata_cache.go index 040eccc..8f278da 100644 --- a/metadata_cache.go +++ b/metadata_cache.go @@ -189,11 +189,56 @@ func (mc *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client return mc.GetProviderMetadata(ctx, providerURL, httpClient) } -// GetMetadataWithRecovery fetches metadata with recovery support +// GetMetadataWithRecovery fetches metadata with retry support for startup scenarios. +// This handles the race condition where Traefik initializes the plugin before the +// OIDC provider routes are fully established, or before TLS certificates are loaded. +// Uses aggressive retry settings (10 attempts, 1s intervals) to give the infrastructure +// time to stabilize during cold starts. +// See: https://github.com/lukaszraczylo/traefikoidc/issues/90 func (mc *MetadataCache) GetMetadataWithRecovery(providerURL string, httpClient *http.Client, logger *Logger, errorRecoveryManager *ErrorRecoveryManager) (*ProviderMetadata, error) { - // For now, just use regular GetMetadata - // Recovery would be handled by ErrorRecoveryManager if needed - return mc.GetMetadata(providerURL, httpClient, logger) + // Check cache first - if we have valid cached metadata, use it + if metadata, exists := mc.Get(providerURL); exists { + return metadata, nil + } + + // Create a retry executor with metadata-fetch-specific configuration + retryConfig := MetadataFetchRetryConfig() + retryExecutor := NewRetryExecutor(retryConfig, logger) + + var metadata *ProviderMetadata + var lastErr error + + // Use context with overall timeout for the entire retry sequence + // 10 attempts * ~10s max delay = ~100s worst case, so use 2 minute timeout + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + err := retryExecutor.ExecuteWithContext(ctx, func() error { + // Create per-attempt context with shorter timeout + attemptCtx, attemptCancel := context.WithTimeout(ctx, 15*time.Second) + defer attemptCancel() + + var fetchErr error + metadata, fetchErr = mc.GetProviderMetadata(attemptCtx, providerURL, httpClient) + if fetchErr != nil { + lastErr = fetchErr + if logger != nil { + logger.Debugf("Metadata fetch attempt failed: %v", fetchErr) + } + return fetchErr + } + return nil + }) + + if err != nil { + // Return the last actual error, not the retry wrapper error + if lastErr != nil { + return nil, lastErr + } + return nil, err + } + + return metadata, nil } // GetStats returns cache statistics for testing diff --git a/middleware.go b/middleware.go index 9d103e4..b8c5d6d 100644 --- a/middleware.go +++ b/middleware.go @@ -125,12 +125,12 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - email := session.GetEmail() - // Domain restriction check removed debug output - if authenticated && email != "" { - if !t.isAllowedDomain(email) { - t.logger.Infof("User with email %s is not from an allowed domain", email) - errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath) + userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim) + // User authorization check + if authenticated && userIdentifier != "" { + if !t.isAllowedUser(userIdentifier) { + t.logger.Infof("User %s is not authorized", userIdentifier) + errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath) t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden) return } @@ -193,10 +193,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { refreshed := t.refreshToken(rw, req, session) if refreshed { - email = session.GetEmail() - if email != "" && !t.isAllowedDomain(email) { - t.logger.Infof("User with refreshed token email %s is not from an allowed domain", email) - errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath) + userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier + if userIdentifier != "" && !t.isAllowedUser(userIdentifier) { + t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier) + errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath) t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden) return } @@ -308,10 +308,13 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http req.Header.Set("X-Forwarded-User", email) - req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI()) - req.Header.Set("X-Auth-Request-User", email) - if idToken := session.GetIDToken(); idToken != "" { - req.Header.Set("X-Auth-Request-Token", idToken) + // When minimalHeaders is enabled, skip extra headers to prevent 431 errors + if !t.minimalHeaders { + req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI()) + req.Header.Set("X-Auth-Request-User", email) + if idToken := session.GetIDToken(); idToken != "" { + req.Header.Set("X-Auth-Request-Token", idToken) + } } if len(t.headerTemplates) > 0 { diff --git a/middleware/auth_middleware.go b/middleware/auth_middleware.go index ca85f08..0ddb4d7 100644 --- a/middleware/auth_middleware.go +++ b/middleware/auth_middleware.go @@ -41,6 +41,7 @@ type AuthMiddleware struct { goroutineWG *sync.WaitGroup startTokenCleanupFunc func() startMetadataRefreshFunc func(string) + minimalHeaders bool } // Logger interface for dependency injection @@ -120,6 +121,7 @@ func NewAuthMiddleware( goroutineWG *sync.WaitGroup, startTokenCleanupFunc func(), startMetadataRefreshFunc func(string), + minimalHeaders bool, ) *AuthMiddleware { return &AuthMiddleware{ logger: logger, @@ -149,6 +151,7 @@ func NewAuthMiddleware( goroutineWG: goroutineWG, startTokenCleanupFunc: startTokenCleanupFunc, startMetadataRefreshFunc: startMetadataRefreshFunc, + minimalHeaders: minimalHeaders, } } @@ -414,10 +417,14 @@ func (m *AuthMiddleware) processAuthorizedRequest(rw http.ResponseWriter, req *h } req.Header.Set("X-Forwarded-User", email) - req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI()) - req.Header.Set("X-Auth-Request-User", email) - if idToken := session.GetIDToken(); idToken != "" { - req.Header.Set("X-Auth-Request-Token", idToken) + + // When minimalHeaders is enabled, skip extra headers to prevent 431 errors + if !m.minimalHeaders { + req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI()) + req.Header.Set("X-Auth-Request-User", email) + if idToken := session.GetIDToken(); idToken != "" { + req.Header.Set("X-Auth-Request-Token", idToken) + } } m.next.ServeHTTP(rw, req) diff --git a/middleware/middleware_comprehensive_test.go b/middleware/middleware_comprehensive_test.go index d39b2dc..5497ab8 100644 --- a/middleware/middleware_comprehensive_test.go +++ b/middleware/middleware_comprehensive_test.go @@ -66,6 +66,7 @@ func TestNewAuthMiddleware(t *testing.T) { wg, startTokenCleanup, startMetadataRefresh, + false, // minimalHeaders ) if m == nil { diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index a586bba..048852b 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -802,3 +802,99 @@ func TestServeHTTP_AdditionalCoverage(t *testing.T) { } }) } + +// TestProcessAuthorizedRequest_MinimalHeaders tests the minimalHeaders configuration +// This addresses GitHub issue #64 - Request Header Fields Too Large +func TestProcessAuthorizedRequest_MinimalHeaders(t *testing.T) { + tests := []struct { + name string + minimalHeaders bool + expectForwardedUser bool + expectAuthRequestUser bool + expectAuthRequestToken bool + expectAuthRequestRedirect bool + }{ + { + name: "minimalHeaders=false forwards all headers", + minimalHeaders: false, + expectForwardedUser: true, + expectAuthRequestUser: true, + expectAuthRequestToken: true, + expectAuthRequestRedirect: true, + }, + { + name: "minimalHeaders=true only forwards X-Forwarded-User", + minimalHeaders: true, + expectForwardedUser: true, + expectAuthRequestUser: false, + expectAuthRequestToken: false, + expectAuthRequestRedirect: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := &mockLogger{} + var capturedHeaders http.Header + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + }) + + session := &mockSessionData{ + email: "user@example.com", + idToken: "test-id-token-that-could-be-very-large", + accessToken: "test-access-token", + } + + m := &AuthMiddleware{ + logger: logger, + next: nextHandler, + minimalHeaders: tt.minimalHeaders, + extractGroupsAndRolesFunc: func(tokenString string) ([]string, []string, error) { + return nil, nil, nil + }, + } + + req := httptest.NewRequest("GET", "/protected", nil) + rw := httptest.NewRecorder() + + m.processAuthorizedRequest(rw, req, session, "https://example.com/callback") + + // Verify X-Forwarded-User is always set + if tt.expectForwardedUser { + if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" { + t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User")) + } + } + + // Verify X-Auth-Request-User + hasAuthRequestUser := capturedHeaders.Get("X-Auth-Request-User") != "" + if tt.expectAuthRequestUser && !hasAuthRequestUser { + t.Error("expected X-Auth-Request-User to be set") + } + if !tt.expectAuthRequestUser && hasAuthRequestUser { + t.Errorf("expected X-Auth-Request-User to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-User")) + } + + // Verify X-Auth-Request-Token (the big one that causes 431 errors) + hasAuthRequestToken := capturedHeaders.Get("X-Auth-Request-Token") != "" + if tt.expectAuthRequestToken && !hasAuthRequestToken { + t.Error("expected X-Auth-Request-Token to be set") + } + if !tt.expectAuthRequestToken && hasAuthRequestToken { + t.Errorf("expected X-Auth-Request-Token to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Token")) + } + + // Verify X-Auth-Request-Redirect + hasAuthRequestRedirect := capturedHeaders.Get("X-Auth-Request-Redirect") != "" + if tt.expectAuthRequestRedirect && !hasAuthRequestRedirect { + t.Error("expected X-Auth-Request-Redirect to be set") + } + if !tt.expectAuthRequestRedirect && hasAuthRequestRedirect { + t.Errorf("expected X-Auth-Request-Redirect to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Redirect")) + } + }) + } +} diff --git a/settings.go b/settings.go index c22803d..42a043d 100644 --- a/settings.go +++ b/settings.go @@ -127,10 +127,63 @@ type Config struct { // Default: "groups" GroupClaimName string `json:"groupClaimName,omitempty"` + // UserIdentifierClaim specifies the JWT claim to use as the user identifier. + // This allows authentication for users without email addresses (e.g., Azure AD service accounts). + // + // Examples: + // - Default (backward compatible): "email" + // - Azure AD without email: "sub", "oid", "upn", or "preferred_username" + // - Generic OIDC: "sub" (always present per OIDC spec) + // + // When set to a non-email claim: + // - AllowedUsers will match against this claim value instead of email + // - AllowedUserDomains validation is skipped (domains only apply to email) + // - The session will store this identifier as the user's identity + // + // Default: "email" + UserIdentifierClaim string `json:"userIdentifierClaim,omitempty"` + // DynamicClientRegistration enables OIDC Dynamic Client Registration (RFC 7591) // When enabled, the middleware will automatically register as a client with // the OIDC provider if ClientID/ClientSecret are not provided. DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"` + + // AllowPrivateIPAddresses disables the security check that blocks private/internal IP addresses. + // By default, the plugin rejects URLs containing private IP ranges (10.x.x.x, 172.16-31.x.x, 192.168.x.x) + // to prevent SSRF attacks and ensure OIDC providers are publicly accessible. + // + // Enable this option ONLY when: + // - Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs + // - You have no DNS resolution available for internal services + // - Your entire stack runs in a Docker network or Kubernetes cluster with private addressing + // + // Security Warning: Enabling this option reduces SSRF protection. Only use in trusted + // network environments where the OIDC provider is known and controlled. + // + // Default: false (private IPs are blocked for security) + AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"` + + // MinimalHeaders reduces the number of headers forwarded to downstream services. + // This helps prevent "431 Request Header Fields Too Large" errors when downstream + // services have limited header buffer sizes. + // + // When enabled (true): + // - Only forwards: X-Forwarded-User + // - Skips: X-Auth-Request-Token (full ID token), X-Auth-Request-Redirect + // - Groups/roles headers (X-User-Groups, X-User-Roles) are still forwarded if configured + // - Custom templated headers are still processed + // + // When disabled (false, default): + // - Forwards all headers: X-Forwarded-User, X-Auth-Request-User, X-Auth-Request-Redirect, + // X-Auth-Request-Token (full ID token) + // + // Use this option when: + // - Downstream services return "431 Request Header Fields Too Large" errors + // - You don't need the full ID token forwarded to backend services + // - You want to reduce request overhead + // + // Default: false (all headers forwarded for backward compatibility) + MinimalHeaders bool `json:"minimalHeaders,omitempty"` } // RedisConfig configures Redis cache backend settings for distributed caching. diff --git a/types.go b/types.go index 0307661..4c377fd 100644 --- a/types.go +++ b/types.go @@ -99,6 +99,7 @@ type TraefikOidc struct { audience string // Expected JWT audience, defaults to clientID roleClaimName string // JWT claim name for extracting roles, defaults to "roles" groupClaimName string // JWT claim name for extracting groups, defaults to "groups" + userIdentifierClaim string // JWT claim for user identification, defaults to "email" name string redirURLPath string logoutURLPath string @@ -128,6 +129,8 @@ type TraefikOidc struct { suppressDiagnosticLogs bool firstRequestReceived bool metadataRefreshStarted bool + allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks) + minimalHeaders bool // Reduce headers to prevent 431 errors securityHeadersApplier func(http.ResponseWriter, *http.Request) scopeFilter *ScopeFilter // NEW - for discovery-based scope filtering scopesSupported []string // NEW - from provider metadata diff --git a/url_helpers.go b/url_helpers.go index df19d34..f12e731 100644 --- a/url_helpers.go +++ b/url_helpers.go @@ -340,6 +340,7 @@ func (t *TraefikOidc) validateParsedURL(u *url.URL) error { // validateHost validates a hostname or IP address for security. // It prevents access to localhost, private networks, and known metadata endpoints. +// When allowPrivateIPAddresses is enabled, private IP checks are skipped. // Parameters: // - host: The host string to validate (may include port). // @@ -357,7 +358,13 @@ func (t *TraefikOidc) validateHost(host string) error { ip := net.ParseIP(hostname) if ip != nil { - if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + // Always block loopback, link-local, and multicast addresses + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return fmt.Errorf("access to loopback/link-local IP addresses is not allowed: %s", ip.String()) + } + + // Skip private IP check if allowPrivateIPAddresses is enabled + if !t.allowPrivateIPAddresses && ip.IsPrivate() { return fmt.Errorf("access to private/internal IP addresses is not allowed: %s", ip.String()) } diff --git a/utilities.go b/utilities.go index ddfbc57..dce4518 100644 --- a/utilities.go +++ b/utilities.go @@ -55,6 +55,51 @@ func (t *TraefikOidc) safeLogInfo(msg string) { // DOMAIN VALIDATION // ============================================================================= +// isAllowedUser checks if a user identifier is authorized based on the configured user identifier claim. +// When using email as the identifier (default), it validates against allowedUsers and allowedUserDomains. +// When using non-email identifiers (sub, oid, upn, etc.), it only validates against allowedUsers +// since domain-based validation doesn't apply to non-email identifiers. +// +// Parameters: +// - userIdentifier: The user identifier to validate (email, sub, oid, upn, etc.). +// +// Returns: +// - true if the user is authorized, false otherwise. +func (t *TraefikOidc) isAllowedUser(userIdentifier string) bool { + // If no restrictions are configured, allow all authenticated users + if len(t.allowedUserDomains) == 0 && len(t.allowedUsers) == 0 { + return true + } + + // Check if user is explicitly allowed + if len(t.allowedUsers) > 0 { + _, userAllowed := t.allowedUsers[strings.ToLower(userIdentifier)] + if userAllowed { + t.logger.Debugf("User identifier %s is explicitly allowed in allowedUsers", userIdentifier) + return true + } + } + + // For email-based identifiers, also check domain restrictions + // Only apply domain validation if using email as identifier AND identifier looks like an email + if t.userIdentifierClaim == "email" && strings.Contains(userIdentifier, "@") { + return t.isAllowedDomain(userIdentifier) + } + + // For non-email identifiers with allowedUserDomains configured, log a warning + if len(t.allowedUserDomains) > 0 && t.userIdentifierClaim != "email" { + t.logger.Debugf("AllowedUserDomains is configured but userIdentifierClaim is '%s', not 'email'. Domain validation skipped for: %s", + t.userIdentifierClaim, userIdentifier) + } + + // User not found in allowedUsers list + if len(t.allowedUsers) > 0 { + t.logger.Debugf("User identifier %s is not in the allowed users list", userIdentifier) + } + + return false +} + // isAllowedDomain checks if an email address is authorized based on domain or user whitelist. // It validates against both allowed user domains and specific allowed users. // Parameters: