mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
December 2025 Improvements - Azure AD, Internal Networks, Startup Race Condition (#100)
* Allow internal IPs for OIDC configuration via extra flag. Addresses issue #97 * Allow for internal IPs in OIDC configuration. Addresses issue #97. * feat: Add allowPrivateIPAddresses config option for internal networks Adds a new configuration option `allowPrivateIPAddresses` that allows OIDC provider URLs to use private IP addresses (10.x.x.x, 172.16-31.x.x, 192.168.x.x). This is useful for internal deployments where Keycloak or other OIDC providers run on private networks without DNS resolution. Security considerations: - Loopback addresses (127.0.0.1, localhost, ::1) remain blocked - Link-local addresses (169.254.x.x) remain blocked - Default is false (secure by default) Fixes #97 * feat: Support non-email user identifiers for Azure AD Add userIdentifierClaim configuration option to support Azure AD users without email addresses. This allows using alternative JWT claims like "sub", "oid", "upn", or "preferred_username" for user identification. - Default behavior uses "email" claim (backward compatible) - Falls back to "sub" claim if configured claim is missing - allowedUsers matches against the configured claim value - allowedUserDomains only applies when using email-based identification Fixes #95 * Race condition on traefik pod startup When the plugin initializes and calls GetMetadataWithRecovery(): 1. Checks cache first (if metadata is cached, returns immediately) 2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays) 3. Attempts to fetch metadata from the OIDC provider 4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries 5. After 10 attempts or on a non-retryable error, returns the error This allows the plugin to handle the race condition where: - Traefik initializes the plugin before routes are established - Traefik serves its default certificate before loading real ones - The OIDC provider pod isn't fully ready yet Fixes issue #90 * Race condition on traefik pod startup When the plugin initializes and calls GetMetadataWithRecovery(): 1. Checks cache first (if metadata is cached, returns immediately) 2. Creates a retry executor with startup-optimized settings (10 attempts, 1s delays) 3. Attempts to fetch metadata from the OIDC provider 4. If the fetch fails with a retryable error (connection refused, EOF, TLS/certificate errors, Traefik default cert), it waits and retries 5. After 10 attempts or on a non-retryable error, returns the error This allows the plugin to handle the race condition where: - Traefik initializes the plugin before routes are established - Traefik serves its default certificate before loading real ones - The OIDC provider pod isn't fully ready yet Fixes issue #90 * Headers too big and 431 responses Added new option `minimalHeaders` to reduce the size of forwarded headers from the auth middleware to backend services. - When minimalHeaders: false (default): All headers are forwarded as before - X-Forwarded-User (always set) - X-Auth-Request-Redirect - X-Auth-Request-User - X-Auth-Request-Token (the large ID token) - X-User-Groups, X-User-Roles (if configured) - When minimalHeaders: true: Reduces header overhead - X-Forwarded-User (always set) - X-User-Groups, X-User-Roles (still forwarded if configured) - Custom templated headers (still processed) - Skipped: X-Auth-Request-Token, X-Auth-Request-User, X-Auth-Request-Redirect Fixes issues #64 and #86
This commit is contained in:
+118
@@ -77,6 +77,7 @@ testData:
|
||||
# Custom claim names for Auth0 and other providers with namespaced claims
|
||||
roleClaimName: roles # JWT claim name for extracting user roles (default: "roles")
|
||||
groupClaimName: groups # JWT claim name for extracting user groups (default: "groups")
|
||||
userIdentifierClaim: email # JWT claim for user identification (default: "email", alternatives: "sub", "oid", "upn", "preferred_username")
|
||||
|
||||
# ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.)
|
||||
# When NOT specified in config: defaults to FALSE (Go zero value)
|
||||
@@ -120,6 +121,8 @@ testData:
|
||||
allowOpaqueTokens: false # Enable opaque (non-JWT) access token support via RFC 7662 introspection
|
||||
requireTokenIntrospection: false # Force introspection for opaque tokens (requires introspection endpoint)
|
||||
disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false)
|
||||
allowPrivateIPAddresses: false # Allow private IP addresses in provider URLs for internal networks (default: false)
|
||||
minimalHeaders: false # Reduce forwarded headers to prevent 431 errors (default: false)
|
||||
|
||||
# Security Headers Configuration (enabled by default with 'default' profile)
|
||||
securityHeaders:
|
||||
@@ -266,6 +269,8 @@ testDataWithRedis:
|
||||
# allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers
|
||||
# - admin
|
||||
# - editor
|
||||
# # For internal Keycloak deployments with private IPs (Docker/Kubernetes internal):
|
||||
# # allowPrivateIPAddresses: true # Enable for private IP addresses like 192.168.x.x, 10.x.x.x
|
||||
# # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token.
|
||||
# # See README.md "Provider Configuration Recommendations" for Keycloak.
|
||||
|
||||
@@ -287,6 +292,26 @@ testDataWithRedis:
|
||||
# - "AppRoleName"
|
||||
# # See README.md "Provider Configuration Recommendations" for Azure AD.
|
||||
|
||||
# --- Azure AD Users Without Email Example (Issue #95) ---
|
||||
# testDataAzureADNoEmail:
|
||||
# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
# clientID: your-azure-ad-client-id
|
||||
# clientSecret: your-azure-ad-client-secret
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure"
|
||||
# # Use 'sub' claim instead of 'email' for user identification
|
||||
# userIdentifierClaim: sub # or "oid", "upn", "preferred_username"
|
||||
# overrideScopes: true # Remove email scope if not needed
|
||||
# scopes:
|
||||
# - openid
|
||||
# - profile
|
||||
# - groups # For group-based access control
|
||||
# # When using non-email identifiers, allowedUsers matches against the claim value
|
||||
# allowedUsers:
|
||||
# - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID (sub or oid claim)
|
||||
# # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
# # See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
|
||||
# --- Google Workspace / Google Cloud Identity Example ---
|
||||
# testDataGoogle:
|
||||
# providerURL: https://accounts.google.com # Standard Google OIDC endpoint
|
||||
@@ -605,6 +630,38 @@ configuration:
|
||||
items:
|
||||
type: string
|
||||
|
||||
userIdentifierClaim:
|
||||
type: string
|
||||
description: |
|
||||
Specifies the JWT claim to use as the user identifier for authentication and authorization.
|
||||
|
||||
This allows authentication for users without email addresses, such as Azure AD service
|
||||
accounts or organizational accounts that don't have email attributes configured.
|
||||
|
||||
When set to a non-email claim (e.g., "sub", "oid", "upn"):
|
||||
- AllowedUsers will match against this claim value instead of email
|
||||
- AllowedUserDomains validation is skipped (domains only apply to email addresses)
|
||||
- The session stores this identifier as the user's identity
|
||||
- If the configured claim is missing, falls back to "sub" (required by OIDC spec)
|
||||
|
||||
Common values by provider:
|
||||
- Default: "email" (standard email-based identification)
|
||||
- Azure AD: "sub", "oid" (object ID), "upn" (User Principal Name), "preferred_username"
|
||||
- Generic OIDC: "sub" (always present per OIDC specification)
|
||||
- Keycloak: "sub", "preferred_username"
|
||||
|
||||
Example for Azure AD users without email:
|
||||
```yaml
|
||||
userIdentifierClaim: sub
|
||||
allowedUsers:
|
||||
- "abc123-user-object-id"
|
||||
- "xyz789-another-user-id"
|
||||
```
|
||||
|
||||
Default: "email"
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
required: false
|
||||
|
||||
revocationURL:
|
||||
type: string
|
||||
description: |
|
||||
@@ -903,6 +960,67 @@ configuration:
|
||||
Default: false (replay detection enabled)
|
||||
required: false
|
||||
|
||||
allowPrivateIPAddresses:
|
||||
type: boolean
|
||||
description: |
|
||||
Allow private IP addresses in OIDC provider URLs for internal network deployments.
|
||||
|
||||
By default, the plugin blocks URLs containing private IP address ranges
|
||||
(10.x.x.x, 172.16-31.x.x, 192.168.x.x) to prevent SSRF attacks and ensure
|
||||
OIDC providers are publicly accessible.
|
||||
|
||||
Enable this option when:
|
||||
- Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
|
||||
- You don't have DNS resolution available for internal services
|
||||
- Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
|
||||
|
||||
When enabled, the plugin will accept provider URLs like:
|
||||
- https://192.168.1.100:8443/auth/realms/your-realm
|
||||
- https://10.0.0.50:8080/realms/master
|
||||
- https://172.16.0.10/auth
|
||||
|
||||
Security Warning:
|
||||
Enabling this option reduces SSRF protection. Only use in trusted network
|
||||
environments where the OIDC provider is known and controlled. Loopback
|
||||
addresses (127.0.0.1, localhost, ::1) remain blocked even with this option enabled.
|
||||
|
||||
Default: false (private IPs are blocked for security)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/97
|
||||
required: false
|
||||
|
||||
minimalHeaders:
|
||||
type: boolean
|
||||
description: |
|
||||
Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors.
|
||||
|
||||
When enabled, the middleware only forwards the X-Forwarded-User header and skips
|
||||
the larger authentication headers that can cause downstream services to reject
|
||||
requests due to header size limits (typically 8KB).
|
||||
|
||||
Headers when disabled (default):
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-Auth-Request-Redirect: Original request URI
|
||||
- X-Auth-Request-User: User's email address
|
||||
- X-Auth-Request-Token: Full ID token (can be very large with many claims)
|
||||
- X-User-Groups: Comma-separated user groups (if configured)
|
||||
- X-User-Roles: Comma-separated user roles (if configured)
|
||||
|
||||
Headers when enabled:
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-User-Groups: Comma-separated user groups (if configured, still forwarded)
|
||||
- X-User-Roles: Comma-separated user roles (if configured, still forwarded)
|
||||
- Custom templated headers (still processed)
|
||||
|
||||
Use this option when:
|
||||
- Downstream services return "431 Request Header Fields Too Large" errors
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- You don't need the full ID token forwarded to backend services
|
||||
- You want to reduce request overhead
|
||||
|
||||
Default: false (all headers forwarded for backward compatibility)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
|
||||
required: false
|
||||
|
||||
headers:
|
||||
type: array
|
||||
description: |
|
||||
|
||||
@@ -124,6 +124,7 @@ The middleware supports the following configuration options:
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `roleClaimName` | JWT claim name for extracting user roles (supports namespaced claims for Auth0) | `"roles"` | `"https://myapp.com/roles"`, `"user_roles"` |
|
||||
| `groupClaimName` | JWT claim name for extracting user groups (supports namespaced claims for Auth0) | `"groups"` | `"https://myapp.com/groups"`, `"user_groups"` |
|
||||
| `userIdentifierClaim` | JWT claim to use as user identifier (for users without email, e.g., Azure AD service accounts) | `"email"` | `"sub"`, `"oid"`, `"upn"`, `"preferred_username"` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
@@ -138,6 +139,8 @@ The middleware supports the following configuration options:
|
||||
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
|
||||
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
|
||||
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
|
||||
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
|
||||
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
|
||||
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
|
||||
|
||||
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
|
||||
@@ -1241,6 +1244,45 @@ spec:
|
||||
- "AppRoleName" # Application role names
|
||||
```
|
||||
|
||||
### Azure AD Configuration (Users Without Email)
|
||||
|
||||
For Azure AD users without email addresses (service accounts, organizational accounts without mail attributes):
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure-no-email
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
clientID: your-azure-ad-client-id
|
||||
clientSecret: your-azure-ad-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
|
||||
# Use 'sub' instead of 'email' for user identification
|
||||
userIdentifierClaim: sub # Can also use: "oid", "upn", "preferred_username"
|
||||
|
||||
overrideScopes: true # Optional: Don't request email scope if not needed
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- groups
|
||||
|
||||
# When using non-email identifiers, allowedUsers matches against the claim value
|
||||
allowedUsers:
|
||||
- "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID
|
||||
- "def67890-1234-5678-90ab-cdef12345678"
|
||||
|
||||
# NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
```
|
||||
|
||||
> **Note**: When `userIdentifierClaim` is set to a non-email claim (like `sub`, `oid`, or `upn`), the `allowedUserDomains` configuration is ignored since domain-based validation only applies to email addresses. Use `allowedUsers` with the actual claim values instead.
|
||||
|
||||
### Auth0 Configuration
|
||||
|
||||
```yaml
|
||||
@@ -1327,8 +1369,12 @@ spec:
|
||||
- admin
|
||||
- editor
|
||||
# Ensure Keycloak client mappers add necessary claims to ID Token
|
||||
# For internal Keycloak deployments with private IPs (e.g., Docker network):
|
||||
# allowPrivateIPAddresses: true
|
||||
```
|
||||
|
||||
> **Internal Network Deployment**: If your Keycloak runs on an internal network with private IP addresses (e.g., `192.168.x.x`, `10.x.x.x`, `172.16-31.x.x`) and you don't have DNS resolution available, set `allowPrivateIPAddresses: true` to allow the plugin to connect to your Keycloak instance. See [Issue #97](https://github.com/lukaszraczylo/traefikoidc/issues/97) for details.
|
||||
|
||||
### AWS Cognito Configuration
|
||||
|
||||
```yaml
|
||||
@@ -1629,12 +1675,39 @@ headers:
|
||||
|
||||
When a user is authenticated, the middleware sets the following headers for downstream services:
|
||||
|
||||
- `X-Forwarded-User`: The user's email address
|
||||
- `X-Forwarded-User`: The user's email address (always set)
|
||||
- `X-User-Groups`: Comma-separated list of user groups (if available)
|
||||
- `X-User-Roles`: Comma-separated list of user roles (if available)
|
||||
- `X-Auth-Request-Redirect`: The original request URI
|
||||
- `X-Auth-Request-User`: The user's email address
|
||||
- `X-Auth-Request-Token`: The user's access token
|
||||
- `X-Auth-Request-Token`: The user's ID token (can be large)
|
||||
|
||||
#### Minimal Headers Mode
|
||||
|
||||
If your downstream services return **"431 Request Header Fields Too Large"** errors, you can enable minimal headers mode to reduce header overhead:
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
my-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
minimalHeaders: true
|
||||
# ... other config
|
||||
```
|
||||
|
||||
When `minimalHeaders: true` is set:
|
||||
- **Only forwards**: `X-Forwarded-User`
|
||||
- **Skips**: `X-Auth-Request-Token` (the full ID token - often the largest header), `X-Auth-Request-User`, `X-Auth-Request-Redirect`
|
||||
- **Still forwards**: `X-User-Groups` and `X-User-Roles` (if configured)
|
||||
- **Still processes**: Custom templated headers
|
||||
|
||||
This is particularly useful when:
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- Downstream services have limited header buffer sizes (default 8KB in many servers)
|
||||
- You don't need the full token forwarded to backend services
|
||||
|
||||
See [GitHub Issue #64](https://github.com/lukaszraczylo/traefikoidc/issues/64) for details.
|
||||
|
||||
### Security Headers
|
||||
|
||||
@@ -1862,6 +1935,15 @@ logLevel: debug
|
||||
- No refresh tokens (re-authentication required on expiry)
|
||||
- Use only for GitHub API access, not user authentication
|
||||
|
||||
15. **Environment variable names containing "API" cause plugin failure** ([Issue #98](https://github.com/lukaszraczylo/traefikoidc/issues/98)):
|
||||
- When using environment variable syntax like `${OIDC_ENCRYPTION_SECRET_API}` in Traefik configuration, the plugin fails with "invalid handler type: \<nil\>" error
|
||||
- This is a **Traefik-side issue**, not a plugin bug. Traefik uses reserved environment variables starting with `TRAEFIK_API_*` for its internal API configuration, and the "API" substring in user-defined variable names may interfere with Traefik's environment variable processing
|
||||
- **Workaround**: Avoid using "API" as a substring in environment variable names. Use alternatives like:
|
||||
- `${OIDC_ENCRYPTION_SECRET_SVC}` instead of `${OIDC_ENCRYPTION_SECRET_API}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_SERVICE}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_BACKEND}`
|
||||
- Any name that doesn't contain the literal substring "API"
|
||||
|
||||
### Provider Warnings and Recommendations
|
||||
|
||||
The middleware includes built-in warnings for provider-specific limitations. Check your logs for important notices about:
|
||||
|
||||
+21
-20
@@ -849,26 +849,27 @@ func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
customAudience := "https://api.company.com"
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://auth.company.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
audience: customAudience, // Set custom audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
||||
excludedURLs: map[string]struct{}{},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
extractClaimsFunc: extractClaims,
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://auth.company.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
audience: customAudience, // Set custom audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
||||
userIdentifierClaim: "email", // Required for user identification
|
||||
excludedURLs: map[string]struct{}{},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
+29
-25
@@ -18,17 +18,18 @@ type ScopeFilter interface {
|
||||
|
||||
// Handler provides core authentication functionality for OIDC flows
|
||||
type Handler struct {
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
scopeFilter ScopeFilter // NEW
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
scopeFilter ScopeFilter // NEW
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks)
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
@@ -40,19 +41,20 @@ type Logger interface {
|
||||
// NewAuthHandler creates a new Handler instance
|
||||
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
|
||||
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool,
|
||||
scopeFilter ScopeFilter, scopesSupported []string) *Handler {
|
||||
scopeFilter ScopeFilter, scopesSupported []string, allowPrivateIPAddresses bool) *Handler {
|
||||
return &Handler{
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
scopeFilter: scopeFilter, // NEW
|
||||
scopesSupported: scopesSupported, // NEW
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
scopeFilter: scopeFilter,
|
||||
scopesSupported: scopesSupported,
|
||||
allowPrivateIPAddresses: allowPrivateIPAddresses,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,6 +349,7 @@ func (h *Handler) validateParsedURL(u *url.URL) error {
|
||||
|
||||
// validateHost validates a hostname for security and reachability.
|
||||
// It prevents access to private networks and localhost addresses.
|
||||
// When allowPrivateIPAddresses is enabled, private IP checks are skipped.
|
||||
func (h *Handler) validateHost(host string) error {
|
||||
if host == "" {
|
||||
return fmt.Errorf("empty host")
|
||||
@@ -361,7 +364,7 @@ func (h *Handler) validateHost(host string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for localhost variations
|
||||
// Check for localhost variations (always blocked, even with allowPrivateIPAddresses)
|
||||
localhostVariations := []string{
|
||||
"localhost", "127.0.0.1", "::1", "0.0.0.0",
|
||||
}
|
||||
@@ -376,7 +379,8 @@ func (h *Handler) validateHost(host string) error {
|
||||
if ip.IsLoopback() {
|
||||
return fmt.Errorf("loopback IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
// Skip private IP check if allowPrivateIPAddresses is enabled
|
||||
if !h.allowPrivateIPAddresses && ip.IsPrivate() {
|
||||
return fmt.Errorf("private IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsLinkLocalUnicast() {
|
||||
|
||||
+25
-25
@@ -86,7 +86,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, true, isGoogleProv, isAzureProv,
|
||||
"test-client-id", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, nil, nil)
|
||||
scopes, false, nil, nil, false)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
@@ -125,7 +125,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{redirectCount: 5} // At the limit
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -160,7 +160,7 @@ func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -191,7 +191,7 @@ func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -222,7 +222,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T)
|
||||
func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -253,7 +253,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T)
|
||||
func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{saveError: &testError{"save failed"}}
|
||||
req := httptest.NewRequest("GET", "/test?param=value", nil)
|
||||
@@ -297,7 +297,7 @@ func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_Success(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/protected/resource", nil)
|
||||
@@ -400,7 +400,7 @@ func TestAuthHandler_BuildAuthURL_GoogleProvider(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
|
||||
"google-client", "https://accounts.google.com/oauth2/auth", "https://accounts.google.com",
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil)
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -440,7 +440,7 @@ func TestAuthHandler_BuildAuthURL_AzureProvider(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
|
||||
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil)
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -468,7 +468,7 @@ func TestAuthHandler_BuildAuthURL_PKCEEnabled(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"pkce-client", "https://example.com/auth", "https://example.com",
|
||||
[]string{"openid"}, false, nil, nil)
|
||||
[]string{"openid"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
@@ -493,7 +493,7 @@ func TestAuthHandler_BuildAuthURL_PKCEDisabled(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"no-pkce-client", "https://example.com/auth", "https://example.com",
|
||||
[]string{"openid"}, false, nil, nil)
|
||||
[]string{"openid"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
@@ -565,7 +565,7 @@ func TestAuthHandler_BuildAuthURL_ScopeHandling(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return tt.isAzure },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
tt.scopes, tt.overrideScopes, nil, nil)
|
||||
tt.scopes, tt.overrideScopes, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -634,7 +634,7 @@ func TestAuthHandler_BuildAuthURL_WithScopeFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -676,7 +676,7 @@ func TestAuthHandler_BuildAuthURL_WithoutScopeFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, nil, nil)
|
||||
scopes, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -714,7 +714,7 @@ func TestAuthHandler_BuildAuthURL_GitLabFiltersOfflineAccess(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"gitlab-client", "https://gitlab.example.com/oauth/authorize",
|
||||
"https://gitlab.example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -756,7 +756,7 @@ func TestAuthHandler_BuildAuthURL_GoogleRemovesOfflineAccess(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
|
||||
"google-client", "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"https://accounts.google.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -797,7 +797,7 @@ func TestAuthHandler_BuildAuthURL_AzureAddsOfflineAccess(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
|
||||
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -831,7 +831,7 @@ func TestAuthHandler_BuildAuthURL_GenericWithFiltering(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"generic-client", "https://auth.provider.com/authorize",
|
||||
"https://auth.provider.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -870,7 +870,7 @@ func TestAuthHandler_BuildAuthURL_OverrideScopesWithFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, true, scopeFilter, scopesSupported)
|
||||
scopes, true, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -916,7 +916,7 @@ func TestAuthHandler_BuildAuthURL_DoubleFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -955,7 +955,7 @@ func TestAuthHandler_BuildAuthURL_NoScopeFilterProvided(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, nil, scopesSupported) // scopeFilter is nil
|
||||
scopes, false, nil, scopesSupported, false) // scopeFilter is nil
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -988,7 +988,7 @@ func TestAuthHandler_BuildAuthURL_EmptyScopesSupported(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -1021,7 +1021,7 @@ func TestAuthHandler_BuildAuthURL_FilteringWithPKCE(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
@@ -1064,7 +1064,7 @@ func TestAuthHandler_BuildAuthURL_ComplexScenario(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"complex-client", "https://auth.complex.com/authorize", "https://auth.complex.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "state-123", "nonce-456", "challenge-789")
|
||||
|
||||
@@ -1130,7 +1130,7 @@ func TestAuthHandler_BuildAuthURL_LoggingVerification(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
|
||||
+103
-5
@@ -10,7 +10,7 @@ import (
|
||||
func TestAuthHandler_validateURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -185,7 +185,7 @@ func TestAuthHandler_validateURL(t *testing.T) {
|
||||
func TestAuthHandler_validateHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -334,7 +334,7 @@ func TestAuthHandler_validateHost(t *testing.T) {
|
||||
func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -438,7 +438,7 @@ func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
// Test special characters that need encoding
|
||||
params := url.Values{
|
||||
@@ -477,7 +477,7 @@ func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
func TestAuthHandler_validateParsedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -560,3 +560,101 @@ func TestAuthHandler_validateParsedURL(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_validateHost_AllowPrivateIPAddresses tests the allowPrivateIPAddresses flag
|
||||
func TestAuthHandler_validateHost_AllowPrivateIPAddresses(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
// Test with allowPrivateIPAddresses = false (default)
|
||||
t.Run("Private IPs blocked by default", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
privateIPs := []string{
|
||||
"192.168.1.1",
|
||||
"10.0.0.1",
|
||||
"172.16.0.1",
|
||||
"172.31.255.255",
|
||||
}
|
||||
|
||||
for _, ip := range privateIPs {
|
||||
err := handler.validateHost(ip)
|
||||
if err == nil {
|
||||
t.Errorf("Expected private IP %s to be blocked, but it was allowed", ip)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "private IP not allowed") {
|
||||
t.Errorf("Expected 'private IP not allowed' error for %s, got: %v", ip, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test with allowPrivateIPAddresses = true
|
||||
t.Run("Private IPs allowed when flag enabled", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
privateIPs := []string{
|
||||
"192.168.1.1",
|
||||
"10.0.0.1",
|
||||
"172.16.0.1",
|
||||
"172.31.255.255",
|
||||
}
|
||||
|
||||
for _, ip := range privateIPs {
|
||||
err := handler.validateHost(ip)
|
||||
if err != nil {
|
||||
t.Errorf("Expected private IP %s to be allowed with flag enabled, but got error: %v", ip, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test that loopback is still blocked even with flag enabled
|
||||
t.Run("Loopback always blocked", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
loopbackAddresses := []string{
|
||||
"127.0.0.1",
|
||||
"localhost",
|
||||
"::1",
|
||||
"0.0.0.0",
|
||||
}
|
||||
|
||||
for _, addr := range loopbackAddresses {
|
||||
err := handler.validateHost(addr)
|
||||
if err == nil {
|
||||
t.Errorf("Expected loopback address %s to be blocked even with allowPrivateIPAddresses=true", addr)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test that link-local is still blocked even with flag enabled
|
||||
t.Run("Link-local always blocked", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
err := handler.validateHost("169.254.1.1")
|
||||
if err == nil {
|
||||
t.Error("Expected link-local address to be blocked even with allowPrivateIPAddresses=true")
|
||||
}
|
||||
})
|
||||
|
||||
// Test that public IPs work with flag enabled
|
||||
t.Run("Public IPs allowed", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
publicIPs := []string{
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
"142.250.185.68",
|
||||
}
|
||||
|
||||
for _, ip := range publicIPs {
|
||||
err := handler.validateHost(ip)
|
||||
if err != nil {
|
||||
t.Errorf("Expected public IP %s to be allowed, but got error: %v", ip, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+19
-9
@@ -223,15 +223,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
|
||||
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
// Try "sub" as fallback since it's required by OIDC spec
|
||||
if t.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim)
|
||||
}
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
|
||||
// Validate user authorization
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Errorf("User not authorized during callback: %s", userIdentifier)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -240,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
@@ -437,6 +437,21 @@ http:
|
||||
4. Configure client scopes and mappers
|
||||
5. Generate client secret in Credentials tab
|
||||
|
||||
### Internal Network Deployment
|
||||
|
||||
If your Keycloak instance runs on an internal network with private IP addresses (e.g., Docker networks, Kubernetes internal services), set `allowPrivateIPAddresses: true`:
|
||||
|
||||
```yaml
|
||||
traefikoidc:
|
||||
providerUrl: "https://192.168.1.100:8443/auth/realms/your-realm" # Private IP
|
||||
allowPrivateIPAddresses: true # Required for private IP addresses
|
||||
clientId: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
# ... other config
|
||||
```
|
||||
|
||||
> **Security Warning**: Only enable `allowPrivateIPAddresses` in trusted network environments where you control the OIDC provider. This setting reduces SSRF protection.
|
||||
|
||||
---
|
||||
|
||||
## Okta
|
||||
|
||||
@@ -2,10 +2,14 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -411,6 +415,31 @@ func DefaultRetryConfig() RetryConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata
|
||||
// fetching during startup. Uses more aggressive retry settings to handle the race
|
||||
// condition where Traefik initializes the plugin before routes are fully established,
|
||||
// or before TLS certificates are properly loaded.
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func MetadataFetchRetryConfig() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 10, // More attempts for startup scenarios
|
||||
InitialDelay: 1 * time.Second, // 1 second between attempts as suggested
|
||||
MaxDelay: 10 * time.Second, // Cap at 10 seconds
|
||||
BackoffFactor: 1.5, // Gentler backoff for startup
|
||||
EnableJitter: true, // Prevent thundering herd
|
||||
RetryableErrors: []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network unreachable",
|
||||
"EOF",
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RetryExecutor implements retry logic with exponential backoff and jitter.
|
||||
// It automatically retries failed operations based on configurable error patterns
|
||||
// and uses exponential backoff to avoid overwhelming failing services.
|
||||
@@ -487,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
|
||||
// isRetryableError checks if an error should trigger a retry
|
||||
// isRetryableError determines if an error should trigger a retry attempt.
|
||||
// Checks error message against configured retryable error patterns.
|
||||
// Also handles startup-specific errors like Traefik default certificate errors
|
||||
// and EOF errors that occur during service initialization.
|
||||
func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for Traefik default certificate error (startup race condition)
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
if isTraefikDefaultCertError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for EOF errors (common during startup when services aren't ready)
|
||||
if isEOFError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for certificate errors (transient during startup)
|
||||
if isCertificateError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
for _, retryableErr := range re.config.RetryableErrors {
|
||||
@@ -1088,3 +1135,86 @@ func containsSubstring(s, substr string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTraefikDefaultCertError detects when Traefik is serving its default self-signed
|
||||
// certificate during cold-start, before the real certificates are loaded.
|
||||
// This manifests as an x509.HostnameError where one of the certificate's DNS names
|
||||
// ends with "traefik.default" (the default Traefik certificate pattern).
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func isTraefikDefaultCertError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var hostnameErr x509.HostnameError
|
||||
if errors.As(err, &hostnameErr) {
|
||||
if hostnameErr.Certificate != nil {
|
||||
for _, name := range hostnameErr.Certificate.DNSNames {
|
||||
if strings.HasSuffix(name, "traefik.default") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isEOFError checks if an error is an EOF error, which can occur during
|
||||
// connection establishment when the remote end closes unexpectedly.
|
||||
// This is common during service startup when endpoints aren't fully ready.
|
||||
func isEOFError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for direct EOF
|
||||
if errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for unexpected EOF
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for EOF patterns (wrapped errors)
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF")
|
||||
}
|
||||
|
||||
// isCertificateError checks if an error is related to TLS certificate validation.
|
||||
// These errors are often transient during startup when services are still initializing.
|
||||
func isCertificateError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for x509 certificate errors
|
||||
var certInvalidErr x509.CertificateInvalidError
|
||||
var hostnameErr x509.HostnameError
|
||||
var unknownAuthErr x509.UnknownAuthorityError
|
||||
|
||||
if errors.As(err, &certInvalidErr) ||
|
||||
errors.As(err, &hostnameErr) ||
|
||||
errors.As(err, &unknownAuthErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for certificate patterns
|
||||
errStr := strings.ToLower(err.Error())
|
||||
certPatterns := []string{
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
"ssl",
|
||||
}
|
||||
|
||||
for _, pattern := range certPatterns {
|
||||
if strings.Contains(errStr, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -846,3 +846,296 @@ func (e *mockNetError) Temporary() bool { return e.temporary }
|
||||
|
||||
// Ensure mockNetError implements net.Error
|
||||
var _ net.Error = (*mockNetError)(nil)
|
||||
|
||||
// Test isTraefikDefaultCertError
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
|
||||
func TestIsTraefikDefaultCertError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "regular error",
|
||||
err: errors.New("some error"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "network error",
|
||||
err: &mockNetError{msg: "connection refused"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isTraefikDefaultCertError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isTraefikDefaultCertError() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test isEOFError
|
||||
|
||||
func TestIsEOFError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "regular error",
|
||||
err: errors.New("some error"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "error containing EOF in message",
|
||||
err: errors.New("connection closed: EOF"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing unexpected EOF",
|
||||
err: errors.New("read: unexpected EOF"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "network error without EOF",
|
||||
err: &mockNetError{msg: "connection refused"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isEOFError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isEOFError() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test isCertificateError
|
||||
|
||||
func TestIsCertificateError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "regular error",
|
||||
err: errors.New("some error"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "error containing certificate in message",
|
||||
err: errors.New("tls: failed to verify certificate"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing x509 in message",
|
||||
err: errors.New("x509: certificate signed by unknown authority"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing tls in message",
|
||||
err: errors.New("tls handshake failed"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing ssl in message",
|
||||
err: errors.New("ssl connection error"),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isCertificateError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isCertificateError() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test MetadataFetchRetryConfig
|
||||
|
||||
func TestMetadataFetchRetryConfig(t *testing.T) {
|
||||
config := MetadataFetchRetryConfig()
|
||||
|
||||
if config.MaxAttempts != 10 {
|
||||
t.Errorf("Expected MaxAttempts 10, got %d", config.MaxAttempts)
|
||||
}
|
||||
|
||||
if config.InitialDelay != 1*time.Second {
|
||||
t.Errorf("Expected InitialDelay 1s, got %v", config.InitialDelay)
|
||||
}
|
||||
|
||||
if config.MaxDelay != 10*time.Second {
|
||||
t.Errorf("Expected MaxDelay 10s, got %v", config.MaxDelay)
|
||||
}
|
||||
|
||||
if config.BackoffFactor != 1.5 {
|
||||
t.Errorf("Expected BackoffFactor 1.5, got %v", config.BackoffFactor)
|
||||
}
|
||||
|
||||
if !config.EnableJitter {
|
||||
t.Error("Expected EnableJitter to be true")
|
||||
}
|
||||
|
||||
// Verify retryable errors include startup-related patterns
|
||||
expectedPatterns := []string{"EOF", "certificate", "x509", "tls"}
|
||||
for _, pattern := range expectedPatterns {
|
||||
found := false
|
||||
for _, retryableErr := range config.RetryableErrors {
|
||||
if retryableErr == pattern {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected '%s' in RetryableErrors", pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test RetryExecutor with startup-specific errors
|
||||
|
||||
func TestRetryExecutorStartupErrors(t *testing.T) {
|
||||
// Verify MetadataFetchRetryConfig creates a valid retry executor
|
||||
_ = NewRetryExecutor(MetadataFetchRetryConfig(), nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
shouldRetry bool
|
||||
}{
|
||||
{
|
||||
name: "EOF error",
|
||||
err: errors.New("read tcp: EOF"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "unexpected EOF",
|
||||
err: errors.New("http: unexpected EOF"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "certificate error",
|
||||
err: errors.New("x509: certificate signed by unknown authority"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "TLS error",
|
||||
err: errors.New("tls: failed to verify certificate"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "connection refused",
|
||||
err: errors.New("dial tcp: connection refused"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "permanent error",
|
||||
err: errors.New("invalid response format"),
|
||||
shouldRetry: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Use very short delays for testing
|
||||
testConfig := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Millisecond,
|
||||
BackoffFactor: 1.5,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network unreachable",
|
||||
"EOF",
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
},
|
||||
}
|
||||
testRe := NewRetryExecutor(testConfig, nil)
|
||||
|
||||
attempts := 0
|
||||
_ = testRe.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return tt.err
|
||||
})
|
||||
|
||||
expectedAttempts := 1
|
||||
if tt.shouldRetry {
|
||||
expectedAttempts = 3
|
||||
}
|
||||
|
||||
if attempts != expectedAttempts {
|
||||
t.Errorf("Expected %d attempts for '%s', got %d", expectedAttempts, tt.name, attempts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test that retry executor properly uses isRetryableError with new error types
|
||||
|
||||
func TestRetryExecutorIsRetryableErrorIntegration(t *testing.T) {
|
||||
re := NewRetryExecutor(DefaultRetryConfig(), nil)
|
||||
|
||||
// Test that the enhanced isRetryableError is being used
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
shouldRetry bool
|
||||
}{
|
||||
{
|
||||
name: "EOF in error message",
|
||||
err: errors.New("connection reset by peer: EOF"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "certificate in error message",
|
||||
err: errors.New("x509: certificate has expired"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "TLS in error message",
|
||||
err: errors.New("tls: handshake failure"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := re.isRetryableError(tt.err)
|
||||
if result != tt.shouldRetry {
|
||||
t.Errorf("isRetryableError(%q) = %v, expected %v", tt.err.Error(), result, tt.shouldRetry)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+29
-12
@@ -15,7 +15,8 @@ type OAuthHandler struct {
|
||||
tokenExchanger TokenExchanger
|
||||
tokenVerifier TokenVerifier
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
isAllowedDomainFunc func(email string) bool
|
||||
isAllowedUserFunc func(userIdentifier string) bool // validates user authorization
|
||||
userIdentifierClaim string // JWT claim to use for user identification
|
||||
redirURLPath string
|
||||
sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int)
|
||||
}
|
||||
@@ -77,16 +78,22 @@ type TokenResponse struct {
|
||||
// NewOAuthHandler creates a new OAuth handler
|
||||
func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger,
|
||||
tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error),
|
||||
isAllowedDomainFunc func(string) bool, redirURLPath string,
|
||||
isAllowedUserFunc func(string) bool, userIdentifierClaim string, redirURLPath string,
|
||||
sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler {
|
||||
|
||||
// Default to "email" for backward compatibility
|
||||
if userIdentifierClaim == "" {
|
||||
userIdentifierClaim = "email"
|
||||
}
|
||||
|
||||
return &OAuthHandler{
|
||||
logger: logger,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: tokenExchanger,
|
||||
tokenVerifier: tokenVerifier,
|
||||
extractClaimsFunc: extractClaimsFunc,
|
||||
isAllowedDomainFunc: isAllowedDomainFunc,
|
||||
isAllowedUserFunc: isAllowedUserFunc,
|
||||
userIdentifierClaim: userIdentifierClaim,
|
||||
redirURLPath: redirURLPath,
|
||||
sendErrorResponseFunc: sendErrorResponseFunc,
|
||||
}
|
||||
@@ -225,15 +232,25 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
h.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
|
||||
userIdentifier, _ := claims[h.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
// Try "sub" as fallback since it's required by OIDC spec
|
||||
if h.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
h.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", h.userIdentifierClaim)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", h.userIdentifierClaim)
|
||||
}
|
||||
if !h.isAllowedDomainFunc(email) {
|
||||
h.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
|
||||
// Validate user authorization
|
||||
if !h.isAllowedUserFunc(userIdentifier) {
|
||||
h.logger.Errorf("User not authorized during callback: %s", userIdentifier)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -242,7 +259,7 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
@@ -108,11 +108,11 @@ func TestOAuthHandler_NewOAuthHandler(t *testing.T) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
|
||||
isAllowed := func(email string) bool { return true }
|
||||
isAllowedUser := func(userIdentifier string) bool { return true }
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowedUser, "email", "/callback", sendError)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
@@ -151,7 +151,7 @@ func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -190,7 +190,7 @@ func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
// Test with error parameter
|
||||
req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil)
|
||||
@@ -230,7 +230,7 @@ func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -265,7 +265,7 @@ func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -300,7 +300,7 @@ func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -335,7 +335,7 @@ func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -370,7 +370,7 @@ func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -406,7 +406,7 @@ func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -444,7 +444,7 @@ func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -483,7 +483,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -521,7 +521,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -559,7 +559,7 @@ func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -591,13 +591,13 @@ func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) {
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email missing in token") {
|
||||
t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg)
|
||||
if !strings.Contains(msg, "User identifier missing in token") {
|
||||
t.Errorf("Expected error message to contain 'User identifier missing in token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -629,13 +629,13 @@ func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) {
|
||||
if code != http.StatusForbidden {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusForbidden, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email domain not allowed") {
|
||||
t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg)
|
||||
if !strings.Contains(msg, "User not authorized") {
|
||||
t.Errorf("Expected error message to contain 'User not authorized', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -677,7 +677,7 @@ func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -719,7 +719,7 @@ func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -760,7 +760,7 @@ func TestOAuthHandler_HandleCallback_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -843,7 +843,7 @@ func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -884,7 +884,7 @@ func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
+52
-47
@@ -15,20 +15,21 @@ import (
|
||||
// XSS, path traversal, and other injection attacks. It validates and sanitizes
|
||||
// various input types used in OIDC authentication flows.
|
||||
type InputValidator struct {
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
allowPrivateIPAddresses bool // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// ValidationResult encapsulates the outcome of input validation.
|
||||
@@ -46,13 +47,14 @@ type ValidationResult struct {
|
||||
// It specifies maximum lengths for various input types and controls whether
|
||||
// strict validation mode is enabled.
|
||||
type InputValidationConfig struct {
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
AllowPrivateIPAddresses bool `json:"allow_private_ip_addresses"` // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// DefaultInputValidationConfig returns a secure default configuration
|
||||
@@ -103,16 +105,17 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
|
||||
}
|
||||
|
||||
return &InputValidator{
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
sqlInjectionPatterns: []string{
|
||||
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
|
||||
"union", "select", "insert", "update", "delete", "drop",
|
||||
@@ -335,24 +338,26 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for private IP ranges (RFC 1918)
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
// Check for private IP ranges (RFC 1918) - skip if allowPrivateIPAddresses is enabled
|
||||
if !iv.allowPrivateIPAddresses {
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -177,6 +177,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}
|
||||
return "groups" // Backward compatible default
|
||||
}(),
|
||||
userIdentifierClaim: func() string {
|
||||
if config.UserIdentifierClaim != "" {
|
||||
return config.UserIdentifierClaim
|
||||
}
|
||||
return "email" // Backward compatible default
|
||||
}(),
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
enablePKCE: config.EnablePKCE,
|
||||
overrideScopes: config.OverrideScopes,
|
||||
@@ -218,6 +224,8 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
securityHeadersApplier: config.GetSecurityHeadersApplier(),
|
||||
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
|
||||
dcrConfig: config.DynamicClientRegistration,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
minimalHeaders: config.MinimalHeaders,
|
||||
}
|
||||
|
||||
// Log audience configuration
|
||||
|
||||
@@ -545,3 +545,168 @@ func createTestSessionManager(t *testing.T) *SessionManager {
|
||||
}
|
||||
return sm
|
||||
}
|
||||
|
||||
// TestMinimalHeaders tests the minimalHeaders configuration option
|
||||
// This addresses GitHub issue #64 - Request Header Fields Too Large
|
||||
func TestMinimalHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
minimalHeaders bool
|
||||
expectForwardedUser bool
|
||||
expectAuthRequestUser bool
|
||||
expectAuthRequestRedirect bool
|
||||
}{
|
||||
{
|
||||
name: "minimalHeaders=false (default) forwards all headers",
|
||||
minimalHeaders: false,
|
||||
expectForwardedUser: true,
|
||||
expectAuthRequestUser: true,
|
||||
expectAuthRequestRedirect: true,
|
||||
},
|
||||
{
|
||||
name: "minimalHeaders=true only forwards X-Forwarded-User",
|
||||
minimalHeaders: true,
|
||||
expectForwardedUser: true,
|
||||
expectAuthRequestUser: false,
|
||||
expectAuthRequestRedirect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Track which headers were set
|
||||
var capturedHeaders http.Header
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: tt.minimalHeaders,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
// Create request and get session properly through session manager
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set up session data
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Call processAuthorizedRequest directly
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
// Verify X-Forwarded-User is always set
|
||||
if tt.expectForwardedUser {
|
||||
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify X-Auth-Request-User
|
||||
hasAuthRequestUser := capturedHeaders.Get("X-Auth-Request-User") != ""
|
||||
if tt.expectAuthRequestUser && !hasAuthRequestUser {
|
||||
t.Error("expected X-Auth-Request-User to be set")
|
||||
}
|
||||
if !tt.expectAuthRequestUser && hasAuthRequestUser {
|
||||
t.Errorf("expected X-Auth-Request-User to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-User"))
|
||||
}
|
||||
|
||||
// Verify X-Auth-Request-Redirect
|
||||
hasAuthRequestRedirect := capturedHeaders.Get("X-Auth-Request-Redirect") != ""
|
||||
if tt.expectAuthRequestRedirect && !hasAuthRequestRedirect {
|
||||
t.Error("expected X-Auth-Request-Redirect to be set")
|
||||
}
|
||||
if !tt.expectAuthRequestRedirect && hasAuthRequestRedirect {
|
||||
t.Errorf("expected X-Auth-Request-Redirect to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Redirect"))
|
||||
}
|
||||
|
||||
// Note: X-Auth-Request-Token is only set if session.GetIDToken() returns non-empty.
|
||||
// Token storage has validation that may reject test tokens, so we verify the flag
|
||||
// logic through the other headers. The important behavior is that when
|
||||
// minimalHeaders=true, the token header would NOT be set even if a token existed.
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMinimalHeaders_TokenHeaderNotSet verifies that the X-Auth-Request-Token header
|
||||
// is NOT set when minimalHeaders is enabled, even if a token exists.
|
||||
func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
var capturedHeaders http.Header
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: true, // Enable minimal headers
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
// Verify X-Forwarded-User is set (always should be)
|
||||
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
|
||||
}
|
||||
|
||||
// The key verification: X-Auth-Request-Token should NOT be set with minimalHeaders=true
|
||||
if capturedHeaders.Get("X-Auth-Request-Token") != "" {
|
||||
t.Error("expected X-Auth-Request-Token to NOT be set with minimalHeaders=true")
|
||||
}
|
||||
|
||||
// X-Auth-Request-User should also NOT be set with minimalHeaders=true
|
||||
if capturedHeaders.Get("X-Auth-Request-User") != "" {
|
||||
t.Error("expected X-Auth-Request-User to NOT be set with minimalHeaders=true")
|
||||
}
|
||||
|
||||
// X-Auth-Request-Redirect should also NOT be set with minimalHeaders=true
|
||||
if capturedHeaders.Get("X-Auth-Request-Redirect") != "" {
|
||||
t.Error("expected X-Auth-Request-Redirect to NOT be set with minimalHeaders=true")
|
||||
}
|
||||
}
|
||||
|
||||
+243
-19
@@ -122,22 +122,23 @@ func (ts *TestSuite) Setup() {
|
||||
|
||||
// Common TraefikOidc instance
|
||||
ts.tOidc = &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles", // Set default for backward compatibility
|
||||
groupClaimName: "groups", // Set default for backward compatibility
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
revocationURL: "https://revocation-endpoint.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}},
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles", // Set default for backward compatibility
|
||||
groupClaimName: "groups", // Set default for backward compatibility
|
||||
userIdentifierClaim: "email", // Set default for backward compatibility
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
revocationURL: "https://revocation-endpoint.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}},
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
// Explicitly set paths as New() is bypassed
|
||||
redirURLPath: "/callback", // Assume default callback path for tests
|
||||
logoutURLPath: "/callback/logout", // Assume default logout path for tests
|
||||
@@ -784,7 +785,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"Accept": "application/json",
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: `{"error":"Forbidden","error_description":"Access denied: Your email domain is not allowed. To log out, visit: /callback/logout","status_code":403}`,
|
||||
expectedBody: `{"error":"Forbidden","error_description":"Access denied: You are not authorized to access this resource. To log out, visit: /callback/logout","status_code":403}`,
|
||||
},
|
||||
{
|
||||
name: "Disallowed Domain (Accept: HTML)",
|
||||
@@ -1282,8 +1283,9 @@ func TestHandleCallback(t *testing.T) {
|
||||
instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case
|
||||
}
|
||||
tOidc := &TraefikOidc{
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
logger: logger,
|
||||
userIdentifierClaim: "email", // Required for claim extraction
|
||||
// exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field
|
||||
extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function
|
||||
tokenVerifier: nil, // Will be set to self below
|
||||
@@ -1438,6 +1440,228 @@ func TestIsAllowedDomain(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAllowedUser(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
allowedDomains map[string]struct{}
|
||||
allowedUsers map[string]struct{}
|
||||
userIdentifierClaim string
|
||||
name string
|
||||
userIdentifier string
|
||||
allowed bool
|
||||
}{
|
||||
// Email-based identification (default behavior)
|
||||
{
|
||||
name: "Email identifier - allowed domain",
|
||||
userIdentifier: "user@example.com",
|
||||
userIdentifierClaim: "email",
|
||||
allowedDomains: map[string]struct{}{"example.com": {}},
|
||||
allowedUsers: map[string]struct{}{},
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "Email identifier - disallowed domain",
|
||||
userIdentifier: "user@notallowed.com",
|
||||
userIdentifierClaim: "email",
|
||||
allowedDomains: map[string]struct{}{"example.com": {}},
|
||||
allowedUsers: map[string]struct{}{},
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "Email identifier - specific user allowed",
|
||||
userIdentifier: "specific.user@otherdomain.com",
|
||||
userIdentifierClaim: "email",
|
||||
allowedDomains: map[string]struct{}{"example.com": {}},
|
||||
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
|
||||
allowed: true,
|
||||
},
|
||||
|
||||
// Non-email identifier (sub claim - for Azure AD users without email)
|
||||
{
|
||||
name: "Sub identifier - allowed in allowedUsers",
|
||||
userIdentifier: "abc12345-6789-0abc-def0-123456789abc",
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}},
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "Sub identifier - not in allowedUsers",
|
||||
userIdentifier: "xyz-not-allowed-user",
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}},
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "Sub identifier - allowedDomains ignored for non-email",
|
||||
userIdentifier: "user-id-12345",
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{"example.com": {}}, // Should be ignored
|
||||
allowedUsers: map[string]struct{}{"user-id-12345": {}},
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "Sub identifier - no restrictions allows all",
|
||||
userIdentifier: "any-user-id",
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{},
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "Sub identifier - case insensitive matching",
|
||||
userIdentifier: "ABC12345-6789-0ABC-DEF0-123456789ABC", // Uppercase
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, // Lowercase
|
||||
allowed: true,
|
||||
},
|
||||
|
||||
// OID claim (Azure AD object ID)
|
||||
{
|
||||
name: "OID identifier - allowed user",
|
||||
userIdentifier: "oid-12345-67890",
|
||||
userIdentifierClaim: "oid",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{"oid-12345-67890": {}},
|
||||
allowed: true,
|
||||
},
|
||||
|
||||
// UPN claim (Azure AD User Principal Name)
|
||||
{
|
||||
name: "UPN identifier - allowed user (looks like email but use sub logic)",
|
||||
userIdentifier: "user@tenant.onmicrosoft.com",
|
||||
userIdentifierClaim: "upn",
|
||||
allowedDomains: map[string]struct{}{"example.com": {}}, // Different domain, should be ignored
|
||||
allowedUsers: map[string]struct{}{"user@tenant.onmicrosoft.com": {}},
|
||||
allowed: true,
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "Empty identifier - not allowed",
|
||||
userIdentifier: "",
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{"some-user": {}},
|
||||
allowed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Configure TraefikOidc instance for this test case
|
||||
tOidc := ts.tOidc
|
||||
tOidc.allowedUserDomains = tc.allowedDomains
|
||||
tOidc.allowedUsers = tc.allowedUsers
|
||||
tOidc.userIdentifierClaim = tc.userIdentifierClaim
|
||||
|
||||
allowed := tOidc.isAllowedUser(tc.userIdentifier)
|
||||
if allowed != tc.allowed {
|
||||
t.Errorf("Expected allowed=%v, got %v for userIdentifier=%q with claim=%q",
|
||||
tc.allowed, allowed, tc.userIdentifier, tc.userIdentifierClaim)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserIdentifierClaimExtraction(t *testing.T) {
|
||||
// Test that the correct claim is extracted based on userIdentifierClaim config
|
||||
tests := []struct {
|
||||
name string
|
||||
userIdentifierClaim string
|
||||
claims map[string]interface{}
|
||||
expectedIdentifier string
|
||||
shouldFallbackToSub bool
|
||||
}{
|
||||
{
|
||||
name: "Email claim extraction (default)",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-sub-id",
|
||||
"email": "user@example.com",
|
||||
},
|
||||
expectedIdentifier: "user@example.com",
|
||||
shouldFallbackToSub: false,
|
||||
},
|
||||
{
|
||||
name: "Sub claim extraction",
|
||||
userIdentifierClaim: "sub",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-sub-id",
|
||||
"email": "user@example.com",
|
||||
},
|
||||
expectedIdentifier: "user-sub-id",
|
||||
shouldFallbackToSub: false,
|
||||
},
|
||||
{
|
||||
name: "OID claim extraction (Azure AD)",
|
||||
userIdentifierClaim: "oid",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-sub-id",
|
||||
"email": "user@example.com",
|
||||
"oid": "azure-object-id",
|
||||
},
|
||||
expectedIdentifier: "azure-object-id",
|
||||
shouldFallbackToSub: false,
|
||||
},
|
||||
{
|
||||
name: "UPN claim extraction (Azure AD)",
|
||||
userIdentifierClaim: "upn",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-sub-id",
|
||||
"upn": "user@tenant.onmicrosoft.com",
|
||||
},
|
||||
expectedIdentifier: "user@tenant.onmicrosoft.com",
|
||||
shouldFallbackToSub: false,
|
||||
},
|
||||
{
|
||||
name: "Fallback to sub when configured claim is missing",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "fallback-sub-id",
|
||||
// email is missing
|
||||
},
|
||||
expectedIdentifier: "fallback-sub-id",
|
||||
shouldFallbackToSub: true,
|
||||
},
|
||||
{
|
||||
name: "preferred_username claim extraction",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-sub-id",
|
||||
"preferred_username": "jdoe",
|
||||
},
|
||||
expectedIdentifier: "jdoe",
|
||||
shouldFallbackToSub: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Extract user identifier using the same logic as auth_flow.go
|
||||
userIdentifier, _ := tc.claims[tc.userIdentifierClaim].(string)
|
||||
usedFallback := false
|
||||
|
||||
if userIdentifier == "" && tc.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = tc.claims["sub"].(string)
|
||||
usedFallback = true
|
||||
}
|
||||
|
||||
if userIdentifier != tc.expectedIdentifier {
|
||||
t.Errorf("Expected identifier %q, got %q", tc.expectedIdentifier, userIdentifier)
|
||||
}
|
||||
|
||||
if usedFallback != tc.shouldFallbackToSub {
|
||||
t.Errorf("Expected fallback=%v, got %v", tc.shouldFallbackToSub, usedFallback)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCHandler(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
+49
-4
@@ -189,11 +189,56 @@ func (mc *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client
|
||||
return mc.GetProviderMetadata(ctx, providerURL, httpClient)
|
||||
}
|
||||
|
||||
// GetMetadataWithRecovery fetches metadata with recovery support
|
||||
// GetMetadataWithRecovery fetches metadata with retry support for startup scenarios.
|
||||
// This handles the race condition where Traefik initializes the plugin before the
|
||||
// OIDC provider routes are fully established, or before TLS certificates are loaded.
|
||||
// Uses aggressive retry settings (10 attempts, 1s intervals) to give the infrastructure
|
||||
// time to stabilize during cold starts.
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func (mc *MetadataCache) GetMetadataWithRecovery(providerURL string, httpClient *http.Client, logger *Logger, errorRecoveryManager *ErrorRecoveryManager) (*ProviderMetadata, error) {
|
||||
// For now, just use regular GetMetadata
|
||||
// Recovery would be handled by ErrorRecoveryManager if needed
|
||||
return mc.GetMetadata(providerURL, httpClient, logger)
|
||||
// Check cache first - if we have valid cached metadata, use it
|
||||
if metadata, exists := mc.Get(providerURL); exists {
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// Create a retry executor with metadata-fetch-specific configuration
|
||||
retryConfig := MetadataFetchRetryConfig()
|
||||
retryExecutor := NewRetryExecutor(retryConfig, logger)
|
||||
|
||||
var metadata *ProviderMetadata
|
||||
var lastErr error
|
||||
|
||||
// Use context with overall timeout for the entire retry sequence
|
||||
// 10 attempts * ~10s max delay = ~100s worst case, so use 2 minute timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
err := retryExecutor.ExecuteWithContext(ctx, func() error {
|
||||
// Create per-attempt context with shorter timeout
|
||||
attemptCtx, attemptCancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer attemptCancel()
|
||||
|
||||
var fetchErr error
|
||||
metadata, fetchErr = mc.GetProviderMetadata(attemptCtx, providerURL, httpClient)
|
||||
if fetchErr != nil {
|
||||
lastErr = fetchErr
|
||||
if logger != nil {
|
||||
logger.Debugf("Metadata fetch attempt failed: %v", fetchErr)
|
||||
}
|
||||
return fetchErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// Return the last actual error, not the retry wrapper error
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics for testing
|
||||
|
||||
+17
-14
@@ -125,12 +125,12 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
email := session.GetEmail()
|
||||
// Domain restriction check removed debug output
|
||||
if authenticated && email != "" {
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Infof("User with email %s is not from an allowed domain", email)
|
||||
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
|
||||
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
|
||||
// User authorization check
|
||||
if authenticated && userIdentifier != "" {
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Infof("User %s is not authorized", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
@@ -193,10 +193,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if refreshed {
|
||||
email = session.GetEmail()
|
||||
if email != "" && !t.isAllowedDomain(email) {
|
||||
t.logger.Infof("User with refreshed token email %s is not from an allowed domain", email)
|
||||
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
|
||||
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
|
||||
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
@@ -308,10 +308,13 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.headerTemplates) > 0 {
|
||||
|
||||
@@ -41,6 +41,7 @@ type AuthMiddleware struct {
|
||||
goroutineWG *sync.WaitGroup
|
||||
startTokenCleanupFunc func()
|
||||
startMetadataRefreshFunc func(string)
|
||||
minimalHeaders bool
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
@@ -120,6 +121,7 @@ func NewAuthMiddleware(
|
||||
goroutineWG *sync.WaitGroup,
|
||||
startTokenCleanupFunc func(),
|
||||
startMetadataRefreshFunc func(string),
|
||||
minimalHeaders bool,
|
||||
) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
logger: logger,
|
||||
@@ -149,6 +151,7 @@ func NewAuthMiddleware(
|
||||
goroutineWG: goroutineWG,
|
||||
startTokenCleanupFunc: startTokenCleanupFunc,
|
||||
startMetadataRefreshFunc: startMetadataRefreshFunc,
|
||||
minimalHeaders: minimalHeaders,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,10 +417,14 @@ func (m *AuthMiddleware) processAuthorizedRequest(rw http.ResponseWriter, req *h
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !m.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
}
|
||||
|
||||
m.next.ServeHTTP(rw, req)
|
||||
|
||||
@@ -66,6 +66,7 @@ func TestNewAuthMiddleware(t *testing.T) {
|
||||
wg,
|
||||
startTokenCleanup,
|
||||
startMetadataRefresh,
|
||||
false, // minimalHeaders
|
||||
)
|
||||
|
||||
if m == nil {
|
||||
|
||||
@@ -802,3 +802,99 @@ func TestServeHTTP_AdditionalCoverage(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestProcessAuthorizedRequest_MinimalHeaders tests the minimalHeaders configuration
|
||||
// This addresses GitHub issue #64 - Request Header Fields Too Large
|
||||
func TestProcessAuthorizedRequest_MinimalHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
minimalHeaders bool
|
||||
expectForwardedUser bool
|
||||
expectAuthRequestUser bool
|
||||
expectAuthRequestToken bool
|
||||
expectAuthRequestRedirect bool
|
||||
}{
|
||||
{
|
||||
name: "minimalHeaders=false forwards all headers",
|
||||
minimalHeaders: false,
|
||||
expectForwardedUser: true,
|
||||
expectAuthRequestUser: true,
|
||||
expectAuthRequestToken: true,
|
||||
expectAuthRequestRedirect: true,
|
||||
},
|
||||
{
|
||||
name: "minimalHeaders=true only forwards X-Forwarded-User",
|
||||
minimalHeaders: true,
|
||||
expectForwardedUser: true,
|
||||
expectAuthRequestUser: false,
|
||||
expectAuthRequestToken: false,
|
||||
expectAuthRequestRedirect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
var capturedHeaders http.Header
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
session := &mockSessionData{
|
||||
email: "user@example.com",
|
||||
idToken: "test-id-token-that-could-be-very-large",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
|
||||
m := &AuthMiddleware{
|
||||
logger: logger,
|
||||
next: nextHandler,
|
||||
minimalHeaders: tt.minimalHeaders,
|
||||
extractGroupsAndRolesFunc: func(tokenString string) ([]string, []string, error) {
|
||||
return nil, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
m.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
// Verify X-Forwarded-User is always set
|
||||
if tt.expectForwardedUser {
|
||||
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify X-Auth-Request-User
|
||||
hasAuthRequestUser := capturedHeaders.Get("X-Auth-Request-User") != ""
|
||||
if tt.expectAuthRequestUser && !hasAuthRequestUser {
|
||||
t.Error("expected X-Auth-Request-User to be set")
|
||||
}
|
||||
if !tt.expectAuthRequestUser && hasAuthRequestUser {
|
||||
t.Errorf("expected X-Auth-Request-User to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-User"))
|
||||
}
|
||||
|
||||
// Verify X-Auth-Request-Token (the big one that causes 431 errors)
|
||||
hasAuthRequestToken := capturedHeaders.Get("X-Auth-Request-Token") != ""
|
||||
if tt.expectAuthRequestToken && !hasAuthRequestToken {
|
||||
t.Error("expected X-Auth-Request-Token to be set")
|
||||
}
|
||||
if !tt.expectAuthRequestToken && hasAuthRequestToken {
|
||||
t.Errorf("expected X-Auth-Request-Token to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Token"))
|
||||
}
|
||||
|
||||
// Verify X-Auth-Request-Redirect
|
||||
hasAuthRequestRedirect := capturedHeaders.Get("X-Auth-Request-Redirect") != ""
|
||||
if tt.expectAuthRequestRedirect && !hasAuthRequestRedirect {
|
||||
t.Error("expected X-Auth-Request-Redirect to be set")
|
||||
}
|
||||
if !tt.expectAuthRequestRedirect && hasAuthRequestRedirect {
|
||||
t.Errorf("expected X-Auth-Request-Redirect to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Redirect"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+53
@@ -127,10 +127,63 @@ type Config struct {
|
||||
// Default: "groups"
|
||||
GroupClaimName string `json:"groupClaimName,omitempty"`
|
||||
|
||||
// UserIdentifierClaim specifies the JWT claim to use as the user identifier.
|
||||
// This allows authentication for users without email addresses (e.g., Azure AD service accounts).
|
||||
//
|
||||
// Examples:
|
||||
// - Default (backward compatible): "email"
|
||||
// - Azure AD without email: "sub", "oid", "upn", or "preferred_username"
|
||||
// - Generic OIDC: "sub" (always present per OIDC spec)
|
||||
//
|
||||
// When set to a non-email claim:
|
||||
// - AllowedUsers will match against this claim value instead of email
|
||||
// - AllowedUserDomains validation is skipped (domains only apply to email)
|
||||
// - The session will store this identifier as the user's identity
|
||||
//
|
||||
// Default: "email"
|
||||
UserIdentifierClaim string `json:"userIdentifierClaim,omitempty"`
|
||||
|
||||
// DynamicClientRegistration enables OIDC Dynamic Client Registration (RFC 7591)
|
||||
// When enabled, the middleware will automatically register as a client with
|
||||
// the OIDC provider if ClientID/ClientSecret are not provided.
|
||||
DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"`
|
||||
|
||||
// AllowPrivateIPAddresses disables the security check that blocks private/internal IP addresses.
|
||||
// By default, the plugin rejects URLs containing private IP ranges (10.x.x.x, 172.16-31.x.x, 192.168.x.x)
|
||||
// to prevent SSRF attacks and ensure OIDC providers are publicly accessible.
|
||||
//
|
||||
// Enable this option ONLY when:
|
||||
// - Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
|
||||
// - You have no DNS resolution available for internal services
|
||||
// - Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
|
||||
//
|
||||
// Security Warning: Enabling this option reduces SSRF protection. Only use in trusted
|
||||
// network environments where the OIDC provider is known and controlled.
|
||||
//
|
||||
// Default: false (private IPs are blocked for security)
|
||||
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
|
||||
|
||||
// MinimalHeaders reduces the number of headers forwarded to downstream services.
|
||||
// This helps prevent "431 Request Header Fields Too Large" errors when downstream
|
||||
// services have limited header buffer sizes.
|
||||
//
|
||||
// When enabled (true):
|
||||
// - Only forwards: X-Forwarded-User
|
||||
// - Skips: X-Auth-Request-Token (full ID token), X-Auth-Request-Redirect
|
||||
// - Groups/roles headers (X-User-Groups, X-User-Roles) are still forwarded if configured
|
||||
// - Custom templated headers are still processed
|
||||
//
|
||||
// When disabled (false, default):
|
||||
// - Forwards all headers: X-Forwarded-User, X-Auth-Request-User, X-Auth-Request-Redirect,
|
||||
// X-Auth-Request-Token (full ID token)
|
||||
//
|
||||
// Use this option when:
|
||||
// - Downstream services return "431 Request Header Fields Too Large" errors
|
||||
// - You don't need the full ID token forwarded to backend services
|
||||
// - You want to reduce request overhead
|
||||
//
|
||||
// Default: false (all headers forwarded for backward compatibility)
|
||||
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
|
||||
}
|
||||
|
||||
// RedisConfig configures Redis cache backend settings for distributed caching.
|
||||
|
||||
@@ -99,6 +99,7 @@ type TraefikOidc struct {
|
||||
audience string // Expected JWT audience, defaults to clientID
|
||||
roleClaimName string // JWT claim name for extracting roles, defaults to "roles"
|
||||
groupClaimName string // JWT claim name for extracting groups, defaults to "groups"
|
||||
userIdentifierClaim string // JWT claim for user identification, defaults to "email"
|
||||
name string
|
||||
redirURLPath string
|
||||
logoutURLPath string
|
||||
@@ -128,6 +129,8 @@ type TraefikOidc struct {
|
||||
suppressDiagnosticLogs bool
|
||||
firstRequestReceived bool
|
||||
metadataRefreshStarted bool
|
||||
allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks)
|
||||
minimalHeaders bool // Reduce headers to prevent 431 errors
|
||||
securityHeadersApplier func(http.ResponseWriter, *http.Request)
|
||||
scopeFilter *ScopeFilter // NEW - for discovery-based scope filtering
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
|
||||
+8
-1
@@ -340,6 +340,7 @@ func (t *TraefikOidc) validateParsedURL(u *url.URL) error {
|
||||
|
||||
// validateHost validates a hostname or IP address for security.
|
||||
// It prevents access to localhost, private networks, and known metadata endpoints.
|
||||
// When allowPrivateIPAddresses is enabled, private IP checks are skipped.
|
||||
// Parameters:
|
||||
// - host: The host string to validate (may include port).
|
||||
//
|
||||
@@ -357,7 +358,13 @@ func (t *TraefikOidc) validateHost(host string) error {
|
||||
|
||||
ip := net.ParseIP(hostname)
|
||||
if ip != nil {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
// Always block loopback, link-local, and multicast addresses
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return fmt.Errorf("access to loopback/link-local IP addresses is not allowed: %s", ip.String())
|
||||
}
|
||||
|
||||
// Skip private IP check if allowPrivateIPAddresses is enabled
|
||||
if !t.allowPrivateIPAddresses && ip.IsPrivate() {
|
||||
return fmt.Errorf("access to private/internal IP addresses is not allowed: %s", ip.String())
|
||||
}
|
||||
|
||||
|
||||
@@ -55,6 +55,51 @@ func (t *TraefikOidc) safeLogInfo(msg string) {
|
||||
// DOMAIN VALIDATION
|
||||
// =============================================================================
|
||||
|
||||
// isAllowedUser checks if a user identifier is authorized based on the configured user identifier claim.
|
||||
// When using email as the identifier (default), it validates against allowedUsers and allowedUserDomains.
|
||||
// When using non-email identifiers (sub, oid, upn, etc.), it only validates against allowedUsers
|
||||
// since domain-based validation doesn't apply to non-email identifiers.
|
||||
//
|
||||
// Parameters:
|
||||
// - userIdentifier: The user identifier to validate (email, sub, oid, upn, etc.).
|
||||
//
|
||||
// Returns:
|
||||
// - true if the user is authorized, false otherwise.
|
||||
func (t *TraefikOidc) isAllowedUser(userIdentifier string) bool {
|
||||
// If no restrictions are configured, allow all authenticated users
|
||||
if len(t.allowedUserDomains) == 0 && len(t.allowedUsers) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if user is explicitly allowed
|
||||
if len(t.allowedUsers) > 0 {
|
||||
_, userAllowed := t.allowedUsers[strings.ToLower(userIdentifier)]
|
||||
if userAllowed {
|
||||
t.logger.Debugf("User identifier %s is explicitly allowed in allowedUsers", userIdentifier)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// For email-based identifiers, also check domain restrictions
|
||||
// Only apply domain validation if using email as identifier AND identifier looks like an email
|
||||
if t.userIdentifierClaim == "email" && strings.Contains(userIdentifier, "@") {
|
||||
return t.isAllowedDomain(userIdentifier)
|
||||
}
|
||||
|
||||
// For non-email identifiers with allowedUserDomains configured, log a warning
|
||||
if len(t.allowedUserDomains) > 0 && t.userIdentifierClaim != "email" {
|
||||
t.logger.Debugf("AllowedUserDomains is configured but userIdentifierClaim is '%s', not 'email'. Domain validation skipped for: %s",
|
||||
t.userIdentifierClaim, userIdentifier)
|
||||
}
|
||||
|
||||
// User not found in allowedUsers list
|
||||
if len(t.allowedUsers) > 0 {
|
||||
t.logger.Debugf("User identifier %s is not in the allowed users list", userIdentifier)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isAllowedDomain checks if an email address is authorized based on domain or user whitelist.
|
||||
// It validates against both allowed user domains and specific allowed users.
|
||||
// Parameters:
|
||||
|
||||
Reference in New Issue
Block a user