Files
traefikoidc/main_servehttp_test.go
lukaszraczylo 72e2b682bb fix: eliminate per-request global mutexes in Yaegi hot paths
The v1.0.14 fix replaced one contended sync.RWMutex (RefreshCoordinator.
refreshMutex) with sync.Map. Production showed the same death-spiral
signature recurring ~2 hours later — same shape, different mutex:
65 goroutines stuck on a sync.(*RWMutex).Lock at one address, pod
pinned at 1000m CPU, identical Yaegi runCfg/reflect.Value.Call stack
pattern. The mutex was RefreshCoordinator.attemptsMutex.

Generalising: under Yaegi (interpreted Go for traefik plugins), any
per-request global mutex acquisition is a latent serialization point.
reflect.Value.Call dispatch on a held lock turns a microsecond
critical section into a multi-millisecond one, and on a GOMAXPROCS=1
pod the queue is unbounded.

This commit removes every per-request global mutex on the hot path:

1. RefreshCoordinator.attemptsMutex (sync.RWMutex)
   sessionRefreshAttempts: map -> sync.Map.
   refreshAttemptTracker: all fields atomic (int32, int64 UnixNano,
   cooldownEndNano == 0 as the not-in-cooldown sentinel, replacing
   the inCooldown bool).
   isInCooldown / recordRefreshAttempt / recordRefreshSuccess /
   recordRefreshFailure all become lock-free. Cooldown entry uses
   CompareAndSwapInt64 so only one goroutine logs the transition.

2. RefreshCircuitBreaker.mutex (sync.RWMutex)
   lastFailureTime / lastSuccessTime -> atomic.Int64 UnixNano.
   state and failures already atomic.
   AllowRequest / RecordSuccess / RecordFailure now pure atomic ops.

3. TraefikOidc.firstRequestMutex (sync.Mutex)
   firstRequestReceived bool -> firstRequestStarted int32.
   metadataRefreshStarted bool -> metadataRefreshStartedAtomic int32.
   ServeHTTP bootstrap path uses CompareAndSwapInt32 — fires once,
   zero steady-state cost. Previously the mutex was acquired on
   every non-health request forever.

4. TraefikOidc.metadataRetryMutex (sync.Mutex)
   lastMetadataRetryTime time.Time -> lastMetadataRetryNano int64.
   The 30-second retry throttle is now a CAS on lastMetadataRetryNano.

cleanupStaleEntries iterates via sync.Map.Range; eviction is a
CompareAndDelete by pointer identity so a tracker freshly re-used by
a concurrent caller is not lost.

Empirical evidence (3 specialist-agent analysis of the v1.0.14 spike,
profiles in /tmp/traefik-spike-1779511683/):
  * mutex profile: 97% delay in sync.(*Mutex).Unlock via
    HTTPHandlerSwitcher -> accesslog -> metrics -> backoff.RetryNotify
  * 65 stuck goroutines at one RWMutex address (0x40022eb648),
    identical Yaegi CFG pointer, all on rc.attemptsMutex via
    recordRefreshAttempt + isInCooldown
  * traffic driver: long-lived in-cluster Go-http-client doing
    ~5.4 req/s POST embeddings via OIDC cookie session → same
    sessionID → contention all funnels to one tracker entry

Yaegi support for sync/atomic confirmed at
github.com/traefik/yaegi@v0.16.1/stdlib/go1_22_sync_atomic.go:
AddInt32/Int64, LoadInt32/Int64, StoreInt32/Int64,
CompareAndSwapInt32/Int64 all exposed via reflect.ValueOf. Yaegi
dispatches each call through reflect.Value.Call to the COMPILED
atomic.* function, which executes a single hardware CAS/LOCK-XADD
instruction. Each atomic op still pays Yaegi dispatch cost but
cannot block — no queueing, no death spiral.

Trade-off acknowledged: v1.0.15 issues ~6-8 atomic/sync.Map ops per
leader-path request vs the 4 mutex ops of v1.0.14. Under low
contention this is a modest CPU bump. Under high contention it's
an unbounded → bounded transformation. Net win.

All tests pass with -race; golangci-lint clean.
2026-05-23 10:47:21 +01:00

1206 lines
39 KiB
Go

package traefikoidc
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
// TestServeHTTP_ExcludedURLs tests the excluded URLs functionality
func TestServeHTTP_ExcludedURLs(t *testing.T) {
tests := []struct {
excludedURLs map[string]struct{}
name string
path string
shouldBypass bool
}{
{
name: "favicon excluded by default",
path: "/favicon.ico",
excludedURLs: defaultExcludedURLs,
shouldBypass: true,
},
{
name: "health endpoint excluded",
path: "/health",
excludedURLs: map[string]struct{}{"/health": {}},
shouldBypass: true,
},
{
name: "API endpoint excluded",
path: "/api/v1/status",
excludedURLs: map[string]struct{}{"/api": {}},
shouldBypass: true,
},
{
name: "normal path not excluded",
path: "/dashboard",
excludedURLs: map[string]struct{}{},
shouldBypass: false,
},
{
name: "metrics endpoint excluded",
path: "/metrics",
excludedURLs: map[string]struct{}{"/metrics": {}},
shouldBypass: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
oidc := &TraefikOidc{
excludedURLs: tt.excludedURLs,
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com", // Required for initialization check
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", tt.path, nil)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if tt.shouldBypass && !nextCalled {
t.Error("expected request to bypass OIDC, but next handler was not called")
}
})
}
}
// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the
// handshake must skip the OIDC redirect dance (clients can't follow it
// mid-stream) but it must STILL require an authenticated session, otherwise
// any caller could reach the backend by setting Accept: text/event-stream.
func TestServeHTTP_EventStream(t *testing.T) {
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
t.Run("unauthenticated_request_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated SSE bypass")
}
})
t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
// Build an authenticated session and inject its cookies onto req.
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated SSE request to be forwarded to backend")
}
if forwardedUser != "user@example.com" {
t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser)
}
})
}
// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket
// handshake bypasses the OIDC redirect (clients can't follow it) but the
// session must already be authenticated, otherwise the backend is exposed
// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`.
func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
}))
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated WS bypass")
}
})
t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
}))
req := httptest.NewRequest("GET", "/ws", nil)
// Mixed-case + multi-token Connection header to exercise parsing.
req.Header.Set("Connection", "keep-alive, Upgrade")
req.Header.Set("Upgrade", "WebSocket")
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("ws-user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated WS handshake to be forwarded to backend")
}
if forwardedUser != "ws-user@example.com" {
t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser)
}
})
t.Run("plain_http_does_not_bypass", func(t *testing.T) {
// Sanity: requests without Upgrade headers must NOT hit the WS
// bypass branch (otherwise the new code path could short-circuit
// normal authentication).
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("backend must not be called for unauthenticated plain HTTP")
}))
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "keep-alive")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code == http.StatusOK {
t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200")
}
})
}
// TestServeHTTP_InitializationTimeout tests initialization timeout handling
func TestServeHTTP_InitializationTimeout(t *testing.T) {
t.Run("timeout waiting for initialization", func(t *testing.T) {
// Use a shorter timeout for testing
oldTimeout := 30 * time.Second
shortTimeout := 100 * time.Millisecond
oidc := &TraefikOidc{
logger: NewLogger("debug"),
initComplete: make(chan struct{}), // Never close this to simulate timeout
sessionManager: createTestSessionManager(t),
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
}
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
// Start request in goroutine with short timeout
done := make(chan bool)
go func() {
// Override timeout in test
start := time.Now()
go func() {
time.Sleep(shortTimeout)
if time.Since(start) >= shortTimeout {
// Simulate timeout by canceling
close(done)
}
}()
oidc.ServeHTTP(rw, req)
}()
select {
case <-done:
// Timeout occurred as expected
case <-time.After(oldTimeout):
t.Error("request did not timeout as expected")
}
})
t.Run("successful initialization", func(t *testing.T) {
oidc := &TraefikOidc{
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
}
// Close init channel to signal completion
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
// Should not return an initialization error
if rw.Code == http.StatusServiceUnavailable {
t.Error("expected successful request after initialization")
}
})
}
// TestServeHTTP_CallbackAndLogout tests callback and logout path handling
func TestServeHTTP_CallbackAndLogout(t *testing.T) {
t.Run("callback path triggers callback handler", func(t *testing.T) {
oidc := &TraefikOidc{
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
tokenURL: "https://provider.example.com/token",
clientID: "test-client",
audience: "test-client",
clientSecret: "test-secret",
tokenHTTPClient: http.DefaultClient,
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
rw := httptest.NewRecorder()
// This will trigger handleCallback
oidc.ServeHTTP(rw, req)
// Check that we got a response (even if it's an error due to invalid code)
if rw.Code == 0 {
t.Error("expected response from callback handler")
}
})
t.Run("logout path triggers logout handler", func(t *testing.T) {
oidc := &TraefikOidc{
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
endSessionURL: "https://provider.example.com/logout",
postLogoutRedirectURI: "https://example.com",
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/logout", nil)
rw := httptest.NewRecorder()
// This will trigger handleLogout
oidc.ServeHTTP(rw, req)
// Check that we got a redirect response
if rw.Code != http.StatusFound && rw.Code != http.StatusSeeOther {
t.Errorf("expected redirect response, got %d", rw.Code)
}
})
}
// TestProcessAuthorizedRequest_Skipped tests the processAuthorizedRequest function
// NOTE: This test is currently skipped due to complex SessionData requirements.
// The function is tested indirectly through ServeHTTP tests above.
/*
func TestProcessAuthorizedRequest(t *testing.T) {
tests := []struct {
name string
setupSession func() *MockSessionData
setupOidc func() *TraefikOidc
expectedHeaders map[string]string
expectNextCalled bool
expectReauth bool
expectedStatus int
}{
{
name: "successful authorization with email",
setupSession: func() *MockSessionData {
session := &MockSessionData{
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
}
return session
},
setupOidc: func() *TraefikOidc {
return &TraefikOidc{
logger: NewLogger("debug"),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
},
expectedHeaders: map[string]string{
"X-Forwarded-User": "user@example.com",
"X-Auth-Request-User": "user@example.com",
"X-Auth-Request-Token": "test-id-token",
},
expectNextCalled: true,
expectReauth: false,
},
{
name: "no email triggers reauth",
setupSession: func() *MockSessionData {
return &MockSessionData{
userIdentifier: "",
idToken: "test-id-token",
accessToken: "test-access-token",
}
},
setupOidc: func() *TraefikOidc {
return &TraefikOidc{
logger: NewLogger("debug"),
authURL: "https://provider.example.com/auth",
clientID: "test-client",
audience: "test-client",
redirURLPath: "/callback",
}
},
expectNextCalled: false,
expectReauth: true,
},
{
name: "roles and groups authorization",
setupSession: func() *MockSessionData {
return &MockSessionData{
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
},
setupOidc: func() *TraefikOidc {
return &TraefikOidc{
logger: NewLogger("debug"),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
"users": {},
},
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"groups": []interface{}{"users", "developers"},
"roles": []interface{}{"viewer"},
}, nil
},
}
},
expectedHeaders: map[string]string{
"X-User-Groups": "users,developers",
"X-User-Roles": "viewer",
},
expectNextCalled: true,
},
{
name: "unauthorized role/group returns 403",
setupSession: func() *MockSessionData {
return &MockSessionData{
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
},
setupOidc: func() *TraefikOidc {
return &TraefikOidc{
logger: NewLogger("debug"),
logoutURLPath: "/logout",
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
},
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"groups": []interface{}{"users"},
"roles": []interface{}{"viewer"},
}, nil
},
}
},
expectNextCalled: false,
expectedStatus: http.StatusForbidden,
},
{
name: "template headers processing",
setupSession: func() *MockSessionData {
return &MockSessionData{
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
}
},
setupOidc: func() *TraefikOidc {
tmpl, _ := template.New("test").Parse("{{.Claims.email}}")
return &TraefikOidc{
logger: NewLogger("debug"),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
headerTemplates: map[string]*template.Template{
"X-Custom-Email": tmpl,
},
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
},
expectedHeaders: map[string]string{
"X-Custom-Email": "user@example.com",
},
expectNextCalled: true,
},
{
name: "OPTIONS request with CORS",
setupSession: func() *MockSessionData {
return &MockSessionData{
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
},
setupOidc: func() *TraefikOidc {
return &TraefikOidc{
logger: NewLogger("debug"),
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{}, nil
},
}
},
expectNextCalled: false, // OPTIONS returns immediately
expectedStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
session := tt.setupSession()
oidc := tt.setupOidc()
req := httptest.NewRequest("GET", "/protected", nil)
if strings.Contains(tt.name, "OPTIONS") {
req = httptest.NewRequest("OPTIONS", "/protected", nil)
req.Header.Set("Origin", "https://example.com")
}
rw := httptest.NewRecorder()
nextCalled := false
if oidc.next == nil {
oidc.next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
} else {
originalNext := oidc.next
oidc.next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
originalNext.ServeHTTP(w, r)
})
}
// Call the function - we need to use the concrete SessionData type
// For testing, we'll create a minimal SessionData that implements what we need
concreteSession := &SessionData{
manager: &SessionManager{logger: NewLogger("debug")},
}
// Copy values from mock to concrete session
concreteSession.SetUserIdentifier(session.userIdentifier)
concreteSession.SetIDToken(session.idToken)
concreteSession.SetAccessToken(session.accessToken)
concreteSession.SetRefreshToken(session.refreshToken)
concreteSession.SetAuthenticated(session.authenticated)
if session.isDirty {
concreteSession.MarkDirty()
}
oidc.processAuthorizedRequest(rw, req, concreteSession, "https://example.com/callback")
// Verify expectations
if tt.expectNextCalled && !nextCalled {
t.Error("expected next handler to be called")
}
if !tt.expectNextCalled && nextCalled {
t.Error("expected next handler NOT to be called")
}
// Check headers
for header, expectedValue := range tt.expectedHeaders {
if got := req.Header.Get(header); got != expectedValue {
t.Errorf("expected header %s = %q, got %q", header, expectedValue, got)
}
}
// Check status code if specified
if tt.expectedStatus > 0 && rw.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, rw.Code)
}
// Check security headers are set
securityHeaders := []string{
"X-Frame-Options",
"X-Content-Type-Options",
"X-XSS-Protection",
"Referrer-Policy",
}
for _, header := range securityHeaders {
if rw.Header().Get(header) == "" {
t.Errorf("expected security header %s to be set", header)
}
}
})
}
}
*/
// MockSessionData is a test implementation of SessionData interface
type MockSessionData struct {
userIdentifier string
idToken string
accessToken string
refreshToken string
csrf string
nonce string
codeVerifier string
redirectCount int
authenticated bool
isDirty bool
}
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
func (m *MockSessionData) GetIDToken() string { return m.idToken }
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
func (m *MockSessionData) SetAuthenticated(auth bool) { m.authenticated = auth }
func (m *MockSessionData) IsAuthenticated() bool { return m.authenticated }
func (m *MockSessionData) IsDirty() bool { return m.isDirty }
func (m *MockSessionData) MarkDirty() { m.isDirty = true }
func (m *MockSessionData) ResetRedirectCount() { m.redirectCount = 0 }
func (m *MockSessionData) IncrementRedirectCount() int { m.redirectCount++; return m.redirectCount }
func (m *MockSessionData) GetCSRF() string { return m.csrf }
func (m *MockSessionData) SetCSRF(csrf string) { m.csrf = csrf }
func (m *MockSessionData) GetNonce() string { return m.nonce }
func (m *MockSessionData) SetNonce(nonce string) { m.nonce = nonce }
func (m *MockSessionData) GetCodeVerifier() string { return m.codeVerifier }
func (m *MockSessionData) SetCodeVerifier(verifier string) { m.codeVerifier = verifier }
func (m *MockSessionData) Save(r *http.Request, w http.ResponseWriter) error { return nil }
func (m *MockSessionData) Clear(r *http.Request, w http.ResponseWriter) error { return nil }
// Helper function to create a test session manager
func createTestSessionManager(t *testing.T) *SessionManager {
sm, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
return sm
}
// TestMinimalHeaders tests the minimalHeaders configuration option
// This addresses GitHub issue #64 - Request Header Fields Too Large
func TestMinimalHeaders(t *testing.T) {
tests := []struct {
name string
minimalHeaders bool
expectForwardedUser bool
expectAuthRequestUser bool
expectAuthRequestRedirect bool
}{
{
name: "minimalHeaders=false (default) forwards all headers",
minimalHeaders: false,
expectForwardedUser: true,
expectAuthRequestUser: true,
expectAuthRequestRedirect: true,
},
{
name: "minimalHeaders=true only forwards X-Forwarded-User",
minimalHeaders: true,
expectForwardedUser: true,
expectAuthRequestUser: false,
expectAuthRequestRedirect: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Track which headers were set
var capturedHeaders http.Header
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
minimalHeaders: tt.minimalHeaders,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
close(oidc.initComplete)
// Create request and get session properly through session manager
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set up session data
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Call processAuthorizedRequest directly
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Verify X-Forwarded-User is always set
if tt.expectForwardedUser {
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
}
}
// Verify X-Auth-Request-User
hasAuthRequestUser := capturedHeaders.Get("X-Auth-Request-User") != ""
if tt.expectAuthRequestUser && !hasAuthRequestUser {
t.Error("expected X-Auth-Request-User to be set")
}
if !tt.expectAuthRequestUser && hasAuthRequestUser {
t.Errorf("expected X-Auth-Request-User to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-User"))
}
// Verify X-Auth-Request-Redirect
hasAuthRequestRedirect := capturedHeaders.Get("X-Auth-Request-Redirect") != ""
if tt.expectAuthRequestRedirect && !hasAuthRequestRedirect {
t.Error("expected X-Auth-Request-Redirect to be set")
}
if !tt.expectAuthRequestRedirect && hasAuthRequestRedirect {
t.Errorf("expected X-Auth-Request-Redirect to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Redirect"))
}
// Note: X-Auth-Request-Token is only set if session.GetIDToken() returns non-empty.
// Token storage has validation that may reject test tokens, so we verify the flag
// logic through the other headers. The important behavior is that when
// minimalHeaders=true, the token header would NOT be set even if a token existed.
})
}
}
// TestMinimalHeaders_TokenHeaderNotSet verifies that the X-Auth-Request-Token header
// is NOT set when minimalHeaders is enabled, even if a token exists.
func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
var capturedHeaders http.Header
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
minimalHeaders: true, // Enable minimal headers
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Verify X-Forwarded-User is set (always should be)
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
}
// The key verification: X-Auth-Request-Token should NOT be set with minimalHeaders=true
if capturedHeaders.Get("X-Auth-Request-Token") != "" {
t.Error("expected X-Auth-Request-Token to NOT be set with minimalHeaders=true")
}
// X-Auth-Request-User should also NOT be set with minimalHeaders=true
if capturedHeaders.Get("X-Auth-Request-User") != "" {
t.Error("expected X-Auth-Request-User to NOT be set with minimalHeaders=true")
}
// X-Auth-Request-Redirect should also NOT be set with minimalHeaders=true
if capturedHeaders.Get("X-Auth-Request-Redirect") != "" {
t.Error("expected X-Auth-Request-Redirect to NOT be set with minimalHeaders=true")
}
}
// TestStripAuthCookies tests the stripAuthCookies configuration option.
// This addresses GitHub issue #122 - OIDC cookies bloating backend requests.
func TestStripAuthCookies(t *testing.T) {
tests := []struct {
name string
stripAuthCookies bool
expectOIDCCookies bool
expectAppCookies bool
}{
{
name: "stripAuthCookies=false (default) forwards all cookies",
stripAuthCookies: false,
expectOIDCCookies: true,
expectAppCookies: true,
},
{
name: "stripAuthCookies=true strips OIDC cookies but keeps app cookies",
stripAuthCookies: true,
expectOIDCCookies: false,
expectAppCookies: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
cookiePrefix := sessionManager.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: tt.stripAuthCookies,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
// Get a valid session first (before adding fake cookies)
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Now add OIDC session cookies (simulating what the browser would send)
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_1", Value: "chunk1"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "r", Value: "refresh-token"})
// Add non-OIDC application cookies (these must always pass through)
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "app-session-id"})
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Check for OIDC cookies in captured cookies
hasOIDCCookie := false
hasAppSession := false
hasTheme := false
for _, c := range capturedCookies {
if len(c.Name) >= len(cookiePrefix) && c.Name[:len(cookiePrefix)] == cookiePrefix {
hasOIDCCookie = true
}
if c.Name == "my_app_session" {
hasAppSession = true
}
if c.Name == "theme" {
hasTheme = true
}
}
if tt.expectOIDCCookies && !hasOIDCCookie {
t.Error("expected OIDC cookies to be forwarded to backend")
}
if !tt.expectOIDCCookies && hasOIDCCookie {
t.Error("expected OIDC cookies to be stripped before forwarding to backend")
}
if tt.expectAppCookies && !hasAppSession {
t.Error("expected my_app_session cookie to be forwarded to backend")
}
if tt.expectAppCookies && !hasTheme {
t.Error("expected theme cookie to be forwarded to backend")
}
})
}
}
// TestStripAuthCookies_NoCookies verifies stripping works when the request has no cookies.
func TestStripAuthCookies_NoCookies(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 0 {
t.Errorf("expected no cookies, got %d", len(capturedCookies))
}
}
// TestStripAuthCookies_OnlyOIDCCookies verifies that when all cookies are OIDC cookies,
// the Cookie header is empty after stripping.
func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
var capturedCookieHeader string
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookieHeader = r.Header.Get("Cookie")
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
cookiePrefix := sessionManager.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only OIDC cookies
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 0 {
t.Errorf("expected all cookies to be stripped, got %d", len(capturedCookies))
}
if capturedCookieHeader != "" {
t.Errorf("expected empty Cookie header, got %q", capturedCookieHeader)
}
}
// TestStripAuthCookies_OnlyAppCookies verifies that non-OIDC cookies pass through
// untouched when stripping is enabled.
func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only non-OIDC cookies
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "abc123"})
req.AddCookie(&http.Cookie{Name: "lang", Value: "en"})
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 3 {
t.Errorf("expected 3 cookies, got %d", len(capturedCookies))
}
cookieNames := make(map[string]bool)
for _, c := range capturedCookies {
cookieNames[c.Name] = true
}
for _, expected := range []string{"my_app_session", "lang", "theme"} {
if !cookieNames[expected] {
t.Errorf("expected cookie %q to be forwarded", expected)
}
}
}
// TestStripAuthCookies_CustomPrefix verifies stripping works with a custom cookie prefix.
func TestStripAuthCookies_CustomPrefix(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
// Create session manager with custom prefix
sm, err := NewSessionManager("test-encryption-key-32-characters", false, "", "myapp_oidc_", 0, NewLogger("debug"))
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
customPrefix := sm.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add cookies with the custom prefix (should be stripped)
req.AddCookie(&http.Cookie{Name: customPrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: customPrefix + "s_0", Value: "chunk0"})
// Add default-prefix cookie (should NOT be stripped — different prefix)
req.AddCookie(&http.Cookie{Name: "_oidc_raczylo_m", Value: "other-session"})
// Add app cookie (should NOT be stripped)
req.AddCookie(&http.Cookie{Name: "my_app", Value: "val"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
cookieNames := make(map[string]bool)
for _, c := range capturedCookies {
cookieNames[c.Name] = true
}
// Custom prefix cookies should be stripped
if cookieNames[customPrefix+"m"] {
t.Errorf("expected cookie %q to be stripped", customPrefix+"m")
}
if cookieNames[customPrefix+"s_0"] {
t.Errorf("expected cookie %q to be stripped", customPrefix+"s_0")
}
// Default prefix cookie should pass through (different prefix)
if !cookieNames["_oidc_raczylo_m"] {
t.Error("expected _oidc_raczylo_m cookie to pass through (different prefix)")
}
// App cookie should pass through
if !cookieNames["my_app"] {
t.Error("expected my_app cookie to pass through")
}
}