diff --git a/.traefik.yml b/.traefik.yml index ddb18cc..091ff6b 100644 --- a/.traefik.yml +++ b/.traefik.yml @@ -77,6 +77,7 @@ testData: # Custom claim names for Auth0 and other providers with namespaced claims roleClaimName: roles # JWT claim name for extracting user roles (default: "roles") groupClaimName: groups # JWT claim name for extracting user groups (default: "groups") + userIdentifierClaim: email # JWT claim for user identification (default: "email", alternatives: "sub", "oid", "upn", "preferred_username") # ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.) # When NOT specified in config: defaults to FALSE (Go zero value) @@ -290,6 +291,26 @@ testDataWithRedis: # - "AppRoleName" # # See README.md "Provider Configuration Recommendations" for Azure AD. +# --- Azure AD Users Without Email Example (Issue #95) --- +# testDataAzureADNoEmail: +# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0 +# clientID: your-azure-ad-client-id +# clientSecret: your-azure-ad-client-secret +# callbackURL: /oauth2/callback +# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure" +# # Use 'sub' claim instead of 'email' for user identification +# userIdentifierClaim: sub # or "oid", "upn", "preferred_username" +# overrideScopes: true # Remove email scope if not needed +# scopes: +# - openid +# - profile +# - groups # For group-based access control +# # When using non-email identifiers, allowedUsers matches against the claim value +# allowedUsers: +# - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID (sub or oid claim) +# # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email" +# # See: https://github.com/lukaszraczylo/traefikoidc/issues/95 + # --- Google Workspace / Google Cloud Identity Example --- # testDataGoogle: # providerURL: https://accounts.google.com # Standard Google OIDC endpoint @@ -608,6 +629,38 @@ configuration: items: type: string + userIdentifierClaim: + type: string + description: | + Specifies the JWT claim to use as the user identifier for authentication and authorization. + + This allows authentication for users without email addresses, such as Azure AD service + accounts or organizational accounts that don't have email attributes configured. + + When set to a non-email claim (e.g., "sub", "oid", "upn"): + - AllowedUsers will match against this claim value instead of email + - AllowedUserDomains validation is skipped (domains only apply to email addresses) + - The session stores this identifier as the user's identity + - If the configured claim is missing, falls back to "sub" (required by OIDC spec) + + Common values by provider: + - Default: "email" (standard email-based identification) + - Azure AD: "sub", "oid" (object ID), "upn" (User Principal Name), "preferred_username" + - Generic OIDC: "sub" (always present per OIDC specification) + - Keycloak: "sub", "preferred_username" + + Example for Azure AD users without email: + ```yaml + userIdentifierClaim: sub + allowedUsers: + - "abc123-user-object-id" + - "xyz789-another-user-id" + ``` + + Default: "email" + See: https://github.com/lukaszraczylo/traefikoidc/issues/95 + required: false + revocationURL: type: string description: | diff --git a/README.md b/README.md index e369dc6..82ac10f 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,7 @@ The middleware supports the following configuration options: | `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` | | `roleClaimName` | JWT claim name for extracting user roles (supports namespaced claims for Auth0) | `"roles"` | `"https://myapp.com/roles"`, `"user_roles"` | | `groupClaimName` | JWT claim name for extracting user groups (supports namespaced claims for Auth0) | `"groups"` | `"https://myapp.com/groups"`, `"user_groups"` | +| `userIdentifierClaim` | JWT claim to use as user identifier (for users without email, e.g., Azure AD service accounts) | `"email"` | `"sub"`, `"oid"`, `"upn"`, `"preferred_username"` | | `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` | | `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` | | `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` | @@ -1242,6 +1243,45 @@ spec: - "AppRoleName" # Application role names ``` +### Azure AD Configuration (Users Without Email) + +For Azure AD users without email addresses (service accounts, organizational accounts without mail attributes): + +```yaml +apiVersion: traefik.io/v1alpha1 +kind: Middleware +metadata: + name: oidc-azure-no-email + namespace: traefik +spec: + plugin: + traefikoidc: + providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0 + clientID: your-azure-ad-client-id + clientSecret: your-azure-ad-client-secret + sessionEncryptionKey: your-secure-encryption-key-min-32-chars + callbackURL: /oauth2/callback + logoutURL: /oauth2/logout + + # Use 'sub' instead of 'email' for user identification + userIdentifierClaim: sub # Can also use: "oid", "upn", "preferred_username" + + overrideScopes: true # Optional: Don't request email scope if not needed + scopes: + - openid + - profile + - groups + + # When using non-email identifiers, allowedUsers matches against the claim value + allowedUsers: + - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID + - "def67890-1234-5678-90ab-cdef12345678" + + # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email" +``` + +> **Note**: When `userIdentifierClaim` is set to a non-email claim (like `sub`, `oid`, or `upn`), the `allowedUserDomains` configuration is ignored since domain-based validation only applies to email addresses. Use `allowedUsers` with the actual claim values instead. + ### Auth0 Configuration ```yaml diff --git a/audience_validation_test.go b/audience_validation_test.go index 8e07184..7fe7aa7 100644 --- a/audience_validation_test.go +++ b/audience_validation_test.go @@ -849,26 +849,27 @@ func TestAudienceEndToEndScenario(t *testing.T) { customAudience := "https://api.company.com" tOidc := &TraefikOidc{ - next: nextHandler, - name: "test", - redirURLPath: "/callback", - logoutURLPath: "/callback/logout", - issuerURL: "https://auth.company.com", - clientID: "test-client-id", - clientSecret: "test-client-secret", - audience: customAudience, // Set custom audience - jwkCache: mockJWKCache, - jwksURL: "https://auth.company.com/.well-known/jwks.json", - tokenBlacklist: tokenBlacklist, - tokenCache: tokenCache, - limiter: rate.NewLimiter(rate.Every(time.Second), 10), - logger: logger, - allowedUserDomains: map[string]struct{}{"company.com": {}}, - excludedURLs: map[string]struct{}{}, - httpClient: &http.Client{}, - initComplete: make(chan struct{}), - sessionManager: sm, - extractClaimsFunc: extractClaims, + next: nextHandler, + name: "test", + redirURLPath: "/callback", + logoutURLPath: "/callback/logout", + issuerURL: "https://auth.company.com", + clientID: "test-client-id", + clientSecret: "test-client-secret", + audience: customAudience, // Set custom audience + jwkCache: mockJWKCache, + jwksURL: "https://auth.company.com/.well-known/jwks.json", + tokenBlacklist: tokenBlacklist, + tokenCache: tokenCache, + limiter: rate.NewLimiter(rate.Every(time.Second), 10), + logger: logger, + allowedUserDomains: map[string]struct{}{"company.com": {}}, + userIdentifierClaim: "email", // Required for user identification + excludedURLs: map[string]struct{}{}, + httpClient: &http.Client{}, + initComplete: make(chan struct{}), + sessionManager: sm, + extractClaimsFunc: extractClaims, } tOidc.jwtVerifier = tOidc tOidc.tokenVerifier = tOidc diff --git a/auth_flow.go b/auth_flow.go index 5b0cb9f..505c532 100644 --- a/auth_flow.go +++ b/auth_flow.go @@ -223,15 +223,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, return } - email, _ := claims["email"].(string) - if email == "" { - t.logger.Errorf("Email claim missing or empty in token during callback") - t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError) - return + // Extract user identifier from the configured claim (defaults to "email" for backward compatibility) + userIdentifier, _ := claims[t.userIdentifierClaim].(string) + if userIdentifier == "" { + // Try "sub" as fallback since it's required by OIDC spec + if t.userIdentifierClaim != "sub" { + userIdentifier, _ = claims["sub"].(string) + } + if userIdentifier == "" { + t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim) + t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError) + return + } + t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim) } - if !t.isAllowedDomain(email) { - t.logger.Errorf("Disallowed email domain during callback: %s", email) - t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden) + + // Validate user authorization + if !t.isAllowedUser(userIdentifier) { + t.logger.Errorf("User not authorized during callback: %s", userIdentifier) + t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden) return } @@ -240,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError) return } - session.SetEmail(email) + session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim) session.SetIDToken(tokenResponse.IDToken) session.SetAccessToken(tokenResponse.AccessToken) session.SetRefreshToken(tokenResponse.RefreshToken) diff --git a/handlers/oauth_handler.go b/handlers/oauth_handler.go index 055d4f6..2a1f1d3 100644 --- a/handlers/oauth_handler.go +++ b/handlers/oauth_handler.go @@ -15,7 +15,8 @@ type OAuthHandler struct { tokenExchanger TokenExchanger tokenVerifier TokenVerifier extractClaimsFunc func(tokenString string) (map[string]interface{}, error) - isAllowedDomainFunc func(email string) bool + isAllowedUserFunc func(userIdentifier string) bool // validates user authorization + userIdentifierClaim string // JWT claim to use for user identification redirURLPath string sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int) } @@ -77,16 +78,22 @@ type TokenResponse struct { // NewOAuthHandler creates a new OAuth handler func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger, tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error), - isAllowedDomainFunc func(string) bool, redirURLPath string, + isAllowedUserFunc func(string) bool, userIdentifierClaim string, redirURLPath string, sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler { + // Default to "email" for backward compatibility + if userIdentifierClaim == "" { + userIdentifierClaim = "email" + } + return &OAuthHandler{ logger: logger, sessionManager: sessionManager, tokenExchanger: tokenExchanger, tokenVerifier: tokenVerifier, extractClaimsFunc: extractClaimsFunc, - isAllowedDomainFunc: isAllowedDomainFunc, + isAllowedUserFunc: isAllowedUserFunc, + userIdentifierClaim: userIdentifierClaim, redirURLPath: redirURLPath, sendErrorResponseFunc: sendErrorResponseFunc, } @@ -225,15 +232,25 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request, return } - email, _ := claims["email"].(string) - if email == "" { - h.logger.Errorf("Email claim missing or empty in token during callback") - h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError) - return + // Extract user identifier from the configured claim (defaults to "email" for backward compatibility) + userIdentifier, _ := claims[h.userIdentifierClaim].(string) + if userIdentifier == "" { + // Try "sub" as fallback since it's required by OIDC spec + if h.userIdentifierClaim != "sub" { + userIdentifier, _ = claims["sub"].(string) + } + if userIdentifier == "" { + h.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", h.userIdentifierClaim) + h.sendErrorResponseFunc(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError) + return + } + h.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", h.userIdentifierClaim) } - if !h.isAllowedDomainFunc(email) { - h.logger.Errorf("Disallowed email domain during callback: %s", email) - h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden) + + // Validate user authorization + if !h.isAllowedUserFunc(userIdentifier) { + h.logger.Errorf("User not authorized during callback: %s", userIdentifier) + h.sendErrorResponseFunc(rw, req, "Authentication failed: User not authorized", http.StatusForbidden) return } @@ -242,7 +259,7 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request, h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError) return } - session.SetEmail(email) + session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim) session.SetIDToken(tokenResponse.IDToken) session.SetAccessToken(tokenResponse.AccessToken) session.SetRefreshToken(tokenResponse.RefreshToken) diff --git a/handlers/oauth_handler_test.go b/handlers/oauth_handler_test.go index 2e3c9f0..615ce55 100644 --- a/handlers/oauth_handler_test.go +++ b/handlers/oauth_handler_test.go @@ -108,11 +108,11 @@ func TestOAuthHandler_NewOAuthHandler(t *testing.T) { return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil } - isAllowed := func(email string) bool { return true } + isAllowedUser := func(userIdentifier string) bool { return true } sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {} handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowedUser, "email", "/callback", sendError) if handler == nil { t.Fatal("Expected handler to be created, got nil") @@ -151,7 +151,7 @@ func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil) rw := httptest.NewRecorder() @@ -190,7 +190,7 @@ func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) // Test with error parameter req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil) @@ -230,7 +230,7 @@ func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test", nil) rw := httptest.NewRecorder() @@ -265,7 +265,7 @@ func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil) rw := httptest.NewRecorder() @@ -300,7 +300,7 @@ func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil) rw := httptest.NewRecorder() @@ -335,7 +335,7 @@ func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?state=test-state", nil) rw := httptest.NewRecorder() @@ -370,7 +370,7 @@ func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -406,7 +406,7 @@ func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -444,7 +444,7 @@ func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -483,7 +483,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -521,7 +521,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -559,7 +559,7 @@ func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -591,13 +591,13 @@ func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) { if code != http.StatusInternalServerError { t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code) } - if !strings.Contains(msg, "Email missing in token") { - t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg) + if !strings.Contains(msg, "User identifier missing in token") { + t.Errorf("Expected error message to contain 'User identifier missing in token', got '%s'", msg) } } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -629,13 +629,13 @@ func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) { if code != http.StatusForbidden { t.Errorf("Expected status %d, got %d", http.StatusForbidden, code) } - if !strings.Contains(msg, "Email domain not allowed") { - t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg) + if !strings.Contains(msg, "User not authorized") { + t.Errorf("Expected error message to contain 'User not authorized', got '%s'", msg) } } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -677,7 +677,7 @@ func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -719,7 +719,7 @@ func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -760,7 +760,7 @@ func TestOAuthHandler_HandleCallback_Success(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -843,7 +843,7 @@ func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() @@ -884,7 +884,7 @@ func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) { } handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier, - extractClaims, isAllowed, "/callback", sendError) + extractClaims, isAllowed, "email", "/callback", sendError) req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil) rw := httptest.NewRecorder() diff --git a/main.go b/main.go index a338774..1ebbd8e 100644 --- a/main.go +++ b/main.go @@ -177,6 +177,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name } return "groups" // Backward compatible default }(), + userIdentifierClaim: func() string { + if config.UserIdentifierClaim != "" { + return config.UserIdentifierClaim + } + return "email" // Backward compatible default + }(), forceHTTPS: config.ForceHTTPS, enablePKCE: config.EnablePKCE, overrideScopes: config.OverrideScopes, diff --git a/main_test.go b/main_test.go index 3da5b2e..57a2d9e 100644 --- a/main_test.go +++ b/main_test.go @@ -122,22 +122,23 @@ func (ts *TestSuite) Setup() { // Common TraefikOidc instance ts.tOidc = &TraefikOidc{ - issuerURL: "https://test-issuer.com", - clientID: "test-client-id", - audience: "test-client-id", - clientSecret: "test-client-secret", - roleClaimName: "roles", // Set default for backward compatibility - groupClaimName: "groups", // Set default for backward compatibility - jwkCache: ts.mockJWKCache, - jwksURL: "https://test-jwks-url.com", - revocationURL: "https://revocation-endpoint.com", - limiter: rate.NewLimiter(rate.Every(time.Second), 10), - tokenBlacklist: tokenBlacklist, - tokenCache: tokenCache, - logger: logger, - allowedUserDomains: map[string]struct{}{"example.com": {}}, - excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}}, - httpClient: &http.Client{Timeout: 10 * time.Second}, + issuerURL: "https://test-issuer.com", + clientID: "test-client-id", + audience: "test-client-id", + clientSecret: "test-client-secret", + roleClaimName: "roles", // Set default for backward compatibility + groupClaimName: "groups", // Set default for backward compatibility + userIdentifierClaim: "email", // Set default for backward compatibility + jwkCache: ts.mockJWKCache, + jwksURL: "https://test-jwks-url.com", + revocationURL: "https://revocation-endpoint.com", + limiter: rate.NewLimiter(rate.Every(time.Second), 10), + tokenBlacklist: tokenBlacklist, + tokenCache: tokenCache, + logger: logger, + allowedUserDomains: map[string]struct{}{"example.com": {}}, + excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}}, + httpClient: &http.Client{Timeout: 10 * time.Second}, // Explicitly set paths as New() is bypassed redirURLPath: "/callback", // Assume default callback path for tests logoutURLPath: "/callback/logout", // Assume default logout path for tests @@ -784,7 +785,7 @@ func TestServeHTTP(t *testing.T) { "Accept": "application/json", }, expectedStatus: http.StatusForbidden, - expectedBody: `{"error":"Forbidden","error_description":"Access denied: Your email domain is not allowed. To log out, visit: /callback/logout","status_code":403}`, + expectedBody: `{"error":"Forbidden","error_description":"Access denied: You are not authorized to access this resource. To log out, visit: /callback/logout","status_code":403}`, }, { name: "Disallowed Domain (Accept: HTML)", @@ -1282,8 +1283,9 @@ func TestHandleCallback(t *testing.T) { instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case } tOidc := &TraefikOidc{ - allowedUserDomains: map[string]struct{}{"example.com": {}}, - logger: logger, + allowedUserDomains: map[string]struct{}{"example.com": {}}, + logger: logger, + userIdentifierClaim: "email", // Required for claim extraction // exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function tokenVerifier: nil, // Will be set to self below @@ -1438,6 +1440,228 @@ func TestIsAllowedDomain(t *testing.T) { } } +func TestIsAllowedUser(t *testing.T) { + ts := NewTestSuite(t) + ts.Setup() + + tests := []struct { + allowedDomains map[string]struct{} + allowedUsers map[string]struct{} + userIdentifierClaim string + name string + userIdentifier string + allowed bool + }{ + // Email-based identification (default behavior) + { + name: "Email identifier - allowed domain", + userIdentifier: "user@example.com", + userIdentifierClaim: "email", + allowedDomains: map[string]struct{}{"example.com": {}}, + allowedUsers: map[string]struct{}{}, + allowed: true, + }, + { + name: "Email identifier - disallowed domain", + userIdentifier: "user@notallowed.com", + userIdentifierClaim: "email", + allowedDomains: map[string]struct{}{"example.com": {}}, + allowedUsers: map[string]struct{}{}, + allowed: false, + }, + { + name: "Email identifier - specific user allowed", + userIdentifier: "specific.user@otherdomain.com", + userIdentifierClaim: "email", + allowedDomains: map[string]struct{}{"example.com": {}}, + allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}}, + allowed: true, + }, + + // Non-email identifier (sub claim - for Azure AD users without email) + { + name: "Sub identifier - allowed in allowedUsers", + userIdentifier: "abc12345-6789-0abc-def0-123456789abc", + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, + allowed: true, + }, + { + name: "Sub identifier - not in allowedUsers", + userIdentifier: "xyz-not-allowed-user", + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, + allowed: false, + }, + { + name: "Sub identifier - allowedDomains ignored for non-email", + userIdentifier: "user-id-12345", + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{"example.com": {}}, // Should be ignored + allowedUsers: map[string]struct{}{"user-id-12345": {}}, + allowed: true, + }, + { + name: "Sub identifier - no restrictions allows all", + userIdentifier: "any-user-id", + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{}, + allowed: true, + }, + { + name: "Sub identifier - case insensitive matching", + userIdentifier: "ABC12345-6789-0ABC-DEF0-123456789ABC", // Uppercase + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, // Lowercase + allowed: true, + }, + + // OID claim (Azure AD object ID) + { + name: "OID identifier - allowed user", + userIdentifier: "oid-12345-67890", + userIdentifierClaim: "oid", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{"oid-12345-67890": {}}, + allowed: true, + }, + + // UPN claim (Azure AD User Principal Name) + { + name: "UPN identifier - allowed user (looks like email but use sub logic)", + userIdentifier: "user@tenant.onmicrosoft.com", + userIdentifierClaim: "upn", + allowedDomains: map[string]struct{}{"example.com": {}}, // Different domain, should be ignored + allowedUsers: map[string]struct{}{"user@tenant.onmicrosoft.com": {}}, + allowed: true, + }, + + // Edge cases + { + name: "Empty identifier - not allowed", + userIdentifier: "", + userIdentifierClaim: "sub", + allowedDomains: map[string]struct{}{}, + allowedUsers: map[string]struct{}{"some-user": {}}, + allowed: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Configure TraefikOidc instance for this test case + tOidc := ts.tOidc + tOidc.allowedUserDomains = tc.allowedDomains + tOidc.allowedUsers = tc.allowedUsers + tOidc.userIdentifierClaim = tc.userIdentifierClaim + + allowed := tOidc.isAllowedUser(tc.userIdentifier) + if allowed != tc.allowed { + t.Errorf("Expected allowed=%v, got %v for userIdentifier=%q with claim=%q", + tc.allowed, allowed, tc.userIdentifier, tc.userIdentifierClaim) + } + }) + } +} + +func TestUserIdentifierClaimExtraction(t *testing.T) { + // Test that the correct claim is extracted based on userIdentifierClaim config + tests := []struct { + name string + userIdentifierClaim string + claims map[string]interface{} + expectedIdentifier string + shouldFallbackToSub bool + }{ + { + name: "Email claim extraction (default)", + userIdentifierClaim: "email", + claims: map[string]interface{}{ + "sub": "user-sub-id", + "email": "user@example.com", + }, + expectedIdentifier: "user@example.com", + shouldFallbackToSub: false, + }, + { + name: "Sub claim extraction", + userIdentifierClaim: "sub", + claims: map[string]interface{}{ + "sub": "user-sub-id", + "email": "user@example.com", + }, + expectedIdentifier: "user-sub-id", + shouldFallbackToSub: false, + }, + { + name: "OID claim extraction (Azure AD)", + userIdentifierClaim: "oid", + claims: map[string]interface{}{ + "sub": "user-sub-id", + "email": "user@example.com", + "oid": "azure-object-id", + }, + expectedIdentifier: "azure-object-id", + shouldFallbackToSub: false, + }, + { + name: "UPN claim extraction (Azure AD)", + userIdentifierClaim: "upn", + claims: map[string]interface{}{ + "sub": "user-sub-id", + "upn": "user@tenant.onmicrosoft.com", + }, + expectedIdentifier: "user@tenant.onmicrosoft.com", + shouldFallbackToSub: false, + }, + { + name: "Fallback to sub when configured claim is missing", + userIdentifierClaim: "email", + claims: map[string]interface{}{ + "sub": "fallback-sub-id", + // email is missing + }, + expectedIdentifier: "fallback-sub-id", + shouldFallbackToSub: true, + }, + { + name: "preferred_username claim extraction", + userIdentifierClaim: "preferred_username", + claims: map[string]interface{}{ + "sub": "user-sub-id", + "preferred_username": "jdoe", + }, + expectedIdentifier: "jdoe", + shouldFallbackToSub: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Extract user identifier using the same logic as auth_flow.go + userIdentifier, _ := tc.claims[tc.userIdentifierClaim].(string) + usedFallback := false + + if userIdentifier == "" && tc.userIdentifierClaim != "sub" { + userIdentifier, _ = tc.claims["sub"].(string) + usedFallback = true + } + + if userIdentifier != tc.expectedIdentifier { + t.Errorf("Expected identifier %q, got %q", tc.expectedIdentifier, userIdentifier) + } + + if usedFallback != tc.shouldFallbackToSub { + t.Errorf("Expected fallback=%v, got %v", tc.shouldFallbackToSub, usedFallback) + } + }) + } +} + func TestOIDCHandler(t *testing.T) { ts := NewTestSuite(t) ts.Setup() diff --git a/middleware.go b/middleware.go index 9d103e4..518ddd9 100644 --- a/middleware.go +++ b/middleware.go @@ -125,12 +125,12 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - email := session.GetEmail() - // Domain restriction check removed debug output - if authenticated && email != "" { - if !t.isAllowedDomain(email) { - t.logger.Infof("User with email %s is not from an allowed domain", email) - errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath) + userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim) + // User authorization check + if authenticated && userIdentifier != "" { + if !t.isAllowedUser(userIdentifier) { + t.logger.Infof("User %s is not authorized", userIdentifier) + errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath) t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden) return } @@ -193,10 +193,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { refreshed := t.refreshToken(rw, req, session) if refreshed { - email = session.GetEmail() - if email != "" && !t.isAllowedDomain(email) { - t.logger.Infof("User with refreshed token email %s is not from an allowed domain", email) - errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath) + userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier + if userIdentifier != "" && !t.isAllowedUser(userIdentifier) { + t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier) + errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath) t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden) return } diff --git a/settings.go b/settings.go index 9a2540d..19671b6 100644 --- a/settings.go +++ b/settings.go @@ -127,6 +127,22 @@ type Config struct { // Default: "groups" GroupClaimName string `json:"groupClaimName,omitempty"` + // UserIdentifierClaim specifies the JWT claim to use as the user identifier. + // This allows authentication for users without email addresses (e.g., Azure AD service accounts). + // + // Examples: + // - Default (backward compatible): "email" + // - Azure AD without email: "sub", "oid", "upn", or "preferred_username" + // - Generic OIDC: "sub" (always present per OIDC spec) + // + // When set to a non-email claim: + // - AllowedUsers will match against this claim value instead of email + // - AllowedUserDomains validation is skipped (domains only apply to email) + // - The session will store this identifier as the user's identity + // + // Default: "email" + UserIdentifierClaim string `json:"userIdentifierClaim,omitempty"` + // DynamicClientRegistration enables OIDC Dynamic Client Registration (RFC 7591) // When enabled, the middleware will automatically register as a client with // the OIDC provider if ClientID/ClientSecret are not provided. diff --git a/types.go b/types.go index ad66337..f4172dc 100644 --- a/types.go +++ b/types.go @@ -99,6 +99,7 @@ type TraefikOidc struct { audience string // Expected JWT audience, defaults to clientID roleClaimName string // JWT claim name for extracting roles, defaults to "roles" groupClaimName string // JWT claim name for extracting groups, defaults to "groups" + userIdentifierClaim string // JWT claim for user identification, defaults to "email" name string redirURLPath string logoutURLPath string diff --git a/utilities.go b/utilities.go index ddfbc57..dce4518 100644 --- a/utilities.go +++ b/utilities.go @@ -55,6 +55,51 @@ func (t *TraefikOidc) safeLogInfo(msg string) { // DOMAIN VALIDATION // ============================================================================= +// isAllowedUser checks if a user identifier is authorized based on the configured user identifier claim. +// When using email as the identifier (default), it validates against allowedUsers and allowedUserDomains. +// When using non-email identifiers (sub, oid, upn, etc.), it only validates against allowedUsers +// since domain-based validation doesn't apply to non-email identifiers. +// +// Parameters: +// - userIdentifier: The user identifier to validate (email, sub, oid, upn, etc.). +// +// Returns: +// - true if the user is authorized, false otherwise. +func (t *TraefikOidc) isAllowedUser(userIdentifier string) bool { + // If no restrictions are configured, allow all authenticated users + if len(t.allowedUserDomains) == 0 && len(t.allowedUsers) == 0 { + return true + } + + // Check if user is explicitly allowed + if len(t.allowedUsers) > 0 { + _, userAllowed := t.allowedUsers[strings.ToLower(userIdentifier)] + if userAllowed { + t.logger.Debugf("User identifier %s is explicitly allowed in allowedUsers", userIdentifier) + return true + } + } + + // For email-based identifiers, also check domain restrictions + // Only apply domain validation if using email as identifier AND identifier looks like an email + if t.userIdentifierClaim == "email" && strings.Contains(userIdentifier, "@") { + return t.isAllowedDomain(userIdentifier) + } + + // For non-email identifiers with allowedUserDomains configured, log a warning + if len(t.allowedUserDomains) > 0 && t.userIdentifierClaim != "email" { + t.logger.Debugf("AllowedUserDomains is configured but userIdentifierClaim is '%s', not 'email'. Domain validation skipped for: %s", + t.userIdentifierClaim, userIdentifier) + } + + // User not found in allowedUsers list + if len(t.allowedUsers) > 0 { + t.logger.Debugf("User identifier %s is not in the allowed users list", userIdentifier) + } + + return false +} + // isAllowedDomain checks if an email address is authorized based on domain or user whitelist. // It validates against both allowed user domains and specific allowed users. // Parameters: