December 2025 Improvements - Azure AD, Internal Networks, Startup Race Condition (#100)

* Allow internal IPs for OIDC configuration via extra flag.

Addresses issue #97

* Allow for internal IPs in OIDC configuration.

Addresses issue #97.

* feat: Add allowPrivateIPAddresses config option for internal networks

Adds a new configuration option `allowPrivateIPAddresses` that allows
OIDC provider URLs to use private IP addresses (10.x.x.x, 172.16-31.x.x,
192.168.x.x). This is useful for internal deployments where Keycloak or
other OIDC providers run on private networks without DNS resolution.

Security considerations:
- Loopback addresses (127.0.0.1, localhost, ::1) remain blocked
- Link-local addresses (169.254.x.x) remain blocked
- Default is false (secure by default)

Fixes #97

* feat: Support non-email user identifiers for Azure AD

Add userIdentifierClaim configuration option to support Azure AD users
without email addresses. This allows using alternative JWT claims like
"sub", "oid", "upn", or "preferred_username" for user identification.

- Default behavior uses "email" claim (backward compatible)
- Falls back to "sub" claim if configured claim is missing
- allowedUsers matches against the configured claim value
- allowedUserDomains only applies when using email-based identification

Fixes #95

* Race condition on traefik pod startup

When the plugin initializes and calls GetMetadataWithRecovery():

1. Checks cache first (if metadata is cached, returns immediately)
2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays)
3. Attempts to fetch metadata from the OIDC provider
4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries
5. After 10 attempts or on a non-retryable error, returns the error

This allows the plugin to handle the race condition where:
- Traefik initializes the plugin before routes are established
- Traefik serves its default certificate before loading real ones
- The OIDC provider pod isn't fully ready yet

Fixes issue #90

* Race condition on traefik pod startup

When the plugin initializes and calls GetMetadataWithRecovery():

1. Checks cache first (if metadata is cached, returns immediately)
2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays)
3. Attempts to fetch metadata from the OIDC provider
4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries
5. After 10 attempts or on a non-retryable error, returns the error

This allows the plugin to handle the race condition where:
- Traefik initializes the plugin before routes are established
- Traefik serves its default certificate before loading real ones
- The OIDC provider pod isn't fully ready yet

Fixes issue #90

* Headers too big and 431 responses

Added new option `minimalHeaders` to reduce the size of forwarded headers from the auth middleware to backend services.

  - When minimalHeaders: false (default): All headers are forwarded as before
    - X-Forwarded-User (always set)
    - X-Auth-Request-Redirect
    - X-Auth-Request-User
    - X-Auth-Request-Token (the large ID token)
    - X-User-Groups, X-User-Roles (if configured)
  - When minimalHeaders: true: Reduces header overhead
    - X-Forwarded-User (always set)
    - X-User-Groups, X-User-Roles (still forwarded if configured)
    - Custom templated headers (still processed)
    - Skipped: X-Auth-Request-Token, X-Auth-Request-User, X-Auth-Request-Redirect

