Files
traefikoidc/main_servehttp_test.go
T
lukaszraczylo 9cbca4c4fb fix(refresh): honor userIdentifierClaim in token refresh path (#132)
patch-release

The refresh path in token_manager.go hardcoded the "email" claim when
extracting the user identifier from a refreshed ID token, ignoring the
configured userIdentifierClaim. Keycloak users without an email claim
(using sub or another identifier) were kicked out on refresh even
though their initial login worked.

The callback path (auth_flow.go:226-239) already honored
userIdentifierClaim with "sub" fallback; PR #100 (commit a316a98)
added that support but missed the refresh path.

Mirror the callback logic in refreshToken so both paths behave the same.

Cleanup: rename Get/SetEmail to Get/SetUserIdentifier on SessionData
to match the actual semantics. The slot already stored the configured
identifier (email, sub, oid, upn, preferred_username), only the API
name was misleading. Storage key "email" → "user_identifier" and
combinedSessionPayload field E (json:"e") → Ui (json:"ui").

Compat note: existing user sessions invalidate on upgrade — every active
user re-authenticates once after deploying this change.
2026-05-07 09:21:41 +01:00

1206 lines
39 KiB
Go

package traefikoidc
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
// TestServeHTTP_ExcludedURLs tests the excluded URLs functionality
func TestServeHTTP_ExcludedURLs(t *testing.T) {
tests := []struct {
excludedURLs map[string]struct{}
name string
path string
shouldBypass bool
}{
{
name: "favicon excluded by default",
path: "/favicon.ico",
excludedURLs: defaultExcludedURLs,
shouldBypass: true,
},
{
name: "health endpoint excluded",
path: "/health",
excludedURLs: map[string]struct{}{"/health": {}},
shouldBypass: true,
},
{
name: "API endpoint excluded",
path: "/api/v1/status",
excludedURLs: map[string]struct{}{"/api": {}},
shouldBypass: true,
},
{
name: "normal path not excluded",
path: "/dashboard",
excludedURLs: map[string]struct{}{},
shouldBypass: false,
},
{
name: "metrics endpoint excluded",
path: "/metrics",
excludedURLs: map[string]struct{}{"/metrics": {}},
shouldBypass: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
oidc := &TraefikOidc{
excludedURLs: tt.excludedURLs,
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com", // Required for initialization check
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", tt.path, nil)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if tt.shouldBypass && !nextCalled {
t.Error("expected request to bypass OIDC, but next handler was not called")
}
})
}
}
// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the
// handshake must skip the OIDC redirect dance (clients can't follow it
// mid-stream) but it must STILL require an authenticated session, otherwise
// any caller could reach the backend by setting Accept: text/event-stream.
func TestServeHTTP_EventStream(t *testing.T) {
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
t.Run("unauthenticated_request_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated SSE bypass")
}
})
t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
// Build an authenticated session and inject its cookies onto req.
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated SSE request to be forwarded to backend")
}
if forwardedUser != "user@example.com" {
t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser)
}
})
}
// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket
// handshake bypasses the OIDC redirect (clients can't follow it) but the
// session must already be authenticated, otherwise the backend is exposed
// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`.
func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
}))
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated WS bypass")
}
})
t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
}))
req := httptest.NewRequest("GET", "/ws", nil)
// Mixed-case + multi-token Connection header to exercise parsing.
req.Header.Set("Connection", "keep-alive, Upgrade")
req.Header.Set("Upgrade", "WebSocket")
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("ws-user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated WS handshake to be forwarded to backend")
}
if forwardedUser != "ws-user@example.com" {
t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser)
}
})
t.Run("plain_http_does_not_bypass", func(t *testing.T) {
// Sanity: requests without Upgrade headers must NOT hit the WS
// bypass branch (otherwise the new code path could short-circuit
// normal authentication).
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("backend must not be called for unauthenticated plain HTTP")
}))
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "keep-alive")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code == http.StatusOK {
t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200")
}
})
}
// TestServeHTTP_InitializationTimeout tests initialization timeout handling
func TestServeHTTP_InitializationTimeout(t *testing.T) {
t.Run("timeout waiting for initialization", func(t *testing.T) {
// Use a shorter timeout for testing
oldTimeout := 30 * time.Second
shortTimeout := 100 * time.Millisecond
oidc := &TraefikOidc{
logger: NewLogger("debug"),
initComplete: make(chan struct{}), // Never close this to simulate timeout
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
}
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
// Start request in goroutine with short timeout
done := make(chan bool)
go func() {
// Override timeout in test
start := time.Now()
go func() {
time.Sleep(shortTimeout)
if time.Since(start) >= shortTimeout {
// Simulate timeout by canceling
close(done)
}
}()
oidc.ServeHTTP(rw, req)
}()
select {
case <-done:
// Timeout occurred as expected
case <-time.After(oldTimeout):
t.Error("request did not timeout as expected")
}
})
t.Run("successful initialization", func(t *testing.T) {
oidc := &TraefikOidc{
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
}
// Close init channel to signal completion
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
// Should not return an initialization error
if rw.Code == http.StatusServiceUnavailable {
t.Error("expected successful request after initialization")
}
})
}
// TestServeHTTP_CallbackAndLogout tests callback and logout path handling
func TestServeHTTP_CallbackAndLogout(t *testing.T) {
t.Run("callback path triggers callback handler", func(t *testing.T) {
oidc := &TraefikOidc{
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
tokenURL: "https://provider.example.com/token",
clientID: "test-client",
audience: "test-client",
clientSecret: "test-secret",
tokenHTTPClient: http.DefaultClient,
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
// This will trigger handleCallback
oidc.ServeHTTP(rw, req)
// Check that we got a response (even if it's an error due to invalid code)
if rw.Code == 0 {
t.Error("expected response from callback handler")
}
})
t.Run("logout path triggers logout handler", func(t *testing.T) {
oidc := &TraefikOidc{
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
endSessionURL: "https://provider.example.com/logout",
postLogoutRedirectURI: "https://example.com",
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/logout", nil)
rw := httptest.NewRecorder()
// This will trigger handleLogout
oidc.ServeHTTP(rw, req)
// Check that we got a redirect response
if rw.Code != http.StatusFound && rw.Code != http.StatusSeeOther {
t.Errorf("expected redirect response, got %d", rw.Code)
}
})
}
// TestProcessAuthorizedRequest_Skipped tests the processAuthorizedRequest function
// NOTE: This test is currently skipped due to complex SessionData requirements.
// The function is tested indirectly through ServeHTTP tests above.
/*
func TestProcessAuthorizedRequest(t *testing.T) {
tests := []struct {
name string
setupSession func() *MockSessionData
setupOidc func() *TraefikOidc
expectedHeaders map[string]string
expectNextCalled bool
expectReauth bool
expectedStatus int
}{
{
name: "successful authorization with email",
setupSession: func() *MockSessionData {
session := &MockSessionData{
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
}
return session
},
setupOidc: func() *TraefikOidc {
return &TraefikOidc{
logger: NewLogger("debug"),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
},
expectedHeaders: map[string]string{
"X-Forwarded-User": "user@example.com",
"X-Auth-Request-User": "user@example.com",
"X-Auth-Request-Token": "test-id-token",
},
expectNextCalled: true,
expectReauth: false,
},
{
name: "no email triggers reauth",
setupSession: func() *MockSessionData {
return &MockSessionData{
userIdentifier: "",
idToken: "test-id-token",
accessToken: "test-access-token",
}
},
setupOidc: func() *TraefikOidc {
return &TraefikOidc{
logger: NewLogger("debug"),
authURL: "https://provider.example.com/auth",
clientID: "test-client",
audience: "test-client",
redirURLPath: "/callback",
}
},
expectNextCalled: false,
expectReauth: true,
},
{
name: "roles and groups authorization",
setupSession: func() *MockSessionData {
return &MockSessionData{
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
},
setupOidc: func() *TraefikOidc {
return &TraefikOidc{
logger: NewLogger("debug"),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
"users": {},
},
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"groups": []interface{}{"users", "developers"},
"roles": []interface{}{"viewer"},
}, nil
},
}
},
expectedHeaders: map[string]string{
"X-User-Groups": "users,developers",
"X-User-Roles": "viewer",
},
expectNextCalled: true,
},
{
name: "unauthorized role/group returns 403",
setupSession: func() *MockSessionData {
return &MockSessionData{
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
},
setupOidc: func() *TraefikOidc {
return &TraefikOidc{
logger: NewLogger("debug"),
logoutURLPath: "/logout",
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
},
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"groups": []interface{}{"users"},
"roles": []interface{}{"viewer"},
}, nil
},
}
},
expectNextCalled: false,
expectedStatus: http.StatusForbidden,
},
{
name: "template headers processing",
setupSession: func() *MockSessionData {
return &MockSessionData{
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
}
},
setupOidc: func() *TraefikOidc {
tmpl, _ := template.New("test").Parse("{{.Claims.email}}")
return &TraefikOidc{
logger: NewLogger("debug"),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
headerTemplates: map[string]*template.Template{
"X-Custom-Email": tmpl,
},
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
},
expectedHeaders: map[string]string{
"X-Custom-Email": "user@example.com",
},
expectNextCalled: true,
},
{
name: "OPTIONS request with CORS",
setupSession: func() *MockSessionData {
return &MockSessionData{
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
},
setupOidc: func() *TraefikOidc {
return &TraefikOidc{
logger: NewLogger("debug"),
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{}, nil
},
}
},
expectNextCalled: false, // OPTIONS returns immediately
expectedStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
session := tt.setupSession()
oidc := tt.setupOidc()
req := httptest.NewRequest("GET", "/protected", nil)
if strings.Contains(tt.name, "OPTIONS") {
req = httptest.NewRequest("OPTIONS", "/protected", nil)
req.Header.Set("Origin", "https://example.com")
}
rw := httptest.NewRecorder()
nextCalled := false
if oidc.next == nil {
oidc.next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
} else {
originalNext := oidc.next
oidc.next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
originalNext.ServeHTTP(w, r)
})
}
// Call the function - we need to use the concrete SessionData type
// For testing, we'll create a minimal SessionData that implements what we need
concreteSession := &SessionData{
manager: &SessionManager{logger: NewLogger("debug")},
}
// Copy values from mock to concrete session
concreteSession.SetUserIdentifier(session.userIdentifier)
concreteSession.SetIDToken(session.idToken)
concreteSession.SetAccessToken(session.accessToken)
concreteSession.SetRefreshToken(session.refreshToken)
concreteSession.SetAuthenticated(session.authenticated)
if session.isDirty {
concreteSession.MarkDirty()
}
oidc.processAuthorizedRequest(rw, req, concreteSession, "https://example.com/callback")
// Verify expectations
if tt.expectNextCalled && !nextCalled {
t.Error("expected next handler to be called")
}
if !tt.expectNextCalled && nextCalled {
t.Error("expected next handler NOT to be called")
}
// Check headers
for header, expectedValue := range tt.expectedHeaders {
if got := req.Header.Get(header); got != expectedValue {
t.Errorf("expected header %s = %q, got %q", header, expectedValue, got)
}
}
// Check status code if specified
if tt.expectedStatus > 0 && rw.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, rw.Code)
}
// Check security headers are set
securityHeaders := []string{
"X-Frame-Options",
"X-Content-Type-Options",
"X-XSS-Protection",
"Referrer-Policy",
}
for _, header := range securityHeaders {
if rw.Header().Get(header) == "" {
t.Errorf("expected security header %s to be set", header)
}
}
})
}
}
*/
// MockSessionData is a test implementation of SessionData interface
type MockSessionData struct {
userIdentifier string
idToken string
accessToken string
refreshToken string
csrf string
nonce string
codeVerifier string
redirectCount int
authenticated bool
isDirty bool
}
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
func (m *MockSessionData) GetIDToken() string { return m.idToken }
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
func (m *MockSessionData) SetAuthenticated(auth bool) { m.authenticated = auth }
func (m *MockSessionData) IsAuthenticated() bool { return m.authenticated }
func (m *MockSessionData) IsDirty() bool { return m.isDirty }
func (m *MockSessionData) MarkDirty() { m.isDirty = true }
func (m *MockSessionData) ResetRedirectCount() { m.redirectCount = 0 }
func (m *MockSessionData) IncrementRedirectCount() int { m.redirectCount++; return m.redirectCount }
func (m *MockSessionData) GetCSRF() string { return m.csrf }
func (m *MockSessionData) SetCSRF(csrf string) { m.csrf = csrf }
func (m *MockSessionData) GetNonce() string { return m.nonce }
func (m *MockSessionData) SetNonce(nonce string) { m.nonce = nonce }
func (m *MockSessionData) GetCodeVerifier() string { return m.codeVerifier }
func (m *MockSessionData) SetCodeVerifier(verifier string) { m.codeVerifier = verifier }
func (m *MockSessionData) Save(r *http.Request, w http.ResponseWriter) error { return nil }
func (m *MockSessionData) Clear(r *http.Request, w http.ResponseWriter) error { return nil }
// Helper function to create a test session manager
func createTestSessionManager(t *testing.T) *SessionManager {
sm, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
return sm
}
// TestMinimalHeaders tests the minimalHeaders configuration option
// This addresses GitHub issue #64 - Request Header Fields Too Large
func TestMinimalHeaders(t *testing.T) {
tests := []struct {
name string
minimalHeaders bool
expectForwardedUser bool
expectAuthRequestUser bool
expectAuthRequestRedirect bool
}{
{
name: "minimalHeaders=false (default) forwards all headers",
minimalHeaders: false,
expectForwardedUser: true,
expectAuthRequestUser: true,
expectAuthRequestRedirect: true,
},
{
name: "minimalHeaders=true only forwards X-Forwarded-User",
minimalHeaders: true,
expectForwardedUser: true,
expectAuthRequestUser: false,
expectAuthRequestRedirect: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Track which headers were set
var capturedHeaders http.Header
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
minimalHeaders: tt.minimalHeaders,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
close(oidc.initComplete)
// Create request and get session properly through session manager
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set up session data
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Call processAuthorizedRequest directly
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Verify X-Forwarded-User is always set
if tt.expectForwardedUser {
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
}
}
// Verify X-Auth-Request-User
hasAuthRequestUser := capturedHeaders.Get("X-Auth-Request-User") != ""
if tt.expectAuthRequestUser && !hasAuthRequestUser {
t.Error("expected X-Auth-Request-User to be set")
}
if !tt.expectAuthRequestUser && hasAuthRequestUser {
t.Errorf("expected X-Auth-Request-User to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-User"))
}
// Verify X-Auth-Request-Redirect
hasAuthRequestRedirect := capturedHeaders.Get("X-Auth-Request-Redirect") != ""
if tt.expectAuthRequestRedirect && !hasAuthRequestRedirect {
t.Error("expected X-Auth-Request-Redirect to be set")
}
if !tt.expectAuthRequestRedirect && hasAuthRequestRedirect {
t.Errorf("expected X-Auth-Request-Redirect to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Redirect"))
}
// Note: X-Auth-Request-Token is only set if session.GetIDToken() returns non-empty.
// Token storage has validation that may reject test tokens, so we verify the flag
// logic through the other headers. The important behavior is that when
// minimalHeaders=true, the token header would NOT be set even if a token existed.
})
}
}
// TestMinimalHeaders_TokenHeaderNotSet verifies that the X-Auth-Request-Token header
// is NOT set when minimalHeaders is enabled, even if a token exists.
func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
var capturedHeaders http.Header
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
minimalHeaders: true, // Enable minimal headers
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Verify X-Forwarded-User is set (always should be)
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
}
// The key verification: X-Auth-Request-Token should NOT be set with minimalHeaders=true
if capturedHeaders.Get("X-Auth-Request-Token") != "" {
t.Error("expected X-Auth-Request-Token to NOT be set with minimalHeaders=true")
}
// X-Auth-Request-User should also NOT be set with minimalHeaders=true
if capturedHeaders.Get("X-Auth-Request-User") != "" {
t.Error("expected X-Auth-Request-User to NOT be set with minimalHeaders=true")
}
// X-Auth-Request-Redirect should also NOT be set with minimalHeaders=true
if capturedHeaders.Get("X-Auth-Request-Redirect") != "" {
t.Error("expected X-Auth-Request-Redirect to NOT be set with minimalHeaders=true")
}
}
// TestStripAuthCookies tests the stripAuthCookies configuration option.
// This addresses GitHub issue #122 - OIDC cookies bloating backend requests.
func TestStripAuthCookies(t *testing.T) {
tests := []struct {
name string
stripAuthCookies bool
expectOIDCCookies bool
expectAppCookies bool
}{
{
name: "stripAuthCookies=false (default) forwards all cookies",
stripAuthCookies: false,
expectOIDCCookies: true,
expectAppCookies: true,
},
{
name: "stripAuthCookies=true strips OIDC cookies but keeps app cookies",
stripAuthCookies: true,
expectOIDCCookies: false,
expectAppCookies: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
cookiePrefix := sessionManager.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: tt.stripAuthCookies,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
// Get a valid session first (before adding fake cookies)
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Now add OIDC session cookies (simulating what the browser would send)
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_1", Value: "chunk1"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "r", Value: "refresh-token"})
// Add non-OIDC application cookies (these must always pass through)
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "app-session-id"})
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Check for OIDC cookies in captured cookies
hasOIDCCookie := false
hasAppSession := false
hasTheme := false
for _, c := range capturedCookies {
if len(c.Name) >= len(cookiePrefix) && c.Name[:len(cookiePrefix)] == cookiePrefix {
hasOIDCCookie = true
}
if c.Name == "my_app_session" {
hasAppSession = true
}
if c.Name == "theme" {
hasTheme = true
}
}
if tt.expectOIDCCookies && !hasOIDCCookie {
t.Error("expected OIDC cookies to be forwarded to backend")
}
if !tt.expectOIDCCookies && hasOIDCCookie {
t.Error("expected OIDC cookies to be stripped before forwarding to backend")
}
if tt.expectAppCookies && !hasAppSession {
t.Error("expected my_app_session cookie to be forwarded to backend")
}
if tt.expectAppCookies && !hasTheme {
t.Error("expected theme cookie to be forwarded to backend")
}
})
}
}
// TestStripAuthCookies_NoCookies verifies stripping works when the request has no cookies.
func TestStripAuthCookies_NoCookies(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 0 {
t.Errorf("expected no cookies, got %d", len(capturedCookies))
}
}
// TestStripAuthCookies_OnlyOIDCCookies verifies that when all cookies are OIDC cookies,
// the Cookie header is empty after stripping.
func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
var capturedCookieHeader string
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookieHeader = r.Header.Get("Cookie")
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
cookiePrefix := sessionManager.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only OIDC cookies
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 0 {
t.Errorf("expected all cookies to be stripped, got %d", len(capturedCookies))
}
if capturedCookieHeader != "" {
t.Errorf("expected empty Cookie header, got %q", capturedCookieHeader)
}
}
// TestStripAuthCookies_OnlyAppCookies verifies that non-OIDC cookies pass through
// untouched when stripping is enabled.
func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only non-OIDC cookies
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "abc123"})
req.AddCookie(&http.Cookie{Name: "lang", Value: "en"})
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 3 {
t.Errorf("expected 3 cookies, got %d", len(capturedCookies))
}
cookieNames := make(map[string]bool)
for _, c := range capturedCookies {
cookieNames[c.Name] = true
}
for _, expected := range []string{"my_app_session", "lang", "theme"} {
if !cookieNames[expected] {
t.Errorf("expected cookie %q to be forwarded", expected)
}
}
}
// TestStripAuthCookies_CustomPrefix verifies stripping works with a custom cookie prefix.
func TestStripAuthCookies_CustomPrefix(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
// Create session manager with custom prefix
sm, err := NewSessionManager("test-encryption-key-32-characters", false, "", "myapp_oidc_", 0, NewLogger("debug"))
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
customPrefix := sm.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add cookies with the custom prefix (should be stripped)
req.AddCookie(&http.Cookie{Name: customPrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: customPrefix + "s_0", Value: "chunk0"})
// Add default-prefix cookie (should NOT be stripped — different prefix)
req.AddCookie(&http.Cookie{Name: "_oidc_raczylo_m", Value: "other-session"})
// Add app cookie (should NOT be stripped)
req.AddCookie(&http.Cookie{Name: "my_app", Value: "val"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
cookieNames := make(map[string]bool)
for _, c := range capturedCookies {
cookieNames[c.Name] = true
}
// Custom prefix cookies should be stripped
if cookieNames[customPrefix+"m"] {
t.Errorf("expected cookie %q to be stripped", customPrefix+"m")
}
if cookieNames[customPrefix+"s_0"] {
t.Errorf("expected cookie %q to be stripped", customPrefix+"s_0")
}
// Default prefix cookie should pass through (different prefix)
if !cookieNames["_oidc_raczylo_m"] {
t.Error("expected _oidc_raczylo_m cookie to pass through (different prefix)")
}
// App cookie should pass through
if !cookieNames["my_app"] {
t.Error("expected my_app cookie to pass through")
}
}