diff --git a/audience_test.go b/audience_test.go index f805ab0..5310df1 100644 --- a/audience_test.go +++ b/audience_test.go @@ -1491,7 +1491,7 @@ func TestAudienceEndToEndScenario(t *testing.T) { if err := session.SetAuthenticated(true); err != nil { t.Fatalf("Failed to set authenticated: %v", err) } - session.SetEmail("user@company.com") + session.SetUserIdentifier("user@company.com") session.SetIDToken(validJWT) session.SetAccessToken(validJWT) diff --git a/auth_flow.go b/auth_flow.go index 47b3113..39d577a 100644 --- a/auth_flow.go +++ b/auth_flow.go @@ -43,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) { func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) { // Clear all existing session data _ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow - session.SetEmail("") + session.SetUserIdentifier("") session.SetAccessToken("") session.SetRefreshToken("") session.SetIDToken("") @@ -250,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError) return } - session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim) + session.SetUserIdentifier(userIdentifier) session.SetIDToken(tokenResponse.IDToken) session.SetAccessToken(tokenResponse.AccessToken) session.SetRefreshToken(tokenResponse.RefreshToken) @@ -290,7 +290,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque session.SetIDToken("") session.SetAccessToken("") session.SetRefreshToken("") - session.SetEmail("") + session.SetUserIdentifier("") // Clear CSRF tokens to prevent replay attacks session.SetCSRF("") session.SetNonce("") diff --git a/auth_flow_behaviour_test.go b/auth_flow_behaviour_test.go index 9e694b4..5cb2a86 100644 --- a/auth_flow_behaviour_test.go +++ b/auth_flow_behaviour_test.go @@ -192,7 +192,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() { // Pre-populate session with old data _ = session.SetAuthenticated(true) - session.SetEmail("old@example.com") + session.SetUserIdentifier("old@example.com") session.SetAccessToken("old-access-token-with-many-characters") session.SetRefreshToken("old-refresh-token-with-many-characters") session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature") @@ -207,7 +207,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() { // Verify old data is cleared s.False(session.GetAuthenticated()) - s.Empty(session.GetEmail()) + s.Empty(session.GetUserIdentifier()) // Verify new data is set s.Equal(csrfToken, session.GetCSRF()) @@ -711,7 +711,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() { session, err := sessionManager.GetSession(req) s.Require().NoError(err) _ = session.SetAuthenticated(true) - session.SetEmail("test@example.com") + session.SetUserIdentifier("test@example.com") session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature") session.mainSession.Values["redirect_count"] = 3 @@ -720,7 +720,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() { // Session should be cleared s.False(session.GetAuthenticated()) - s.Empty(session.GetEmail()) + s.Empty(session.GetUserIdentifier()) s.Empty(session.GetIDToken()) // Redirect count should be reset to 0 and then incremented by defaultInitiateAuthentication diff --git a/csrf_session_test.go b/csrf_session_test.go index 4125187..b2a1546 100644 --- a/csrf_session_test.go +++ b/csrf_session_test.go @@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) { session.SetCSRF(csrfToken) session.SetNonce("test-nonce") session.SetAuthenticated(true) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAccessToken("old-access-token") session.SetRefreshToken("old-refresh-token") session.SetIDToken("old-id-token") @@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) { // Now perform selective clearing (as done in the fix) session2.SetAuthenticated(false) - session2.SetEmail("") + session2.SetUserIdentifier("") session2.SetAccessToken("") session2.SetRefreshToken("") session2.SetIDToken("") @@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) { // Set initial session data session.SetAuthenticated(true) - session.SetEmail("old@example.com") + session.SetUserIdentifier("old@example.com") session.SetAccessToken("old-token") session.SetCSRF("existing-csrf") @@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) { // OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF // NEW BEHAVIOR: Selective clearing session2.SetAuthenticated(false) - session2.SetEmail("") + session2.SetUserIdentifier("") session2.SetAccessToken("") session2.SetRefreshToken("") session2.SetIDToken("") diff --git a/issue132_regression_test.go b/issue132_regression_test.go new file mode 100644 index 0000000..7c70a3b --- /dev/null +++ b/issue132_regression_test.go @@ -0,0 +1,135 @@ +package traefikoidc + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +// TestIssue132_RefreshTokenHonorsUserIdentifierClaim reproduces and verifies +// the fix for issue #132: token refresh path hardcoded the "email" claim and +// ignored the configured userIdentifierClaim. Keycloak users without an email +// claim (using sub or another identifier) were being kicked out on refresh +// even though their initial login worked. +// +// The callback path (auth_flow.go) already honored userIdentifierClaim with +// "sub" fallback. The refresh path (token_manager.go) had drifted out of sync +// after PR #100 (commit a316a98). +func TestIssue132_RefreshTokenHonorsUserIdentifierClaim(t *testing.T) { + tests := []struct { + claims map[string]any + name string + userIdentifierClaim string + expectedIdentifier string + expectSuccess bool + }{ + { + name: "sub claim configured, only sub present (Keycloak no-email case)", + userIdentifierClaim: "sub", + claims: map[string]any{ + "sub": "user-uuid-keycloak-12345", + "exp": float64(9999999999), + }, + expectSuccess: true, + expectedIdentifier: "user-uuid-keycloak-12345", + }, + { + name: "preferred_username configured, claim present", + userIdentifierClaim: "preferred_username", + claims: map[string]any{ + "sub": "user-uuid-12345", + "preferred_username": "alice", + "exp": float64(9999999999), + }, + expectSuccess: true, + expectedIdentifier: "alice", + }, + { + name: "configured claim missing, falls back to sub", + userIdentifierClaim: "preferred_username", + claims: map[string]any{ + "sub": "fallback-sub-id", + "exp": float64(9999999999), + }, + expectSuccess: true, + expectedIdentifier: "fallback-sub-id", + }, + { + name: "email default, email present (backward compatibility)", + userIdentifierClaim: "email", + claims: map[string]any{ + "sub": "user-uuid-12345", + "email": "user@example.com", + "exp": float64(9999999999), + }, + expectSuccess: true, + expectedIdentifier: "user@example.com", + }, + { + name: "email default, no email and no sub - refresh fails", + userIdentifierClaim: "email", + claims: map[string]any{ + "exp": float64(9999999999), + }, + expectSuccess: false, + expectedIdentifier: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sessionManager, err := NewSessionManager( + "test-encryption-key-32-bytes-long!!", + false, + "", + "", + 0, + NewLogger("error"), + ) + if err != nil { + t.Fatalf("session manager: %v", err) + } + defer sessionManager.Shutdown() + + capturedClaims := tt.claims + tOidc := &TraefikOidc{ + logger: NewLogger("error"), + userIdentifierClaim: tt.userIdentifierClaim, + sessionManager: sessionManager, + tokenExchanger: &EnhancedMockTokenExchanger{ + RefreshResponse: &TokenResponse{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + IDToken: "new-id-token-jwt", + ExpiresIn: 3600, + }, + }, + tokenVerifier: &EnhancedMockTokenVerifier{Err: nil}, + extractClaimsFunc: func(token string) (map[string]any, error) { + return capturedClaims, nil + }, + } + + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + rw := httptest.NewRecorder() + + session, err := sessionManager.GetSession(req) + if err != nil { + t.Fatalf("get session: %v", err) + } + defer session.returnToPoolSafely() + + session.SetRefreshToken("initial-refresh-token") + + refreshed := tOidc.refreshToken(rw, req, session) + + if refreshed != tt.expectSuccess { + t.Fatalf("refreshToken() = %v, want %v", refreshed, tt.expectSuccess) + } + + if got := session.GetUserIdentifier(); got != tt.expectedIdentifier { + t.Errorf("session.GetUserIdentifier() = %q, want %q", got, tt.expectedIdentifier) + } + }) + } +} diff --git a/main_servehttp_test.go b/main_servehttp_test.go index 3ca87d1..c2aab43 100644 --- a/main_servehttp_test.go +++ b/main_servehttp_test.go @@ -138,7 +138,7 @@ func TestServeHTTP_EventStream(t *testing.T) { if err != nil { t.Fatalf("failed to create test session: %v", err) } - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") if err := session.SetAuthenticated(true); err != nil { t.Fatalf("failed to mark session authenticated: %v", err) } @@ -221,7 +221,7 @@ func TestServeHTTP_WebSocketUpgrade(t *testing.T) { if err != nil { t.Fatalf("failed to create test session: %v", err) } - session.SetEmail("ws-user@example.com") + session.SetUserIdentifier("ws-user@example.com") if err := session.SetAuthenticated(true); err != nil { t.Fatalf("failed to mark session authenticated: %v", err) } @@ -408,7 +408,7 @@ func TestProcessAuthorizedRequest(t *testing.T) { name: "successful authorization with email", setupSession: func() *MockSessionData { session := &MockSessionData{ - email: "user@example.com", + userIdentifier: "user@example.com", idToken: "test-id-token", accessToken: "test-access-token", isDirty: false, @@ -440,7 +440,7 @@ func TestProcessAuthorizedRequest(t *testing.T) { name: "no email triggers reauth", setupSession: func() *MockSessionData { return &MockSessionData{ - email: "", + userIdentifier: "", idToken: "test-id-token", accessToken: "test-access-token", } @@ -461,7 +461,7 @@ func TestProcessAuthorizedRequest(t *testing.T) { name: "roles and groups authorization", setupSession: func() *MockSessionData { return &MockSessionData{ - email: "user@example.com", + userIdentifier: "user@example.com", idToken: "test-id-token", accessToken: "test-access-token", } @@ -494,7 +494,7 @@ func TestProcessAuthorizedRequest(t *testing.T) { name: "unauthorized role/group returns 403", setupSession: func() *MockSessionData { return &MockSessionData{ - email: "user@example.com", + userIdentifier: "user@example.com", idToken: "test-id-token", accessToken: "test-access-token", } @@ -521,7 +521,7 @@ func TestProcessAuthorizedRequest(t *testing.T) { name: "template headers processing", setupSession: func() *MockSessionData { return &MockSessionData{ - email: "user@example.com", + userIdentifier: "user@example.com", idToken: "test-id-token", accessToken: "test-access-token", isDirty: false, @@ -553,7 +553,7 @@ func TestProcessAuthorizedRequest(t *testing.T) { name: "OPTIONS request with CORS", setupSession: func() *MockSessionData { return &MockSessionData{ - email: "user@example.com", + userIdentifier: "user@example.com", idToken: "test-id-token", accessToken: "test-access-token", } @@ -604,7 +604,7 @@ func TestProcessAuthorizedRequest(t *testing.T) { manager: &SessionManager{logger: NewLogger("debug")}, } // Copy values from mock to concrete session - concreteSession.SetEmail(session.email) + concreteSession.SetUserIdentifier(session.userIdentifier) concreteSession.SetIDToken(session.idToken) concreteSession.SetAccessToken(session.accessToken) concreteSession.SetRefreshToken(session.refreshToken) @@ -654,23 +654,23 @@ func TestProcessAuthorizedRequest(t *testing.T) { // MockSessionData is a test implementation of SessionData interface type MockSessionData struct { - email string - idToken string - accessToken string - refreshToken string - csrf string - nonce string - codeVerifier string - redirectCount int - authenticated bool - isDirty bool + userIdentifier string + idToken string + accessToken string + refreshToken string + csrf string + nonce string + codeVerifier string + redirectCount int + authenticated bool + isDirty bool } -func (m *MockSessionData) GetEmail() string { return m.email } +func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier } func (m *MockSessionData) GetIDToken() string { return m.idToken } func (m *MockSessionData) GetAccessToken() string { return m.accessToken } func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken } -func (m *MockSessionData) SetEmail(email string) { m.email = email } +func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier } func (m *MockSessionData) SetIDToken(token string) { m.idToken = token } func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token } func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token } @@ -762,7 +762,7 @@ func TestMinimalHeaders(t *testing.T) { } // Set up session data - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAuthenticated(true) // Call processAuthorizedRequest directly @@ -837,7 +837,7 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) { t.Fatalf("Failed to get session: %v", err) } - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAuthenticated(true) oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback") @@ -923,7 +923,7 @@ func TestStripAuthCookies(t *testing.T) { if err != nil { t.Fatalf("Failed to get session: %v", err) } - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAuthenticated(true) // Now add OIDC session cookies (simulating what the browser would send) @@ -1004,7 +1004,7 @@ func TestStripAuthCookies_NoCookies(t *testing.T) { if err != nil { t.Fatalf("Failed to get session: %v", err) } - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAuthenticated(true) oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback") @@ -1051,7 +1051,7 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) { if err != nil { t.Fatalf("Failed to get session: %v", err) } - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAuthenticated(true) // Add only OIDC cookies @@ -1102,7 +1102,7 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) { if err != nil { t.Fatalf("Failed to get session: %v", err) } - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAuthenticated(true) // Add only non-OIDC cookies @@ -1165,7 +1165,7 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) { if err != nil { t.Fatalf("Failed to get session: %v", err) } - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAuthenticated(true) // Add cookies with the custom prefix (should be stripped) diff --git a/main_test.go b/main_test.go index 8abff2d..34b8ea1 100644 --- a/main_test.go +++ b/main_test.go @@ -580,7 +580,7 @@ func TestServeHTTP(t *testing.T) { requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") // Generate a fresh valid token for this test case to avoid replay issues freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), @@ -603,7 +603,7 @@ func TestServeHTTP(t *testing.T) { // even if session.SetAuthenticated(true) was called. // We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt. session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") // Create an expired token for this test expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(), @@ -660,7 +660,7 @@ func TestServeHTTP(t *testing.T) { requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup setupSession: func(session *SessionData) { session.SetAuthenticated(true) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") // Generate a fresh valid token for this test case freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), @@ -678,7 +678,7 @@ func TestServeHTTP(t *testing.T) { requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) // Set flag initially - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") // Create an expired token for this test expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(), @@ -706,7 +706,7 @@ func TestServeHTTP(t *testing.T) { requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) // Set flag initially - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") // Create an expired token for this test expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(), @@ -741,7 +741,7 @@ func TestServeHTTP(t *testing.T) { "sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16), }) session.SetAuthenticated(true) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAccessToken(nearExpiryToken) session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh }, @@ -772,7 +772,7 @@ func TestServeHTTP(t *testing.T) { "sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16), }) session.SetAuthenticated(true) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAccessToken(validToken) session.SetIDToken(validToken) // Ensure ID token is also set session.SetRefreshToken("should-not-be-used-refresh-token") @@ -792,7 +792,7 @@ func TestServeHTTP(t *testing.T) { requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) - session.SetEmail("user@disallowed.com") // Use disallowed domain + session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain // Generate a fresh valid token for this test case freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), @@ -814,7 +814,7 @@ func TestServeHTTP(t *testing.T) { requestPath: "/protected", setupSession: func(session *SessionData) { session.SetAuthenticated(true) - session.SetEmail("user@disallowed.com") // Use disallowed domain + session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain // Generate a fresh valid token for this test case freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{ "iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(), @@ -2179,7 +2179,7 @@ func TestHandleExpiredToken(t *testing.T) { "sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16), }) session.SetAccessToken(expiredToken) - session.SetEmail("test@example.com") + session.SetUserIdentifier("test@example.com") }, expectedPath: "/original/path", }, @@ -2756,7 +2756,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") }, expectedStatus: http.StatusOK, expectedHeaders: map[string]string{ @@ -2782,7 +2782,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") }, expectedStatus: http.StatusOK, expectedHeaders: map[string]string{ @@ -2809,7 +2809,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") }, expectedStatus: http.StatusForbidden, }, @@ -2829,7 +2829,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") }, expectedStatus: http.StatusOK, expectedHeaders: map[string]string{ @@ -2851,7 +2851,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) { }, setupSession: func(session *SessionData) { session.SetAuthenticated(true) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") }, expectedStatus: http.StatusOK, expectedHeaders: map[string]string{}, diff --git a/middleware.go b/middleware.go index fa636f8..083b637 100644 --- a/middleware.go +++ b/middleware.go @@ -92,17 +92,17 @@ func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) b return false } - email := session.GetEmail() - if email == "" { + userIdentifier := session.GetUserIdentifier() + if userIdentifier == "" { t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason) return false } - req.Header.Set("X-Forwarded-User", email) + req.Header.Set("X-Forwarded-User", userIdentifier) if !t.minimalHeaders { - req.Header.Set("X-Auth-Request-User", email) + req.Header.Set("X-Auth-Request-User", userIdentifier) } - t.logger.Debugf("%s bypass: forwarded user %s from session", reason, email) + t.logger.Debugf("%s bypass: forwarded user %s from session", reason, userIdentifier) return true } @@ -289,7 +289,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim) + userIdentifier := session.GetUserIdentifier() // User authorization check if authenticated && userIdentifier != "" { if !t.isAllowedUser(userIdentifier) { @@ -361,7 +361,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { refreshed := t.refreshToken(rw, req, session) if refreshed { - userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier + userIdentifier = session.GetUserIdentifier() if userIdentifier != "" && !t.isAllowedUser(userIdentifier) { t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier) errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath) @@ -411,9 +411,9 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // - session: The user's session data containing tokens and claims. // - redirectURL: The callback URL for re-authentication if needed. func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { - email := session.GetEmail() - if email == "" { - t.logger.Info("No email found in session during final processing, initiating re-auth") + userIdentifier := session.GetUserIdentifier() + if userIdentifier == "" { + t.logger.Info("No user identifier found in session during final processing, initiating re-auth") // Reset redirect count to prevent loops when session is invalid session.ResetRedirectCount() t.defaultInitiateAuthentication(rw, req, session, redirectURL) @@ -426,7 +426,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http if idToken != "" { sid, sub, createdAt := t.extractSessionInfo(idToken) if t.isSessionInvalidated(sid, sub, createdAt) { - t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email) + t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier) // Clear the session and redirect to login if err := session.Clear(req, rw); err != nil { t.logger.Errorf("Error clearing invalidated session: %v", err) @@ -502,19 +502,19 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http } } if !allowed { - t.logger.Infof("User with email %s does not have any allowed roles or groups", email) + t.logger.Infof("User %s does not have any allowed roles or groups", userIdentifier) errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath) t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden) return } } - req.Header.Set("X-Forwarded-User", email) + req.Header.Set("X-Forwarded-User", userIdentifier) // When minimalHeaders is enabled, skip extra headers to prevent 431 errors if !t.minimalHeaders { req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI()) - req.Header.Set("X-Auth-Request-User", email) + req.Header.Set("X-Auth-Request-User", userIdentifier) if idToken != "" { req.Header.Set("X-Auth-Request-Token", idToken) } @@ -587,7 +587,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http } } - t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email) + t.logger.Debugf("Request authorized for user %s, forwarding to next handler", userIdentifier) t.next.ServeHTTP(rw, req) } diff --git a/middleware_edge_cases_test.go b/middleware_edge_cases_test.go index 8deabf8..5aab76c 100644 --- a/middleware_edge_cases_test.go +++ b/middleware_edge_cases_test.go @@ -161,7 +161,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) { // Create authenticated session req := httptest.NewRequest("GET", "/api/test", nil) session, _ := sessionManager.GetSession(req) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAuthenticated(true) session.SetIDToken("dummy-token") session.Save(req, httptest.NewRecorder()) @@ -203,7 +203,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) { // Create session with forbidden domain req := httptest.NewRequest("GET", "/api/test", nil) session, _ := sessionManager.GetSession(req) - session.SetEmail("user@forbidden.com") + session.SetUserIdentifier("user@forbidden.com") session.SetAuthenticated(true) // Save and inject cookies @@ -252,7 +252,7 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) { // Create session with opaque token req := httptest.NewRequest("GET", "/api/test", nil) session, _ := sessionManager.GetSession(req) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots) session.SetAuthenticated(true) @@ -291,7 +291,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) { req := httptest.NewRequest("GET", "/api/test", nil) session, _ := sessionManager.GetSession(req) - session.SetEmail("") // No email + session.SetUserIdentifier("") // No email session.SetIDToken("dummy-token") rw := httptest.NewRecorder() @@ -321,7 +321,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) { req := httptest.NewRequest("GET", "/api/test", nil) session, _ := sessionManager.GetSession(req) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetIDToken("") // No ID token session.SetAccessToken("") // No access token @@ -349,7 +349,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) { req := httptest.NewRequest("GET", "/api/test", nil) session, _ := sessionManager.GetSession(req) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") session.SetIDToken("dummy-token") rw := httptest.NewRecorder() @@ -383,7 +383,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) { req := httptest.NewRequest("GET", "/api/test", nil) session, _ := sessionManager.GetSession(req) testEmail := "user@example.com" - session.SetEmail(testEmail) + session.SetUserIdentifier(testEmail) session.SetIDToken("dummy-id-token") rw := httptest.NewRecorder() diff --git a/regression/regression_test.go b/regression/regression_test.go index 508c451..db2d924 100644 --- a/regression/regression_test.go +++ b/regression/regression_test.go @@ -129,7 +129,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) { // Simulate successful Azure authentication session.SetAuthenticated(true) - session.SetEmail("user@example.com") + session.SetUserIdentifier("user@example.com") // Azure may use opaque access tokens session.SetAccessToken("opaque-azure-access-token") session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ") // trufflehog:ignore @@ -152,7 +152,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) { require.NoError(t, err) assert.True(t, session2.GetAuthenticated(), "User should remain authenticated") - assert.Equal(t, "user@example.com", session2.GetEmail()) + assert.Equal(t, "user@example.com", session2.GetUserIdentifier()) assert.NotEmpty(t, session2.GetAccessToken(), "Access token should persist") assert.NotEmpty(t, session2.GetIDToken(), "ID token should persist") assert.NotEmpty(t, session2.GetRefreshToken(), "Refresh token should persist") diff --git a/security_edge_cases_test.go b/security_edge_cases_test.go index d967c3e..ccc78e9 100644 --- a/security_edge_cases_test.go +++ b/security_edge_cases_test.go @@ -485,7 +485,7 @@ func TestSessionFixationAttack(t *testing.T) { // Set up the attacker's session with malicious data attackerSession.SetAuthenticated(true) - attackerSession.SetEmail("attacker@evil.com") + attackerSession.SetUserIdentifier("attacker@evil.com") attackerSession.SetIDToken(ValidIDToken) attackerSession.SetAccessToken(ValidAccessToken) @@ -512,7 +512,7 @@ func TestSessionFixationAttack(t *testing.T) { } // Get the email from the session - email := session.GetEmail() + email := session.GetUserIdentifier() w.Header().Set("X-User-Email", email) w.WriteHeader(http.StatusOK) }) diff --git a/session.go b/session.go index 8954783..d282d1f 100644 --- a/session.go +++ b/session.go @@ -100,7 +100,7 @@ type combinedSessionPayload struct { A string `json:"a,omitempty"` R string `json:"r,omitempty"` I string `json:"i,omitempty"` - E string `json:"e,omitempty"` + Ui string `json:"ui,omitempty"` Cs string `json:"cs,omitempty"` N string `json:"n,omitempty"` Cv string `json:"cv,omitempty"` @@ -113,11 +113,11 @@ type combinedSessionPayload struct { // knownSessionKeys are the standard keys that are handled explicitly in the combined payload. // All other mainSession.Values keys are stored in the X (extra) field. var knownSessionKeys = map[string]bool{ - "access_token": true, - "refresh_token": true, - "id_token": true, - "email": true, - "authenticated": true, + "access_token": true, + "refresh_token": true, + "id_token": true, + "user_identifier": true, + "authenticated": true, "csrf": true, "nonce": true, "code_verifier": true, @@ -1134,7 +1134,7 @@ func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData * sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName()) // Populate legacy session values from combined payload - sessionData.mainSession.Values["email"] = payload.E + sessionData.mainSession.Values["user_identifier"] = payload.Ui sessionData.mainSession.Values["authenticated"] = payload.Au sessionData.mainSession.Values["csrf"] = payload.Cs sessionData.mainSession.Values["nonce"] = payload.N @@ -1278,7 +1278,7 @@ func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, opti A: sd.getAccessTokenUnsafe(), R: sd.getRefreshTokenUnsafe(), I: sd.getIDTokenUnsafe(), - E: sd.getEmailUnsafe(), + Ui: sd.getUserIdentifierUnsafe(), Au: sd.getAuthenticatedUnsafe(), Cs: sd.getCSRFUnsafe(), N: sd.getNonceUnsafe(), @@ -2469,30 +2469,30 @@ func (sd *SessionData) SetCodeVerifier(codeVerifier string) { } } -// GetEmail retrieves the authenticated user's email address. -// The email is extracted from ID token claims and used for -// authorization decisions and header injection. +// GetUserIdentifier retrieves the authenticated user's identifier as extracted +// from the configured userIdentifierClaim of the ID token (email, sub, oid, +// upn, preferred_username, etc.). The value is used for authorization +// decisions and header injection. // Returns: -// - The user's email address string, or an empty string if not set. -func (sd *SessionData) GetEmail() string { +// - The user identifier string, or an empty string if not set. +func (sd *SessionData) GetUserIdentifier() string { sd.sessionMutex.RLock() defer sd.sessionMutex.RUnlock() - email, _ := sd.mainSession.Values["email"].(string) - return email + userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string) + return userIdentifier } -// SetEmail stores the authenticated user's email address. -// The email is typically extracted from the 'email' claim in the ID token. +// SetUserIdentifier stores the authenticated user's identifier value. // Parameters: -// - email: The user's email address to store. -func (sd *SessionData) SetEmail(email string) { +// - userIdentifier: The user identifier to store (email, sub, or other claim value). +func (sd *SessionData) SetUserIdentifier(userIdentifier string) { sd.sessionMutex.Lock() defer sd.sessionMutex.Unlock() - currentVal, _ := sd.mainSession.Values["email"].(string) - if currentVal != email { - sd.mainSession.Values["email"] = email + currentVal, _ := sd.mainSession.Values["user_identifier"].(string) + if currentVal != userIdentifier { + sd.mainSession.Values["user_identifier"] = userIdentifier sd.dirty = true } } @@ -2626,10 +2626,10 @@ func (sd *SessionData) getRefreshTokenUnsafe() string { return result.Token } -// getEmailUnsafe retrieves the email without acquiring locks. -func (sd *SessionData) getEmailUnsafe() string { - email, _ := sd.mainSession.Values["email"].(string) - return email +// getUserIdentifierUnsafe retrieves the user identifier without acquiring locks. +func (sd *SessionData) getUserIdentifierUnsafe() string { + userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string) + return userIdentifier } // getCSRFUnsafe retrieves the CSRF token without acquiring locks. diff --git a/session_behaviour_test.go b/session_behaviour_test.go index cc22b59..cb146d1 100644 --- a/session_behaviour_test.go +++ b/session_behaviour_test.go @@ -320,17 +320,16 @@ func (s *SessionBehaviourSuite) TestSessionData_DirtyTracking() { s.False(session.IsDirty()) } -// TestSessionData_SetEmail tests email setter with dirty tracking -func (s *SessionBehaviourSuite) TestSessionData_SetEmail() { +// TestSessionData_SetUserIdentifier tests user identifier setter with dirty tracking +func (s *SessionBehaviourSuite) TestSessionData_SetUserIdentifier() { req := httptest.NewRequest(http.MethodGet, "/test", nil) session, err := s.sessionManager.GetSession(req) s.Require().NoError(err) defer session.returnToPoolSafely() - // Set email - session.SetEmail("test@example.com") - s.Equal("test@example.com", session.GetEmail()) + session.SetUserIdentifier("test@example.com") + s.Equal("test@example.com", session.GetUserIdentifier()) s.True(session.IsDirty()) } @@ -568,7 +567,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Clear() { // Set some data err = session.SetAuthenticated(true) s.Require().NoError(err) - session.SetEmail("test@example.com") + session.SetUserIdentifier("test@example.com") session.SetCSRF("csrf-token") // Clear session @@ -588,7 +587,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Save() { defer session.returnToPoolSafely() // Modify session - session.SetEmail("test@example.com") + session.SetUserIdentifier("test@example.com") s.True(session.IsDirty()) // Save session diff --git a/session_test.go b/session_test.go index 691fe7e..e34e1e4 100644 --- a/session_test.go +++ b/session_test.go @@ -2688,7 +2688,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) { // Set up initial session state (what user has when first logging in) session1.SetAuthenticated(true) - session1.SetEmail(originalUserData["email"].(string)) + session1.SetUserIdentifier(originalUserData["email"].(string)) session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars") session1.SetIDToken("initial-valid-id-token-longer-than-20-chars") session1.SetRefreshToken("valid-refresh-token-should-last-30-days") @@ -2732,7 +2732,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) { // Simulate what happens when middleware detects expired tokens // It should preserve session state while attempting token refresh originalAuth := session2.GetAuthenticated() - originalEmail := session2.GetEmail() + originalEmail := session2.GetUserIdentifier() // Reconstruct user data from individual stored keys originalUserDataStored := make(map[string]interface{}) @@ -2813,7 +2813,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) { // Verify all session data is still intact after token refresh postRefreshAuth := session2.GetAuthenticated() - postRefreshEmail := session2.GetEmail() + postRefreshEmail := session2.GetUserIdentifier() userDataPresent := true for k := range originalUserData { if session2.mainSession.Values["user_data_"+k] == nil { @@ -2907,7 +2907,7 @@ func TestSessionExpiryVsTokenExpiry(t *testing.T) { // Set up session with specific creation time session.SetAuthenticated(true) - session.SetEmail("test@example.com") + session.SetUserIdentifier("test@example.com") session.mainSession.Values["created_at"] = sessionCreatedAt.Unix() // Create tokens with specific expiry @@ -3018,7 +3018,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) { // Set up session with data that should be preserved or removed session.SetAuthenticated(true) - session.SetEmail("cleanup@example.com") + session.SetUserIdentifier("cleanup@example.com") session.mainSession.Values["user_data"] = "Test User|user-123" session.mainSession.Values["preferences"] = "theme:dark,lang:en" @@ -3049,7 +3049,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) { if scenario.shouldCleanup { if sessionTooOld { session.SetAuthenticated(false) - session.SetEmail("") + session.SetUserIdentifier("") session.SetAccessToken("") session.SetRefreshToken("") for key := range session.mainSession.Values { diff --git a/test_framework_test.go b/test_framework_test.go index 27f3e21..a18b646 100644 --- a/test_framework_test.go +++ b/test_framework_test.go @@ -293,7 +293,7 @@ func (tf *TestFramework) CreateAuthenticatedRequest(method, path string) (*http. } session.SetAuthenticated(true) - session.SetEmail(tf.fixtures.UserEmail) + session.SetUserIdentifier(tf.fixtures.UserEmail) session.SetAccessToken(tf.fixtures.AccessToken) session.SetRefreshToken(tf.fixtures.RefreshToken) session.SetIDToken(tf.GenerateJWT(tf.fixtures.Claims)) diff --git a/token_manager.go b/token_manager.go index a20fcaf..9463410 100644 --- a/token_manager.go +++ b/token_manager.go @@ -434,7 +434,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se session.SetRefreshToken("") session.SetAccessToken("") session.SetIDToken("") - session.SetEmail("") + session.SetUserIdentifier("") // Clear CSRF tokens as well to prevent any replay attacks session.SetCSRF("") session.SetNonce("") @@ -476,12 +476,18 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err) return false } - email, _ := claims["email"].(string) - if email == "" { - t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token") - return false + userIdentifier, _ := claims[t.userIdentifierClaim].(string) + if userIdentifier == "" { + if t.userIdentifierClaim != "sub" { + userIdentifier, _ = claims["sub"].(string) + } + if userIdentifier == "" { + t.logger.Errorf("refreshToken failed: User identifier claim '%s' missing or empty in refreshed token", t.userIdentifierClaim) + return false + } + t.logger.Debugf("Configured claim '%s' not found in refreshed token, using 'sub' claim as fallback", t.userIdentifierClaim) } - session.SetEmail(email) + session.SetUserIdentifier(userIdentifier) // Get token expiry information for logging var expiryTime time.Time @@ -507,7 +513,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se session.SetAccessToken("") session.SetIDToken("") session.SetRefreshToken("") - session.SetEmail("") + session.SetUserIdentifier("") return false }