Fixes issues #64 and #86
This commit is contained in:
2025-12-08 14:21:17 +00:00
committed by GitHub
parent a750c4f5b9
commit 9126c74723
25 changed files with 1642 additions and 212 deletions
+118
View File
@@ -77,6 +77,7 @@ testData:
# Custom claim names for Auth0 and other providers with namespaced claims
roleClaimName: roles # JWT claim name for extracting user roles (default: "roles")
groupClaimName: groups # JWT claim name for extracting user groups (default: "groups")
userIdentifierClaim: email # JWT claim for user identification (default: "email", alternatives: "sub", "oid", "upn", "preferred_username")
# ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.)
# When NOT specified in config: defaults to FALSE (Go zero value)
@@ -120,6 +121,8 @@ testData:
allowOpaqueTokens: false # Enable opaque (non-JWT) access token support via RFC 7662 introspection
requireTokenIntrospection: false # Force introspection for opaque tokens (requires introspection endpoint)
disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false)
allowPrivateIPAddresses: false # Allow private IP addresses in provider URLs for internal networks (default: false)
minimalHeaders: false # Reduce forwarded headers to prevent 431 errors (default: false)
# Security Headers Configuration (enabled by default with 'default' profile)
securityHeaders:
@@ -266,6 +269,8 @@ testDataWithRedis:
# allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers
# - admin
# - editor
# # For internal Keycloak deployments with private IPs (Docker/Kubernetes internal):
# # allowPrivateIPAddresses: true # Enable for private IP addresses like 192.168.x.x, 10.x.x.x
# # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token.
# # See README.md "Provider Configuration Recommendations" for Keycloak.
@@ -287,6 +292,26 @@ testDataWithRedis:
# - "AppRoleName"
# # See README.md "Provider Configuration Recommendations" for Azure AD.
# --- Azure AD Users Without Email Example (Issue #95) ---
# testDataAzureADNoEmail:
# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
# clientID: your-azure-ad-client-id
# clientSecret: your-azure-ad-client-secret
# callbackURL: /oauth2/callback
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure"
# # Use 'sub' claim instead of 'email' for user identification
# userIdentifierClaim: sub # or "oid", "upn", "preferred_username"
# overrideScopes: true # Remove email scope if not needed
# scopes:
# - openid
# - profile
# - groups # For group-based access control
# # When using non-email identifiers, allowedUsers matches against the claim value
# allowedUsers:
# - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID (sub or oid claim)
# # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
# # See: https://github.com/lukaszraczylo/traefikoidc/issues/95
# --- Google Workspace / Google Cloud Identity Example ---
# testDataGoogle:
# providerURL: https://accounts.google.com # Standard Google OIDC endpoint
@@ -605,6 +630,38 @@ configuration:
items:
type: string
userIdentifierClaim:
type: string
description: |
Specifies the JWT claim to use as the user identifier for authentication and authorization.
This allows authentication for users without email addresses, such as Azure AD service
accounts or organizational accounts that don't have email attributes configured.
When set to a non-email claim (e.g., "sub", "oid", "upn"):
- AllowedUsers will match against this claim value instead of email
- AllowedUserDomains validation is skipped (domains only apply to email addresses)
- The session stores this identifier as the user's identity
- If the configured claim is missing, falls back to "sub" (required by OIDC spec)
Common values by provider:
- Default: "email" (standard email-based identification)
- Azure AD: "sub", "oid" (object ID), "upn" (User Principal Name), "preferred_username"
- Generic OIDC: "sub" (always present per OIDC specification)
- Keycloak: "sub", "preferred_username"
Example for Azure AD users without email:
```yaml
userIdentifierClaim: sub
allowedUsers:
- "abc123-user-object-id"
- "xyz789-another-user-id"
```
Default: "email"
See: https://github.com/lukaszraczylo/traefikoidc/issues/95
required: false
revocationURL:
type: string
description: |
@@ -903,6 +960,67 @@ configuration:
Default: false (replay detection enabled)
required: false
allowPrivateIPAddresses:
type: boolean
description: |
Allow private IP addresses in OIDC provider URLs for internal network deployments.
By default, the plugin blocks URLs containing private IP address ranges
(10.x.x.x, 172.16-31.x.x, 192.168.x.x) to prevent SSRF attacks and ensure
OIDC providers are publicly accessible.
Enable this option when:
- Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
- You don't have DNS resolution available for internal services
- Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
When enabled, the plugin will accept provider URLs like:
- https://192.168.1.100:8443/auth/realms/your-realm
- https://10.0.0.50:8080/realms/master
- https://172.16.0.10/auth
Security Warning:
Enabling this option reduces SSRF protection. Only use in trusted network
environments where the OIDC provider is known and controlled. Loopback
addresses (127.0.0.1, localhost, ::1) remain blocked even with this option enabled.
Default: false (private IPs are blocked for security)
See: https://github.com/lukaszraczylo/traefikoidc/issues/97
required: false
minimalHeaders:
type: boolean
description: |
Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors.
When enabled, the middleware only forwards the X-Forwarded-User header and skips
the larger authentication headers that can cause downstream services to reject
requests due to header size limits (typically 8KB).
Headers when disabled (default):
- X-Forwarded-User: User's email address (always set)
- X-Auth-Request-Redirect: Original request URI
- X-Auth-Request-User: User's email address
- X-Auth-Request-Token: Full ID token (can be very large with many claims)
- X-User-Groups: Comma-separated user groups (if configured)
- X-User-Roles: Comma-separated user roles (if configured)
Headers when enabled:
- X-Forwarded-User: User's email address (always set)
- X-User-Groups: Comma-separated user groups (if configured, still forwarded)
- X-User-Roles: Comma-separated user roles (if configured, still forwarded)
- Custom templated headers (still processed)
Use this option when:
- Downstream services return "431 Request Header Fields Too Large" errors
- Your ID tokens are large (many claims, long group lists)
- You don't need the full ID token forwarded to backend services
- You want to reduce request overhead
Default: false (all headers forwarded for backward compatibility)
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
required: false
headers:
type: array
description: |
+84 -2
View File
@@ -124,6 +124,7 @@ The middleware supports the following configuration options:
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
| `roleClaimName` | JWT claim name for extracting user roles (supports namespaced claims for Auth0) | `"roles"` | `"https://myapp.com/roles"`, `"user_roles"` |
| `groupClaimName` | JWT claim name for extracting user groups (supports namespaced claims for Auth0) | `"groups"` | `"https://myapp.com/groups"`, `"user_groups"` |
| `userIdentifierClaim` | JWT claim to use as user identifier (for users without email, e.g., Azure AD service accounts) | `"email"` | `"sub"`, `"oid"`, `"upn"`, `"preferred_username"` |
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
@@ -138,6 +139,8 @@ The middleware supports the following configuration options:
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
@@ -1241,6 +1244,45 @@ spec:
- "AppRoleName" # Application role names
```
### Azure AD Configuration (Users Without Email)
For Azure AD users without email addresses (service accounts, organizational accounts without mail attributes):
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-azure-no-email
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
clientID: your-azure-ad-client-id
clientSecret: your-azure-ad-client-secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
# Use 'sub' instead of 'email' for user identification
userIdentifierClaim: sub # Can also use: "oid", "upn", "preferred_username"
overrideScopes: true # Optional: Don't request email scope if not needed
scopes:
- openid
- profile
- groups
# When using non-email identifiers, allowedUsers matches against the claim value
allowedUsers:
- "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID
- "def67890-1234-5678-90ab-cdef12345678"
# NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
```
> **Note**: When `userIdentifierClaim` is set to a non-email claim (like `sub`, `oid`, or `upn`), the `allowedUserDomains` configuration is ignored since domain-based validation only applies to email addresses. Use `allowedUsers` with the actual claim values instead.
### Auth0 Configuration
```yaml
@@ -1327,8 +1369,12 @@ spec:
- admin
- editor
# Ensure Keycloak client mappers add necessary claims to ID Token
# For internal Keycloak deployments with private IPs (e.g., Docker network):
# allowPrivateIPAddresses: true
```
> **Internal Network Deployment**: If your Keycloak runs on an internal network with private IP addresses (e.g., `192.168.x.x`, `10.x.x.x`, `172.16-31.x.x`) and you don't have DNS resolution available, set `allowPrivateIPAddresses: true` to allow the plugin to connect to your Keycloak instance. See [Issue #97](https://github.com/lukaszraczylo/traefikoidc/issues/97) for details.
### AWS Cognito Configuration
```yaml
@@ -1629,12 +1675,39 @@ headers:
When a user is authenticated, the middleware sets the following headers for downstream services:
- `X-Forwarded-User`: The user's email address
- `X-Forwarded-User`: The user's email address (always set)
- `X-User-Groups`: Comma-separated list of user groups (if available)
- `X-User-Roles`: Comma-separated list of user roles (if available)
- `X-Auth-Request-Redirect`: The original request URI
- `X-Auth-Request-User`: The user's email address
- `X-Auth-Request-Token`: The user's access token
- `X-Auth-Request-Token`: The user's ID token (can be large)
#### Minimal Headers Mode
If your downstream services return **"431 Request Header Fields Too Large"** errors, you can enable minimal headers mode to reduce header overhead:
```yaml
http:
middlewares:
my-auth:
plugin:
traefikoidc:
minimalHeaders: true
# ... other config
```
When `minimalHeaders: true` is set:
- **Only forwards**: `X-Forwarded-User`
- **Skips**: `X-Auth-Request-Token` (the full ID token - often the largest header), `X-Auth-Request-User`, `X-Auth-Request-Redirect`
- **Still forwards**: `X-User-Groups` and `X-User-Roles` (if configured)
- **Still processes**: Custom templated headers
This is particularly useful when:
- Your ID tokens are large (many claims, long group lists)
- Downstream services have limited header buffer sizes (default 8KB in many servers)
- You don't need the full token forwarded to backend services
See [GitHub Issue #64](https://github.com/lukaszraczylo/traefikoidc/issues/64) for details.
### Security Headers
@@ -1862,6 +1935,15 @@ logLevel: debug
- No refresh tokens (re-authentication required on expiry)
- Use only for GitHub API access, not user authentication
15. **Environment variable names containing "API" cause plugin failure** ([Issue #98](https://github.com/lukaszraczylo/traefikoidc/issues/98)):
- When using environment variable syntax like `${OIDC_ENCRYPTION_SECRET_API}` in Traefik configuration, the plugin fails with "invalid handler type: \<nil\>" error
- This is a **Traefik-side issue**, not a plugin bug. Traefik uses reserved environment variables starting with `TRAEFIK_API_*` for its internal API configuration, and the "API" substring in user-defined variable names may interfere with Traefik's environment variable processing
- **Workaround**: Avoid using "API" as a substring in environment variable names. Use alternatives like:
- `${OIDC_ENCRYPTION_SECRET_SVC}` instead of `${OIDC_ENCRYPTION_SECRET_API}`
- `${OIDC_ENCRYPTION_SECRET_SERVICE}`
- `${OIDC_ENCRYPTION_SECRET_BACKEND}`
- Any name that doesn't contain the literal substring "API"
### Provider Warnings and Recommendations
The middleware includes built-in warnings for provider-specific limitations. Check your logs for important notices about:
+21 -20
View File
@@ -849,26 +849,27 @@ func TestAudienceEndToEndScenario(t *testing.T) {
customAudience := "https://api.company.com"
tOidc := &TraefikOidc{
next: nextHandler,
name: "test",
redirURLPath: "/callback",
logoutURLPath: "/callback/logout",
issuerURL: "https://auth.company.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
audience: customAudience, // Set custom audience
jwkCache: mockJWKCache,
jwksURL: "https://auth.company.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
allowedUserDomains: map[string]struct{}{"company.com": {}},
excludedURLs: map[string]struct{}{},
httpClient: &http.Client{},
initComplete: make(chan struct{}),
sessionManager: sm,
extractClaimsFunc: extractClaims,
next: nextHandler,
name: "test",
redirURLPath: "/callback",
logoutURLPath: "/callback/logout",
issuerURL: "https://auth.company.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
audience: customAudience, // Set custom audience
jwkCache: mockJWKCache,
jwksURL: "https://auth.company.com/.well-known/jwks.json",
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
logger: logger,
allowedUserDomains: map[string]struct{}{"company.com": {}},
userIdentifierClaim: "email", // Required for user identification
excludedURLs: map[string]struct{}{},
httpClient: &http.Client{},
initComplete: make(chan struct{}),
sessionManager: sm,
extractClaimsFunc: extractClaims,
}
tOidc.jwtVerifier = tOidc
tOidc.tokenVerifier = tOidc
+29 -25
View File
@@ -18,17 +18,18 @@ type ScopeFilter interface {
// Handler provides core authentication functionality for OIDC flows
type Handler struct {
logger Logger
enablePKCE bool
isGoogleProv func() bool
isAzureProv func() bool
clientID string
authURL string
issuerURL string
scopes []string
overrideScopes bool
scopeFilter ScopeFilter // NEW
scopesSupported []string // NEW - from provider metadata
logger Logger
enablePKCE bool
isGoogleProv func() bool
isAzureProv func() bool
clientID string
authURL string
issuerURL string
scopes []string
overrideScopes bool
scopeFilter ScopeFilter // NEW
scopesSupported []string // NEW - from provider metadata
allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks)
}
// Logger interface for dependency injection
@@ -40,19 +41,20 @@ type Logger interface {
// NewAuthHandler creates a new Handler instance
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool,
scopeFilter ScopeFilter, scopesSupported []string) *Handler {
scopeFilter ScopeFilter, scopesSupported []string, allowPrivateIPAddresses bool) *Handler {
return &Handler{
logger: logger,
enablePKCE: enablePKCE,
isGoogleProv: isGoogleProv,
isAzureProv: isAzureProv,
clientID: clientID,
authURL: authURL,
issuerURL: issuerURL,
scopes: scopes,
overrideScopes: overrideScopes,
scopeFilter: scopeFilter, // NEW
scopesSupported: scopesSupported, // NEW
logger: logger,
enablePKCE: enablePKCE,
isGoogleProv: isGoogleProv,
isAzureProv: isAzureProv,
clientID: clientID,
authURL: authURL,
issuerURL: issuerURL,
scopes: scopes,
overrideScopes: overrideScopes,
scopeFilter: scopeFilter,
scopesSupported: scopesSupported,
allowPrivateIPAddresses: allowPrivateIPAddresses,
}
}
@@ -347,6 +349,7 @@ func (h *Handler) validateParsedURL(u *url.URL) error {
// validateHost validates a hostname for security and reachability.
// It prevents access to private networks and localhost addresses.
// When allowPrivateIPAddresses is enabled, private IP checks are skipped.
func (h *Handler) validateHost(host string) error {
if host == "" {
return fmt.Errorf("empty host")
@@ -361,7 +364,7 @@ func (h *Handler) validateHost(host string) error {
}
}
// Check for localhost variations
// Check for localhost variations (always blocked, even with allowPrivateIPAddresses)
localhostVariations := []string{
"localhost", "127.0.0.1", "::1", "0.0.0.0",
}
@@ -376,7 +379,8 @@ func (h *Handler) validateHost(host string) error {
if ip.IsLoopback() {
return fmt.Errorf("loopback IP not allowed: %s", host)
}
if ip.IsPrivate() {
// Skip private IP check if allowPrivateIPAddresses is enabled
if !h.allowPrivateIPAddresses && ip.IsPrivate() {
return fmt.Errorf("private IP not allowed: %s", host)
}
if ip.IsLinkLocalUnicast() {
+25 -25
View File
@@ -86,7 +86,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
handler := NewAuthHandler(logger, true, isGoogleProv, isAzureProv,
"test-client-id", "https://example.com/auth", "https://example.com",
scopes, false, nil, nil)
scopes, false, nil, nil, false)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
@@ -125,7 +125,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
session := &mockSessionData{redirectCount: 5} // At the limit
req := httptest.NewRequest("GET", "/test", nil)
@@ -160,7 +160,7 @@ func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/test", nil)
@@ -191,7 +191,7 @@ func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/test", nil)
@@ -222,7 +222,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T)
func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/test", nil)
@@ -253,7 +253,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T)
func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
session := &mockSessionData{saveError: &testError{"save failed"}}
req := httptest.NewRequest("GET", "/test?param=value", nil)
@@ -297,7 +297,7 @@ func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
func TestAuthHandler_InitiateAuthentication_Success(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil)
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil, false)
session := &mockSessionData{}
req := httptest.NewRequest("GET", "/protected/resource", nil)
@@ -400,7 +400,7 @@ func TestAuthHandler_BuildAuthURL_GoogleProvider(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
"google-client", "https://accounts.google.com/oauth2/auth", "https://accounts.google.com",
[]string{"openid", "profile", "email"}, false, nil, nil)
[]string{"openid", "profile", "email"}, false, nil, nil, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -440,7 +440,7 @@ func TestAuthHandler_BuildAuthURL_AzureProvider(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
"https://login.microsoftonline.com/tenant/v2.0",
[]string{"openid", "profile", "email"}, false, nil, nil)
[]string{"openid", "profile", "email"}, false, nil, nil, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -468,7 +468,7 @@ func TestAuthHandler_BuildAuthURL_PKCEEnabled(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"pkce-client", "https://example.com/auth", "https://example.com",
[]string{"openid"}, false, nil, nil)
[]string{"openid"}, false, nil, nil, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
@@ -493,7 +493,7 @@ func TestAuthHandler_BuildAuthURL_PKCEDisabled(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"no-pkce-client", "https://example.com/auth", "https://example.com",
[]string{"openid"}, false, nil, nil)
[]string{"openid"}, false, nil, nil, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
@@ -565,7 +565,7 @@ func TestAuthHandler_BuildAuthURL_ScopeHandling(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return tt.isAzure },
"test-client", "https://example.com/auth", "https://example.com",
tt.scopes, tt.overrideScopes, nil, nil)
tt.scopes, tt.overrideScopes, nil, nil, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -634,7 +634,7 @@ func TestAuthHandler_BuildAuthURL_WithScopeFiltering(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, scopeFilter, scopesSupported)
scopes, false, scopeFilter, scopesSupported, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -676,7 +676,7 @@ func TestAuthHandler_BuildAuthURL_WithoutScopeFiltering(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, nil, nil)
scopes, false, nil, nil, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -714,7 +714,7 @@ func TestAuthHandler_BuildAuthURL_GitLabFiltersOfflineAccess(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"gitlab-client", "https://gitlab.example.com/oauth/authorize",
"https://gitlab.example.com",
scopes, false, scopeFilter, scopesSupported)
scopes, false, scopeFilter, scopesSupported, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -756,7 +756,7 @@ func TestAuthHandler_BuildAuthURL_GoogleRemovesOfflineAccess(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
"google-client", "https://accounts.google.com/o/oauth2/v2/auth",
"https://accounts.google.com",
scopes, false, scopeFilter, scopesSupported)
scopes, false, scopeFilter, scopesSupported, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -797,7 +797,7 @@ func TestAuthHandler_BuildAuthURL_AzureAddsOfflineAccess(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
"https://login.microsoftonline.com/tenant/v2.0",
scopes, false, scopeFilter, scopesSupported)
scopes, false, scopeFilter, scopesSupported, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -831,7 +831,7 @@ func TestAuthHandler_BuildAuthURL_GenericWithFiltering(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"generic-client", "https://auth.provider.com/authorize",
"https://auth.provider.com",
scopes, false, scopeFilter, scopesSupported)
scopes, false, scopeFilter, scopesSupported, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -870,7 +870,7 @@ func TestAuthHandler_BuildAuthURL_OverrideScopesWithFiltering(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, true, scopeFilter, scopesSupported)
scopes, true, scopeFilter, scopesSupported, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -916,7 +916,7 @@ func TestAuthHandler_BuildAuthURL_DoubleFiltering(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, scopeFilter, scopesSupported)
scopes, false, scopeFilter, scopesSupported, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -955,7 +955,7 @@ func TestAuthHandler_BuildAuthURL_NoScopeFilterProvided(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, nil, scopesSupported) // scopeFilter is nil
scopes, false, nil, scopesSupported, false) // scopeFilter is nil
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -988,7 +988,7 @@ func TestAuthHandler_BuildAuthURL_EmptyScopesSupported(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, scopeFilter, scopesSupported)
scopes, false, scopeFilter, scopesSupported, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
@@ -1021,7 +1021,7 @@ func TestAuthHandler_BuildAuthURL_FilteringWithPKCE(t *testing.T) {
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, scopeFilter, scopesSupported)
scopes, false, scopeFilter, scopesSupported, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
@@ -1064,7 +1064,7 @@ func TestAuthHandler_BuildAuthURL_ComplexScenario(t *testing.T) {
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
"complex-client", "https://auth.complex.com/authorize", "https://auth.complex.com",
scopes, false, scopeFilter, scopesSupported)
scopes, false, scopeFilter, scopesSupported, false)
authURL := handler.BuildAuthURL("https://example.com/callback", "state-123", "nonce-456", "challenge-789")
@@ -1130,7 +1130,7 @@ func TestAuthHandler_BuildAuthURL_LoggingVerification(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com",
scopes, false, scopeFilter, scopesSupported)
scopes, false, scopeFilter, scopesSupported, false)
handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
+103 -5
View File
@@ -10,7 +10,7 @@ import (
func TestAuthHandler_validateURL(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
tests := []struct {
name string
@@ -185,7 +185,7 @@ func TestAuthHandler_validateURL(t *testing.T) {
func TestAuthHandler_validateHost(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
tests := []struct {
name string
@@ -334,7 +334,7 @@ func TestAuthHandler_validateHost(t *testing.T) {
func TestAuthHandler_buildURLWithParams(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
tests := []struct {
name string
@@ -438,7 +438,7 @@ func TestAuthHandler_buildURLWithParams(t *testing.T) {
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
// Test special characters that need encoding
params := url.Values{
@@ -477,7 +477,7 @@ func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
func TestAuthHandler_validateParsedURL(t *testing.T) {
logger := &mockLogger{}
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
tests := []struct {
name string
@@ -560,3 +560,101 @@ func TestAuthHandler_validateParsedURL(t *testing.T) {
})
}
}
// TestAuthHandler_validateHost_AllowPrivateIPAddresses tests the allowPrivateIPAddresses flag
func TestAuthHandler_validateHost_AllowPrivateIPAddresses(t *testing.T) {
logger := &mockLogger{}
// Test with allowPrivateIPAddresses = false (default)
t.Run("Private IPs blocked by default", func(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
privateIPs := []string{
"192.168.1.1",
"10.0.0.1",
"172.16.0.1",
"172.31.255.255",
}
for _, ip := range privateIPs {
err := handler.validateHost(ip)
if err == nil {
t.Errorf("Expected private IP %s to be blocked, but it was allowed", ip)
}
if err != nil && !strings.Contains(err.Error(), "private IP not allowed") {
t.Errorf("Expected 'private IP not allowed' error for %s, got: %v", ip, err)
}
}
})
// Test with allowPrivateIPAddresses = true
t.Run("Private IPs allowed when flag enabled", func(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
privateIPs := []string{
"192.168.1.1",
"10.0.0.1",
"172.16.0.1",
"172.31.255.255",
}
for _, ip := range privateIPs {
err := handler.validateHost(ip)
if err != nil {
t.Errorf("Expected private IP %s to be allowed with flag enabled, but got error: %v", ip, err)
}
}
})
// Test that loopback is still blocked even with flag enabled
t.Run("Loopback always blocked", func(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
loopbackAddresses := []string{
"127.0.0.1",
"localhost",
"::1",
"0.0.0.0",
}
for _, addr := range loopbackAddresses {
err := handler.validateHost(addr)
if err == nil {
t.Errorf("Expected loopback address %s to be blocked even with allowPrivateIPAddresses=true", addr)
}
}
})
// Test that link-local is still blocked even with flag enabled
t.Run("Link-local always blocked", func(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
err := handler.validateHost("169.254.1.1")
if err == nil {
t.Error("Expected link-local address to be blocked even with allowPrivateIPAddresses=true")
}
})
// Test that public IPs work with flag enabled
t.Run("Public IPs allowed", func(t *testing.T) {
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
publicIPs := []string{
"8.8.8.8",
"1.1.1.1",
"142.250.185.68",
}
for _, ip := range publicIPs {
err := handler.validateHost(ip)
if err != nil {
t.Errorf("Expected public IP %s to be allowed, but got error: %v", ip, err)
}
}
})
}
+19 -9
View File
@@ -223,15 +223,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Errorf("Email claim missing or empty in token during callback")
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
return
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
if userIdentifier == "" {
// Try "sub" as fallback since it's required by OIDC spec
if t.userIdentifierClaim != "sub" {
userIdentifier, _ = claims["sub"].(string)
}
if userIdentifier == "" {
t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim)
t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
return
}
t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim)
}
if !t.isAllowedDomain(email) {
t.logger.Errorf("Disallowed email domain during callback: %s", email)
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
// Validate user authorization
if !t.isAllowedUser(userIdentifier) {
t.logger.Errorf("User not authorized during callback: %s", userIdentifier)
t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
return
}
@@ -240,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(email)
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
+15
View File
@@ -437,6 +437,21 @@ http:
4. Configure client scopes and mappers
5. Generate client secret in Credentials tab
### Internal Network Deployment
If your Keycloak instance runs on an internal network with private IP addresses (e.g., Docker networks, Kubernetes internal services), set `allowPrivateIPAddresses: true`:
```yaml
traefikoidc:
providerUrl: "https://192.168.1.100:8443/auth/realms/your-realm" # Private IP
allowPrivateIPAddresses: true # Required for private IP addresses
clientId: "your-client-id"
clientSecret: "your-client-secret"
# ... other config
```
> **Security Warning**: Only enable `allowPrivateIPAddresses` in trusted network environments where you control the OIDC provider. This setting reduces SSRF protection.
---
## Okta
+130
View File
@@ -2,10 +2,14 @@ package traefikoidc
import (
"context"
"crypto/x509"
"errors"
"fmt"
"io"
"math"
"math/rand/v2"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@@ -411,6 +415,31 @@ func DefaultRetryConfig() RetryConfig {
}
}
// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata
// fetching during startup. Uses more aggressive retry settings to handle the race
// condition where Traefik initializes the plugin before routes are fully established,
// or before TLS certificates are properly loaded.
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func MetadataFetchRetryConfig() RetryConfig {
return RetryConfig{
MaxAttempts: 10, // More attempts for startup scenarios
InitialDelay: 1 * time.Second, // 1 second between attempts as suggested
MaxDelay: 10 * time.Second, // Cap at 10 seconds
BackoffFactor: 1.5, // Gentler backoff for startup
EnableJitter: true, // Prevent thundering herd
RetryableErrors: []string{
"connection refused",
"timeout",
"temporary failure",
"network unreachable",
"EOF",
"certificate",
"x509",
"tls",
},
}
}
// RetryExecutor implements retry logic with exponential backoff and jitter.
// It automatically retries failed operations based on configurable error patterns
// and uses exponential backoff to avoid overwhelming failing services.
@@ -487,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
// isRetryableError checks if an error should trigger a retry
// isRetryableError determines if an error should trigger a retry attempt.
// Checks error message against configured retryable error patterns.
// Also handles startup-specific errors like Traefik default certificate errors
// and EOF errors that occur during service initialization.
func (re *RetryExecutor) isRetryableError(err error) bool {
if err == nil {
return false
}
// Check for Traefik default certificate error (startup race condition)
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
if isTraefikDefaultCertError(err) {
return true
}
// Check for EOF errors (common during startup when services aren't ready)
if isEOFError(err) {
return true
}
// Check for certificate errors (transient during startup)
if isCertificateError(err) {
return true
}
errStr := err.Error()
for _, retryableErr := range re.config.RetryableErrors {
@@ -1088,3 +1135,86 @@ func containsSubstring(s, substr string) bool {
}
return false
}
// isTraefikDefaultCertError detects when Traefik is serving its default self-signed
// certificate during cold-start, before the real certificates are loaded.
// This manifests as an x509.HostnameError where one of the certificate's DNS names
// ends with "traefik.default" (the default Traefik certificate pattern).
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func isTraefikDefaultCertError(err error) bool {
if err == nil {
return false
}
var hostnameErr x509.HostnameError
if errors.As(err, &hostnameErr) {
if hostnameErr.Certificate != nil {
for _, name := range hostnameErr.Certificate.DNSNames {
if strings.HasSuffix(name, "traefik.default") {
return true
}
}
}
}
return false
}
// isEOFError checks if an error is an EOF error, which can occur during
// connection establishment when the remote end closes unexpectedly.
// This is common during service startup when endpoints aren't fully ready.
func isEOFError(err error) bool {
if err == nil {
return false
}
// Check for direct EOF
if errors.Is(err, io.EOF) {
return true
}
// Check for unexpected EOF
if errors.Is(err, io.ErrUnexpectedEOF) {
return true
}
// Check error message for EOF patterns (wrapped errors)
errStr := err.Error()
return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF")
}
// isCertificateError checks if an error is related to TLS certificate validation.
// These errors are often transient during startup when services are still initializing.
func isCertificateError(err error) bool {
if err == nil {
return false
}
// Check for x509 certificate errors
var certInvalidErr x509.CertificateInvalidError
var hostnameErr x509.HostnameError
var unknownAuthErr x509.UnknownAuthorityError
if errors.As(err, &certInvalidErr) ||
errors.As(err, &hostnameErr) ||
errors.As(err, &unknownAuthErr) {
return true
}
// Check error message for certificate patterns
errStr := strings.ToLower(err.Error())
certPatterns := []string{
"certificate",
"x509",
"tls",
"ssl",
}
for _, pattern := range certPatterns {
if strings.Contains(errStr, pattern) {
return true
}
}
return false
}
+293
View File
@@ -846,3 +846,296 @@ func (e *mockNetError) Temporary() bool { return e.temporary }
// Ensure mockNetError implements net.Error
var _ net.Error = (*mockNetError)(nil)
// Test isTraefikDefaultCertError
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func TestIsTraefikDefaultCertError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "regular error",
err: errors.New("some error"),
expected: false,
},
{
name: "network error",
err: &mockNetError{msg: "connection refused"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isTraefikDefaultCertError(tt.err)
if result != tt.expected {
t.Errorf("isTraefikDefaultCertError() = %v, expected %v", result, tt.expected)
}
})
}
}
// Test isEOFError
func TestIsEOFError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "regular error",
err: errors.New("some error"),
expected: false,
},
{
name: "error containing EOF in message",
err: errors.New("connection closed: EOF"),
expected: true,
},
{
name: "error containing unexpected EOF",
err: errors.New("read: unexpected EOF"),
expected: true,
},
{
name: "network error without EOF",
err: &mockNetError{msg: "connection refused"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isEOFError(tt.err)
if result != tt.expected {
t.Errorf("isEOFError() = %v, expected %v", result, tt.expected)
}
})
}
}
// Test isCertificateError
func TestIsCertificateError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "regular error",
err: errors.New("some error"),
expected: false,
},
{
name: "error containing certificate in message",
err: errors.New("tls: failed to verify certificate"),
expected: true,
},
{
name: "error containing x509 in message",
err: errors.New("x509: certificate signed by unknown authority"),
expected: true,
},
{
name: "error containing tls in message",
err: errors.New("tls handshake failed"),
expected: true,
},
{
name: "error containing ssl in message",
err: errors.New("ssl connection error"),
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isCertificateError(tt.err)
if result != tt.expected {
t.Errorf("isCertificateError() = %v, expected %v", result, tt.expected)
}
})
}
}
// Test MetadataFetchRetryConfig
func TestMetadataFetchRetryConfig(t *testing.T) {
config := MetadataFetchRetryConfig()
if config.MaxAttempts != 10 {
t.Errorf("Expected MaxAttempts 10, got %d", config.MaxAttempts)
}
if config.InitialDelay != 1*time.Second {
t.Errorf("Expected InitialDelay 1s, got %v", config.InitialDelay)
}
if config.MaxDelay != 10*time.Second {
t.Errorf("Expected MaxDelay 10s, got %v", config.MaxDelay)
}
if config.BackoffFactor != 1.5 {
t.Errorf("Expected BackoffFactor 1.5, got %v", config.BackoffFactor)
}
if !config.EnableJitter {
t.Error("Expected EnableJitter to be true")
}
// Verify retryable errors include startup-related patterns
expectedPatterns := []string{"EOF", "certificate", "x509", "tls"}
for _, pattern := range expectedPatterns {
found := false
for _, retryableErr := range config.RetryableErrors {
if retryableErr == pattern {
found = true
break
}
}
if !found {
t.Errorf("Expected '%s' in RetryableErrors", pattern)
}
}
}
// Test RetryExecutor with startup-specific errors
func TestRetryExecutorStartupErrors(t *testing.T) {
// Verify MetadataFetchRetryConfig creates a valid retry executor
_ = NewRetryExecutor(MetadataFetchRetryConfig(), nil)
tests := []struct {
name string
err error
shouldRetry bool
}{
{
name: "EOF error",
err: errors.New("read tcp: EOF"),
shouldRetry: true,
},
{
name: "unexpected EOF",
err: errors.New("http: unexpected EOF"),
shouldRetry: true,
},
{
name: "certificate error",
err: errors.New("x509: certificate signed by unknown authority"),
shouldRetry: true,
},
{
name: "TLS error",
err: errors.New("tls: failed to verify certificate"),
shouldRetry: true,
},
{
name: "connection refused",
err: errors.New("dial tcp: connection refused"),
shouldRetry: true,
},
{
name: "permanent error",
err: errors.New("invalid response format"),
shouldRetry: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use very short delays for testing
testConfig := RetryConfig{
MaxAttempts: 3,
InitialDelay: 1 * time.Millisecond,
MaxDelay: 10 * time.Millisecond,
BackoffFactor: 1.5,
EnableJitter: false,
RetryableErrors: []string{
"connection refused",
"timeout",
"temporary failure",
"network unreachable",
"EOF",
"certificate",
"x509",
"tls",
},
}
testRe := NewRetryExecutor(testConfig, nil)
attempts := 0
_ = testRe.ExecuteWithContext(context.Background(), func() error {
attempts++
return tt.err
})
expectedAttempts := 1
if tt.shouldRetry {
expectedAttempts = 3
}
if attempts != expectedAttempts {
t.Errorf("Expected %d attempts for '%s', got %d", expectedAttempts, tt.name, attempts)
}
})
}
}
// Test that retry executor properly uses isRetryableError with new error types
func TestRetryExecutorIsRetryableErrorIntegration(t *testing.T) {
re := NewRetryExecutor(DefaultRetryConfig(), nil)
// Test that the enhanced isRetryableError is being used
tests := []struct {
name string
err error
shouldRetry bool
}{
{
name: "EOF in error message",
err: errors.New("connection reset by peer: EOF"),
shouldRetry: true,
},
{
name: "certificate in error message",
err: errors.New("x509: certificate has expired"),
shouldRetry: true,
},
{
name: "TLS in error message",
err: errors.New("tls: handshake failure"),
shouldRetry: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := re.isRetryableError(tt.err)
if result != tt.shouldRetry {
t.Errorf("isRetryableError(%q) = %v, expected %v", tt.err.Error(), result, tt.shouldRetry)
}
})
}
}
+29 -12
View File
@@ -15,7 +15,8 @@ type OAuthHandler struct {
tokenExchanger TokenExchanger
tokenVerifier TokenVerifier
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
isAllowedDomainFunc func(email string) bool
isAllowedUserFunc func(userIdentifier string) bool // validates user authorization
userIdentifierClaim string // JWT claim to use for user identification
redirURLPath string
sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int)
}
@@ -77,16 +78,22 @@ type TokenResponse struct {
// NewOAuthHandler creates a new OAuth handler
func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger,
tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error),
isAllowedDomainFunc func(string) bool, redirURLPath string,
isAllowedUserFunc func(string) bool, userIdentifierClaim string, redirURLPath string,
sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler {
// Default to "email" for backward compatibility
if userIdentifierClaim == "" {
userIdentifierClaim = "email"
}
return &OAuthHandler{
logger: logger,
sessionManager: sessionManager,
tokenExchanger: tokenExchanger,
tokenVerifier: tokenVerifier,
extractClaimsFunc: extractClaimsFunc,
isAllowedDomainFunc: isAllowedDomainFunc,
isAllowedUserFunc: isAllowedUserFunc,
userIdentifierClaim: userIdentifierClaim,
redirURLPath: redirURLPath,
sendErrorResponseFunc: sendErrorResponseFunc,
}
@@ -225,15 +232,25 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
return
}
email, _ := claims["email"].(string)
if email == "" {
h.logger.Errorf("Email claim missing or empty in token during callback")
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
return
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
userIdentifier, _ := claims[h.userIdentifierClaim].(string)
if userIdentifier == "" {
// Try "sub" as fallback since it's required by OIDC spec
if h.userIdentifierClaim != "sub" {
userIdentifier, _ = claims["sub"].(string)
}
if userIdentifier == "" {
h.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", h.userIdentifierClaim)
h.sendErrorResponseFunc(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
return
}
h.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", h.userIdentifierClaim)
}
if !h.isAllowedDomainFunc(email) {
h.logger.Errorf("Disallowed email domain during callback: %s", email)
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
// Validate user authorization
if !h.isAllowedUserFunc(userIdentifier) {
h.logger.Errorf("User not authorized during callback: %s", userIdentifier)
h.sendErrorResponseFunc(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
return
}
@@ -242,7 +259,7 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(email)
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
+25 -25
View File
@@ -108,11 +108,11 @@ func TestOAuthHandler_NewOAuthHandler(t *testing.T) {
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
}
isAllowed := func(email string) bool { return true }
isAllowedUser := func(userIdentifier string) bool { return true }
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowedUser, "email", "/callback", sendError)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
@@ -151,7 +151,7 @@ func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil)
rw := httptest.NewRecorder()
@@ -190,7 +190,7 @@ func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
// Test with error parameter
req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil)
@@ -230,7 +230,7 @@ func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test", nil)
rw := httptest.NewRecorder()
@@ -265,7 +265,7 @@ func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -300,7 +300,7 @@ func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -335,7 +335,7 @@ func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?state=test-state", nil)
rw := httptest.NewRecorder()
@@ -370,7 +370,7 @@ func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -406,7 +406,7 @@ func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -444,7 +444,7 @@ func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -483,7 +483,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -521,7 +521,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -559,7 +559,7 @@ func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -591,13 +591,13 @@ func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) {
if code != http.StatusInternalServerError {
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
}
if !strings.Contains(msg, "Email missing in token") {
t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg)
if !strings.Contains(msg, "User identifier missing in token") {
t.Errorf("Expected error message to contain 'User identifier missing in token', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -629,13 +629,13 @@ func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) {
if code != http.StatusForbidden {
t.Errorf("Expected status %d, got %d", http.StatusForbidden, code)
}
if !strings.Contains(msg, "Email domain not allowed") {
t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg)
if !strings.Contains(msg, "User not authorized") {
t.Errorf("Expected error message to contain 'User not authorized', got '%s'", msg)
}
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -677,7 +677,7 @@ func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -719,7 +719,7 @@ func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -760,7 +760,7 @@ func TestOAuthHandler_HandleCallback_Success(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -843,7 +843,7 @@ func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
@@ -884,7 +884,7 @@ func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) {
}
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
extractClaims, isAllowed, "/callback", sendError)
extractClaims, isAllowed, "email", "/callback", sendError)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
+52 -47
View File
@@ -15,20 +15,21 @@ import (
// XSS, path traversal, and other injection attacks. It validates and sanitizes
// various input types used in OIDC authentication flows.
type InputValidator struct {
usernameRegex *regexp.Regexp
tokenRegex *regexp.Regexp
logger *Logger
urlRegex *regexp.Regexp
emailRegex *regexp.Regexp
sqlInjectionPatterns []string
pathTraversalPatterns []string
xssPatterns []string
maxUsernameLength int
maxURLLength int
maxTokenLength int
maxEmailLength int
maxClaimLength int
maxHeaderLength int
usernameRegex *regexp.Regexp
tokenRegex *regexp.Regexp
logger *Logger
urlRegex *regexp.Regexp
emailRegex *regexp.Regexp
sqlInjectionPatterns []string
pathTraversalPatterns []string
xssPatterns []string
maxUsernameLength int
maxURLLength int
maxTokenLength int
maxEmailLength int
maxClaimLength int
maxHeaderLength int
allowPrivateIPAddresses bool // Allow private IP addresses in URL validation
}
// ValidationResult encapsulates the outcome of input validation.
@@ -46,13 +47,14 @@ type ValidationResult struct {
// It specifies maximum lengths for various input types and controls whether
// strict validation mode is enabled.
type InputValidationConfig struct {
MaxTokenLength int `json:"max_token_length"`
MaxURLLength int `json:"max_url_length"`
MaxHeaderLength int `json:"max_header_length"`
MaxClaimLength int `json:"max_claim_length"`
MaxEmailLength int `json:"max_email_length"`
MaxUsernameLength int `json:"max_username_length"`
StrictMode bool `json:"strict_mode"`
MaxTokenLength int `json:"max_token_length"`
MaxURLLength int `json:"max_url_length"`
MaxHeaderLength int `json:"max_header_length"`
MaxClaimLength int `json:"max_claim_length"`
MaxEmailLength int `json:"max_email_length"`
MaxUsernameLength int `json:"max_username_length"`
StrictMode bool `json:"strict_mode"`
AllowPrivateIPAddresses bool `json:"allow_private_ip_addresses"` // Allow private IP addresses in URL validation
}
// DefaultInputValidationConfig returns a secure default configuration
@@ -103,16 +105,17 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
}
return &InputValidator{
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
maxHeaderLength: config.MaxHeaderLength,
maxClaimLength: config.MaxClaimLength,
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
maxHeaderLength: config.MaxHeaderLength,
maxClaimLength: config.MaxClaimLength,
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
sqlInjectionPatterns: []string{
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
"union", "select", "insert", "update", "delete", "drop",
@@ -335,24 +338,26 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
}
}
// Check for private IP ranges (RFC 1918)
if strings.HasPrefix(hostname, "10.") ||
strings.HasPrefix(hostname, "192.168.") ||
strings.HasPrefix(hostname, "172.") {
// For 172.x check if it's in the 172.16.0.0/12 range
if strings.HasPrefix(hostname, "172.") {
parts := strings.Split(hostname, ".")
if len(parts) >= 2 {
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
// Check for private IP ranges (RFC 1918) - skip if allowPrivateIPAddresses is enabled
if !iv.allowPrivateIPAddresses {
if strings.HasPrefix(hostname, "10.") ||
strings.HasPrefix(hostname, "192.168.") ||
strings.HasPrefix(hostname, "172.") {
// For 172.x check if it's in the 172.16.0.0/12 range
if strings.HasPrefix(hostname, "172.") {
parts := strings.Split(hostname, ".")
if len(parts) >= 2 {
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
}
} else {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
} else {
result.IsValid = false
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
return result
}
}
+8
View File
@@ -177,6 +177,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
}
return "groups" // Backward compatible default
}(),
userIdentifierClaim: func() string {
if config.UserIdentifierClaim != "" {
return config.UserIdentifierClaim
}
return "email" // Backward compatible default
}(),
forceHTTPS: config.ForceHTTPS,
enablePKCE: config.EnablePKCE,
overrideScopes: config.OverrideScopes,
@@ -218,6 +224,8 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
securityHeadersApplier: config.GetSecurityHeadersApplier(),
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
dcrConfig: config.DynamicClientRegistration,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
minimalHeaders: config.MinimalHeaders,
}
// Log audience configuration
+165
View File
@@ -545,3 +545,168 @@ func createTestSessionManager(t *testing.T) *SessionManager {
}
return sm
}
// TestMinimalHeaders tests the minimalHeaders configuration option
// This addresses GitHub issue #64 - Request Header Fields Too Large
func TestMinimalHeaders(t *testing.T) {
tests := []struct {
name string
minimalHeaders bool
expectForwardedUser bool
expectAuthRequestUser bool
expectAuthRequestRedirect bool
}{
{
name: "minimalHeaders=false (default) forwards all headers",
minimalHeaders: false,
expectForwardedUser: true,
expectAuthRequestUser: true,
expectAuthRequestRedirect: true,
},
{
name: "minimalHeaders=true only forwards X-Forwarded-User",
minimalHeaders: true,
expectForwardedUser: true,
expectAuthRequestUser: false,
expectAuthRequestRedirect: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Track which headers were set
var capturedHeaders http.Header
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
minimalHeaders: tt.minimalHeaders,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
close(oidc.initComplete)
// Create request and get session properly through session manager
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set up session data
session.SetEmail("user@example.com")
session.SetAuthenticated(true)
// Call processAuthorizedRequest directly
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Verify X-Forwarded-User is always set
if tt.expectForwardedUser {
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
}
}
// Verify X-Auth-Request-User
hasAuthRequestUser := capturedHeaders.Get("X-Auth-Request-User") != ""
if tt.expectAuthRequestUser && !hasAuthRequestUser {
t.Error("expected X-Auth-Request-User to be set")
}
if !tt.expectAuthRequestUser && hasAuthRequestUser {
t.Errorf("expected X-Auth-Request-User to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-User"))
}
// Verify X-Auth-Request-Redirect
hasAuthRequestRedirect := capturedHeaders.Get("X-Auth-Request-Redirect") != ""
if tt.expectAuthRequestRedirect && !hasAuthRequestRedirect {
t.Error("expected X-Auth-Request-Redirect to be set")
}
if !tt.expectAuthRequestRedirect && hasAuthRequestRedirect {
t.Errorf("expected X-Auth-Request-Redirect to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Redirect"))
}
// Note: X-Auth-Request-Token is only set if session.GetIDToken() returns non-empty.
// Token storage has validation that may reject test tokens, so we verify the flag
// logic through the other headers. The important behavior is that when
// minimalHeaders=true, the token header would NOT be set even if a token existed.
})
}
}
// TestMinimalHeaders_TokenHeaderNotSet verifies that the X-Auth-Request-Token header
// is NOT set when minimalHeaders is enabled, even if a token exists.
func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
var capturedHeaders http.Header
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
minimalHeaders: true, // Enable minimal headers
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Verify X-Forwarded-User is set (always should be)
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
}
// The key verification: X-Auth-Request-Token should NOT be set with minimalHeaders=true
if capturedHeaders.Get("X-Auth-Request-Token") != "" {
t.Error("expected X-Auth-Request-Token to NOT be set with minimalHeaders=true")
}
// X-Auth-Request-User should also NOT be set with minimalHeaders=true
if capturedHeaders.Get("X-Auth-Request-User") != "" {
t.Error("expected X-Auth-Request-User to NOT be set with minimalHeaders=true")
}
// X-Auth-Request-Redirect should also NOT be set with minimalHeaders=true
if capturedHeaders.Get("X-Auth-Request-Redirect") != "" {
t.Error("expected X-Auth-Request-Redirect to NOT be set with minimalHeaders=true")
}
}
+243 -19
View File
@@ -122,22 +122,23 @@ func (ts *TestSuite) Setup() {
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
roleClaimName: "roles", // Set default for backward compatibility
groupClaimName: "groups", // Set default for backward compatibility
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
revocationURL: "https://revocation-endpoint.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}},
httpClient: &http.Client{Timeout: 10 * time.Second},
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
audience: "test-client-id",
clientSecret: "test-client-secret",
roleClaimName: "roles", // Set default for backward compatibility
groupClaimName: "groups", // Set default for backward compatibility
userIdentifierClaim: "email", // Set default for backward compatibility
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
revocationURL: "https://revocation-endpoint.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: tokenBlacklist,
tokenCache: tokenCache,
logger: logger,
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}},
httpClient: &http.Client{Timeout: 10 * time.Second},
// Explicitly set paths as New() is bypassed
redirURLPath: "/callback", // Assume default callback path for tests
logoutURLPath: "/callback/logout", // Assume default logout path for tests
@@ -784,7 +785,7 @@ func TestServeHTTP(t *testing.T) {
"Accept": "application/json",
},
expectedStatus: http.StatusForbidden,
expectedBody: `{"error":"Forbidden","error_description":"Access denied: Your email domain is not allowed. To log out, visit: /callback/logout","status_code":403}`,
expectedBody: `{"error":"Forbidden","error_description":"Access denied: You are not authorized to access this resource. To log out, visit: /callback/logout","status_code":403}`,
},
{
name: "Disallowed Domain (Accept: HTML)",
@@ -1282,8 +1283,9 @@ func TestHandleCallback(t *testing.T) {
instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case
}
tOidc := &TraefikOidc{
allowedUserDomains: map[string]struct{}{"example.com": {}},
logger: logger,
allowedUserDomains: map[string]struct{}{"example.com": {}},
logger: logger,
userIdentifierClaim: "email", // Required for claim extraction
// exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field
extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function
tokenVerifier: nil, // Will be set to self below
@@ -1438,6 +1440,228 @@ func TestIsAllowedDomain(t *testing.T) {
}
}
func TestIsAllowedUser(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
tests := []struct {
allowedDomains map[string]struct{}
allowedUsers map[string]struct{}
userIdentifierClaim string
name string
userIdentifier string
allowed bool
}{
// Email-based identification (default behavior)
{
name: "Email identifier - allowed domain",
userIdentifier: "user@example.com",
userIdentifierClaim: "email",
allowedDomains: map[string]struct{}{"example.com": {}},
allowedUsers: map[string]struct{}{},
allowed: true,
},
{
name: "Email identifier - disallowed domain",
userIdentifier: "user@notallowed.com",
userIdentifierClaim: "email",
allowedDomains: map[string]struct{}{"example.com": {}},
allowedUsers: map[string]struct{}{},
allowed: false,
},
{
name: "Email identifier - specific user allowed",
userIdentifier: "specific.user@otherdomain.com",
userIdentifierClaim: "email",
allowedDomains: map[string]struct{}{"example.com": {}},
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
allowed: true,
},
// Non-email identifier (sub claim - for Azure AD users without email)
{
name: "Sub identifier - allowed in allowedUsers",
userIdentifier: "abc12345-6789-0abc-def0-123456789abc",
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}},
allowed: true,
},
{
name: "Sub identifier - not in allowedUsers",
userIdentifier: "xyz-not-allowed-user",
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}},
allowed: false,
},
{
name: "Sub identifier - allowedDomains ignored for non-email",
userIdentifier: "user-id-12345",
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{"example.com": {}}, // Should be ignored
allowedUsers: map[string]struct{}{"user-id-12345": {}},
allowed: true,
},
{
name: "Sub identifier - no restrictions allows all",
userIdentifier: "any-user-id",
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{},
allowed: true,
},
{
name: "Sub identifier - case insensitive matching",
userIdentifier: "ABC12345-6789-0ABC-DEF0-123456789ABC", // Uppercase
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, // Lowercase
allowed: true,
},
// OID claim (Azure AD object ID)
{
name: "OID identifier - allowed user",
userIdentifier: "oid-12345-67890",
userIdentifierClaim: "oid",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{"oid-12345-67890": {}},
allowed: true,
},
// UPN claim (Azure AD User Principal Name)
{
name: "UPN identifier - allowed user (looks like email but use sub logic)",
userIdentifier: "user@tenant.onmicrosoft.com",
userIdentifierClaim: "upn",
allowedDomains: map[string]struct{}{"example.com": {}}, // Different domain, should be ignored
allowedUsers: map[string]struct{}{"user@tenant.onmicrosoft.com": {}},
allowed: true,
},
// Edge cases
{
name: "Empty identifier - not allowed",
userIdentifier: "",
userIdentifierClaim: "sub",
allowedDomains: map[string]struct{}{},
allowedUsers: map[string]struct{}{"some-user": {}},
allowed: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Configure TraefikOidc instance for this test case
tOidc := ts.tOidc
tOidc.allowedUserDomains = tc.allowedDomains
tOidc.allowedUsers = tc.allowedUsers
tOidc.userIdentifierClaim = tc.userIdentifierClaim
allowed := tOidc.isAllowedUser(tc.userIdentifier)
if allowed != tc.allowed {
t.Errorf("Expected allowed=%v, got %v for userIdentifier=%q with claim=%q",
tc.allowed, allowed, tc.userIdentifier, tc.userIdentifierClaim)
}
})
}
}
func TestUserIdentifierClaimExtraction(t *testing.T) {
// Test that the correct claim is extracted based on userIdentifierClaim config
tests := []struct {
name string
userIdentifierClaim string
claims map[string]interface{}
expectedIdentifier string
shouldFallbackToSub bool
}{
{
name: "Email claim extraction (default)",
userIdentifierClaim: "email",
claims: map[string]interface{}{
"sub": "user-sub-id",
"email": "user@example.com",
},
expectedIdentifier: "user@example.com",
shouldFallbackToSub: false,
},
{
name: "Sub claim extraction",
userIdentifierClaim: "sub",
claims: map[string]interface{}{
"sub": "user-sub-id",
"email": "user@example.com",
},
expectedIdentifier: "user-sub-id",
shouldFallbackToSub: false,
},
{
name: "OID claim extraction (Azure AD)",
userIdentifierClaim: "oid",
claims: map[string]interface{}{
"sub": "user-sub-id",
"email": "user@example.com",
"oid": "azure-object-id",
},
expectedIdentifier: "azure-object-id",
shouldFallbackToSub: false,
},
{
name: "UPN claim extraction (Azure AD)",
userIdentifierClaim: "upn",
claims: map[string]interface{}{
"sub": "user-sub-id",
"upn": "user@tenant.onmicrosoft.com",
},
expectedIdentifier: "user@tenant.onmicrosoft.com",
shouldFallbackToSub: false,
},
{
name: "Fallback to sub when configured claim is missing",
userIdentifierClaim: "email",
claims: map[string]interface{}{
"sub": "fallback-sub-id",
// email is missing
},
expectedIdentifier: "fallback-sub-id",
shouldFallbackToSub: true,
},
{
name: "preferred_username claim extraction",
userIdentifierClaim: "preferred_username",
claims: map[string]interface{}{
"sub": "user-sub-id",
"preferred_username": "jdoe",
},
expectedIdentifier: "jdoe",
shouldFallbackToSub: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Extract user identifier using the same logic as auth_flow.go
userIdentifier, _ := tc.claims[tc.userIdentifierClaim].(string)
usedFallback := false
if userIdentifier == "" && tc.userIdentifierClaim != "sub" {
userIdentifier, _ = tc.claims["sub"].(string)
usedFallback = true
}
if userIdentifier != tc.expectedIdentifier {
t.Errorf("Expected identifier %q, got %q", tc.expectedIdentifier, userIdentifier)
}
if usedFallback != tc.shouldFallbackToSub {
t.Errorf("Expected fallback=%v, got %v", tc.shouldFallbackToSub, usedFallback)
}
})
}
}
func TestOIDCHandler(t *testing.T) {
ts := NewTestSuite(t)
ts.Setup()
+49 -4
View File
@@ -189,11 +189,56 @@ func (mc *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client
return mc.GetProviderMetadata(ctx, providerURL, httpClient)
}
// GetMetadataWithRecovery fetches metadata with recovery support
// GetMetadataWithRecovery fetches metadata with retry support for startup scenarios.
// This handles the race condition where Traefik initializes the plugin before the
// OIDC provider routes are fully established, or before TLS certificates are loaded.
// Uses aggressive retry settings (10 attempts, 1s intervals) to give the infrastructure
// time to stabilize during cold starts.
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
func (mc *MetadataCache) GetMetadataWithRecovery(providerURL string, httpClient *http.Client, logger *Logger, errorRecoveryManager *ErrorRecoveryManager) (*ProviderMetadata, error) {
// For now, just use regular GetMetadata
// Recovery would be handled by ErrorRecoveryManager if needed
return mc.GetMetadata(providerURL, httpClient, logger)
// Check cache first - if we have valid cached metadata, use it
if metadata, exists := mc.Get(providerURL); exists {
return metadata, nil
}
// Create a retry executor with metadata-fetch-specific configuration
retryConfig := MetadataFetchRetryConfig()
retryExecutor := NewRetryExecutor(retryConfig, logger)
var metadata *ProviderMetadata
var lastErr error
// Use context with overall timeout for the entire retry sequence
// 10 attempts * ~10s max delay = ~100s worst case, so use 2 minute timeout
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
err := retryExecutor.ExecuteWithContext(ctx, func() error {
// Create per-attempt context with shorter timeout
attemptCtx, attemptCancel := context.WithTimeout(ctx, 15*time.Second)
defer attemptCancel()
var fetchErr error
metadata, fetchErr = mc.GetProviderMetadata(attemptCtx, providerURL, httpClient)
if fetchErr != nil {
lastErr = fetchErr
if logger != nil {
logger.Debugf("Metadata fetch attempt failed: %v", fetchErr)
}
return fetchErr
}
return nil
})
if err != nil {
// Return the last actual error, not the retry wrapper error
if lastErr != nil {
return nil, lastErr
}
return nil, err
}
return metadata, nil
}
// GetStats returns cache statistics for testing
+17 -14
View File
@@ -125,12 +125,12 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
email := session.GetEmail()
// Domain restriction check removed debug output
if authenticated && email != "" {
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
// User authorization check
if authenticated && userIdentifier != "" {
if !t.isAllowedUser(userIdentifier) {
t.logger.Infof("User %s is not authorized", userIdentifier)
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
return
}
@@ -193,10 +193,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
refreshed := t.refreshToken(rw, req, session)
if refreshed {
email = session.GetEmail()
if email != "" && !t.isAllowedDomain(email) {
t.logger.Infof("User with refreshed token email %s is not from an allowed domain", email)
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
return
}
@@ -308,10 +308,13 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
req.Header.Set("X-Forwarded-User", email)
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetIDToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetIDToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
}
if len(t.headerTemplates) > 0 {
+11 -4
View File
@@ -41,6 +41,7 @@ type AuthMiddleware struct {
goroutineWG *sync.WaitGroup
startTokenCleanupFunc func()
startMetadataRefreshFunc func(string)
minimalHeaders bool
}
// Logger interface for dependency injection
@@ -120,6 +121,7 @@ func NewAuthMiddleware(
goroutineWG *sync.WaitGroup,
startTokenCleanupFunc func(),
startMetadataRefreshFunc func(string),
minimalHeaders bool,
) *AuthMiddleware {
return &AuthMiddleware{
logger: logger,
@@ -149,6 +151,7 @@ func NewAuthMiddleware(
goroutineWG: goroutineWG,
startTokenCleanupFunc: startTokenCleanupFunc,
startMetadataRefreshFunc: startMetadataRefreshFunc,
minimalHeaders: minimalHeaders,
}
}
@@ -414,10 +417,14 @@ func (m *AuthMiddleware) processAuthorizedRequest(rw http.ResponseWriter, req *h
}
req.Header.Set("X-Forwarded-User", email)
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetIDToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
if !m.minimalHeaders {
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetIDToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
}
m.next.ServeHTTP(rw, req)
@@ -66,6 +66,7 @@ func TestNewAuthMiddleware(t *testing.T) {
wg,
startTokenCleanup,
startMetadataRefresh,
false, // minimalHeaders
)
if m == nil {
+96
View File
@@ -802,3 +802,99 @@ func TestServeHTTP_AdditionalCoverage(t *testing.T) {
}
})
}
// TestProcessAuthorizedRequest_MinimalHeaders tests the minimalHeaders configuration
// This addresses GitHub issue #64 - Request Header Fields Too Large
func TestProcessAuthorizedRequest_MinimalHeaders(t *testing.T) {
tests := []struct {
name string
minimalHeaders bool
expectForwardedUser bool
expectAuthRequestUser bool
expectAuthRequestToken bool
expectAuthRequestRedirect bool
}{
{
name: "minimalHeaders=false forwards all headers",
minimalHeaders: false,
expectForwardedUser: true,
expectAuthRequestUser: true,
expectAuthRequestToken: true,
expectAuthRequestRedirect: true,
},
{
name: "minimalHeaders=true only forwards X-Forwarded-User",
minimalHeaders: true,
expectForwardedUser: true,
expectAuthRequestUser: false,
expectAuthRequestToken: false,
expectAuthRequestRedirect: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := &mockLogger{}
var capturedHeaders http.Header
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
})
session := &mockSessionData{
email: "user@example.com",
idToken: "test-id-token-that-could-be-very-large",
accessToken: "test-access-token",
}
m := &AuthMiddleware{
logger: logger,
next: nextHandler,
minimalHeaders: tt.minimalHeaders,
extractGroupsAndRolesFunc: func(tokenString string) ([]string, []string, error) {
return nil, nil, nil
},
}
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
m.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Verify X-Forwarded-User is always set
if tt.expectForwardedUser {
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
}
}
// Verify X-Auth-Request-User
hasAuthRequestUser := capturedHeaders.Get("X-Auth-Request-User") != ""
if tt.expectAuthRequestUser && !hasAuthRequestUser {
t.Error("expected X-Auth-Request-User to be set")
}
if !tt.expectAuthRequestUser && hasAuthRequestUser {
t.Errorf("expected X-Auth-Request-User to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-User"))
}
// Verify X-Auth-Request-Token (the big one that causes 431 errors)
hasAuthRequestToken := capturedHeaders.Get("X-Auth-Request-Token") != ""
if tt.expectAuthRequestToken && !hasAuthRequestToken {
t.Error("expected X-Auth-Request-Token to be set")
}
if !tt.expectAuthRequestToken && hasAuthRequestToken {
t.Errorf("expected X-Auth-Request-Token to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Token"))
}
// Verify X-Auth-Request-Redirect
hasAuthRequestRedirect := capturedHeaders.Get("X-Auth-Request-Redirect") != ""
if tt.expectAuthRequestRedirect && !hasAuthRequestRedirect {
t.Error("expected X-Auth-Request-Redirect to be set")
}
if !tt.expectAuthRequestRedirect && hasAuthRequestRedirect {
t.Errorf("expected X-Auth-Request-Redirect to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Redirect"))
}
})
}
}
+53
View File
@@ -127,10 +127,63 @@ type Config struct {
// Default: "groups"
GroupClaimName string `json:"groupClaimName,omitempty"`
// UserIdentifierClaim specifies the JWT claim to use as the user identifier.
// This allows authentication for users without email addresses (e.g., Azure AD service accounts).
//
// Examples:
// - Default (backward compatible): "email"
// - Azure AD without email: "sub", "oid", "upn", or "preferred_username"
// - Generic OIDC: "sub" (always present per OIDC spec)
//
// When set to a non-email claim:
// - AllowedUsers will match against this claim value instead of email
// - AllowedUserDomains validation is skipped (domains only apply to email)
// - The session will store this identifier as the user's identity
//
// Default: "email"
UserIdentifierClaim string `json:"userIdentifierClaim,omitempty"`
// DynamicClientRegistration enables OIDC Dynamic Client Registration (RFC 7591)
// When enabled, the middleware will automatically register as a client with
// the OIDC provider if ClientID/ClientSecret are not provided.
DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"`
// AllowPrivateIPAddresses disables the security check that blocks private/internal IP addresses.
// By default, the plugin rejects URLs containing private IP ranges (10.x.x.x, 172.16-31.x.x, 192.168.x.x)
// to prevent SSRF attacks and ensure OIDC providers are publicly accessible.
//
// Enable this option ONLY when:
// - Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
// - You have no DNS resolution available for internal services
// - Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
//
// Security Warning: Enabling this option reduces SSRF protection. Only use in trusted
// network environments where the OIDC provider is known and controlled.
//
// Default: false (private IPs are blocked for security)
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
// MinimalHeaders reduces the number of headers forwarded to downstream services.
// This helps prevent "431 Request Header Fields Too Large" errors when downstream
// services have limited header buffer sizes.
//
// When enabled (true):
// - Only forwards: X-Forwarded-User
// - Skips: X-Auth-Request-Token (full ID token), X-Auth-Request-Redirect
// - Groups/roles headers (X-User-Groups, X-User-Roles) are still forwarded if configured
// - Custom templated headers are still processed
//
// When disabled (false, default):
// - Forwards all headers: X-Forwarded-User, X-Auth-Request-User, X-Auth-Request-Redirect,
// X-Auth-Request-Token (full ID token)
//
// Use this option when:
// - Downstream services return "431 Request Header Fields Too Large" errors
// - You don't need the full ID token forwarded to backend services
// - You want to reduce request overhead
//
// Default: false (all headers forwarded for backward compatibility)
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
}
// RedisConfig configures Redis cache backend settings for distributed caching.
+3
View File
@@ -99,6 +99,7 @@ type TraefikOidc struct {
audience string // Expected JWT audience, defaults to clientID
roleClaimName string // JWT claim name for extracting roles, defaults to "roles"
groupClaimName string // JWT claim name for extracting groups, defaults to "groups"
userIdentifierClaim string // JWT claim for user identification, defaults to "email"
name string
redirURLPath string
logoutURLPath string
@@ -128,6 +129,8 @@ type TraefikOidc struct {
suppressDiagnosticLogs bool
firstRequestReceived bool
metadataRefreshStarted bool
allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks)
minimalHeaders bool // Reduce headers to prevent 431 errors
securityHeadersApplier func(http.ResponseWriter, *http.Request)
scopeFilter *ScopeFilter // NEW - for discovery-based scope filtering
scopesSupported []string // NEW - from provider metadata
+8 -1
View File
@@ -340,6 +340,7 @@ func (t *TraefikOidc) validateParsedURL(u *url.URL) error {
// validateHost validates a hostname or IP address for security.
// It prevents access to localhost, private networks, and known metadata endpoints.
// When allowPrivateIPAddresses is enabled, private IP checks are skipped.
// Parameters:
// - host: The host string to validate (may include port).
//
@@ -357,7 +358,13 @@ func (t *TraefikOidc) validateHost(host string) error {
ip := net.ParseIP(hostname)
if ip != nil {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
// Always block loopback, link-local, and multicast addresses
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return fmt.Errorf("access to loopback/link-local IP addresses is not allowed: %s", ip.String())
}
// Skip private IP check if allowPrivateIPAddresses is enabled
if !t.allowPrivateIPAddresses && ip.IsPrivate() {
return fmt.Errorf("access to private/internal IP addresses is not allowed: %s", ip.String())
}
+45
View File
@@ -55,6 +55,51 @@ func (t *TraefikOidc) safeLogInfo(msg string) {
// DOMAIN VALIDATION
// =============================================================================
// isAllowedUser checks if a user identifier is authorized based on the configured user identifier claim.
// When using email as the identifier (default), it validates against allowedUsers and allowedUserDomains.
// When using non-email identifiers (sub, oid, upn, etc.), it only validates against allowedUsers
// since domain-based validation doesn't apply to non-email identifiers.
//
// Parameters:
// - userIdentifier: The user identifier to validate (email, sub, oid, upn, etc.).
//
// Returns:
// - true if the user is authorized, false otherwise.
func (t *TraefikOidc) isAllowedUser(userIdentifier string) bool {
// If no restrictions are configured, allow all authenticated users
if len(t.allowedUserDomains) == 0 && len(t.allowedUsers) == 0 {
return true
}
// Check if user is explicitly allowed
if len(t.allowedUsers) > 0 {
_, userAllowed := t.allowedUsers[strings.ToLower(userIdentifier)]
if userAllowed {
t.logger.Debugf("User identifier %s is explicitly allowed in allowedUsers", userIdentifier)
return true
}
}
// For email-based identifiers, also check domain restrictions
// Only apply domain validation if using email as identifier AND identifier looks like an email
if t.userIdentifierClaim == "email" && strings.Contains(userIdentifier, "@") {
return t.isAllowedDomain(userIdentifier)
}
// For non-email identifiers with allowedUserDomains configured, log a warning
if len(t.allowedUserDomains) > 0 && t.userIdentifierClaim != "email" {
t.logger.Debugf("AllowedUserDomains is configured but userIdentifierClaim is '%s', not 'email'. Domain validation skipped for: %s",
t.userIdentifierClaim, userIdentifier)
}
// User not found in allowedUsers list
if len(t.allowedUsers) > 0 {
t.logger.Debugf("User identifier %s is not in the allowed users list", userIdentifier)
}
return false
}
// isAllowedDomain checks if an email address is authorized based on domain or user whitelist.
// It validates against both allowed user domains and specific allowed users.
// Parameters: