diff --git a/bearer_auth_test.go b/bearer_auth_test.go index 94fcf60..118ef5f 100644 --- a/bearer_auth_test.go +++ b/bearer_auth_test.go @@ -71,8 +71,8 @@ func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc { logger: NewLogger("error"), initComplete: make(chan struct{}), sessionManager: sm, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://issuer.example.com", audience: "https://api.example.com", clientID: "https://api.example.com", diff --git a/issue67_regression_test.go b/issue67_regression_test.go index 9c4a3ca..8d2b2a5 100644 --- a/issue67_regression_test.go +++ b/issue67_regression_test.go @@ -478,11 +478,10 @@ func TestRefreshCoordinatorIntegration(t *testing.T) { // Test 3: Rate limiting t.Run("RateLimiting", func(t *testing.T) { - // Reset circuit breaker to closed state for this test - coordinator.circuitBreaker.mutex.Lock() + // Reset circuit breaker to closed state for this test. All fields are + // atomic so we don't need any mutex. atomic.StoreInt32(&coordinator.circuitBreaker.state, 0) // closed atomic.StoreInt32(&coordinator.circuitBreaker.failures, 0) - coordinator.circuitBreaker.mutex.Unlock() // Temporarily increase circuit breaker threshold to not interfere oldMaxFailures := coordinator.circuitBreaker.config.MaxFailures @@ -525,9 +524,11 @@ func TestRefreshCoordinatorIntegration(t *testing.T) { time.Sleep(config.CleanupInterval * 3) // Old sessions should be cleaned up - coordinator.attemptsMutex.RLock() - count := len(coordinator.sessionRefreshAttempts) - coordinator.attemptsMutex.RUnlock() + count := 0 + coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool { + count++ + return true + }) // Should have fewer sessions after cleanup if count > 10 { diff --git a/logout_test.go b/logout_test.go index 54597d6..4a42d14 100644 --- a/logout_test.go +++ b/logout_test.go @@ -415,8 +415,8 @@ func TestMiddlewareBackchannelLogoutRouting(t *testing.T) { clientID: "test-client", issuerURL: "https://provider.example.com", initComplete: make(chan struct{}), - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, logoutURLPath: "/logout", } close(oidc.initComplete) @@ -457,8 +457,8 @@ func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) { clientID: "test-client", issuerURL: "https://provider.example.com", initComplete: make(chan struct{}), - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, logoutURLPath: "/logout", } close(oidc.initComplete) diff --git a/main_initialization_test.go b/main_initialization_test.go index 45b9438..6906fa0 100644 --- a/main_initialization_test.go +++ b/main_initialization_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "strings" "sync" + "sync/atomic" "testing" "time" ) @@ -484,9 +485,8 @@ func TestFirstRequestHandling(t *testing.T) { defer server.Close() oidc := &TraefikOidc{ - providerURL: server.URL, - firstRequestReceived: false, - firstRequestMutex: sync.Mutex{}, + providerURL: server.URL, + firstRequestStarted: 0, httpClient: &http.Client{ Timeout: 5 * time.Second, }, @@ -508,19 +508,13 @@ func TestFirstRequestHandling(t *testing.T) { }, } - // Simulate first request processing - oidc.firstRequestMutex.Lock() - if !oidc.firstRequestReceived { - oidc.firstRequestReceived = true - oidc.firstRequestMutex.Unlock() - + // Simulate first request processing — single-firing via CAS. + if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) { // This would normally be called asynchronously go func() { oidc.initializeMetadata(server.URL) // initComplete is closed internally by initializeMetadata }() - } else { - oidc.firstRequestMutex.Unlock() } // Wait for initialization @@ -556,9 +550,8 @@ func TestFirstRequestHandling(t *testing.T) { defer server.Close() oidc := &TraefikOidc{ - providerURL: server.URL, - firstRequestReceived: false, - firstRequestMutex: sync.Mutex{}, + providerURL: server.URL, + firstRequestStarted: 0, httpClient: &http.Client{ Timeout: 5 * time.Second, }, @@ -580,31 +573,22 @@ func TestFirstRequestHandling(t *testing.T) { }, } - // Simulate multiple concurrent "first" requests + // Simulate multiple concurrent "first" requests — only one CAS winner + // fires the bootstrap path. const numRequests = 10 var wg sync.WaitGroup wg.Add(numRequests) - initStarted := 0 - var initMu sync.Mutex + var initStarted int32 for i := 0; i < numRequests; i++ { go func() { defer wg.Done() - oidc.firstRequestMutex.Lock() - if !oidc.firstRequestReceived { - oidc.firstRequestReceived = true - oidc.firstRequestMutex.Unlock() - - initMu.Lock() - initStarted++ - initMu.Unlock() - + if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) { + atomic.AddInt32(&initStarted, 1) // Only one should actually start initialization oidc.initializeMetadata(server.URL) - } else { - oidc.firstRequestMutex.Unlock() } }() } @@ -612,8 +596,8 @@ func TestFirstRequestHandling(t *testing.T) { wg.Wait() // Verify only one initialization was started - if initStarted != 1 { - t.Errorf("expected exactly 1 initialization, got %d", initStarted) + if atomic.LoadInt32(&initStarted) != 1 { + t.Errorf("expected exactly 1 initialization, got %d", atomic.LoadInt32(&initStarted)) } // The metadata endpoint might be called once or not at all depending on timing diff --git a/main_servehttp_test.go b/main_servehttp_test.go index c2aab43..272973b 100644 --- a/main_servehttp_test.go +++ b/main_servehttp_test.go @@ -61,8 +61,8 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: createTestSessionManager(t), - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", // Required for initialization check } close(oidc.initComplete) @@ -92,8 +92,8 @@ func TestServeHTTP_EventStream(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sessionManager, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", } close(oidc.initComplete) @@ -175,8 +175,8 @@ func TestServeHTTP_WebSocketUpgrade(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sessionManager, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", } close(oidc.initComplete) @@ -272,8 +272,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), // Never close this to simulate timeout sessionManager: createTestSessionManager(t), - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, } req := httptest.NewRequest("GET", "/protected", nil) @@ -307,8 +307,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: createTestSessionManager(t), - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", redirURLPath: "/callback", logoutURLPath: "/logout", @@ -337,8 +337,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: createTestSessionManager(t), - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", redirURLPath: "/callback", logoutURLPath: "/logout", @@ -367,8 +367,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: createTestSessionManager(t), - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", redirURLPath: "/callback", logoutURLPath: "/logout", @@ -740,8 +740,8 @@ func TestMinimalHeaders(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sessionManager, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", minimalHeaders: tt.minimalHeaders, extractClaimsFunc: func(token string) (map[string]interface{}, error) { @@ -817,8 +817,8 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sessionManager, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", minimalHeaders: true, // Enable minimal headers extractClaimsFunc: func(token string) (map[string]interface{}, error) { @@ -903,8 +903,8 @@ func TestStripAuthCookies(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sessionManager, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", stripAuthCookies: tt.stripAuthCookies, extractClaimsFunc: func(token string) (map[string]interface{}, error) { @@ -987,8 +987,8 @@ func TestStripAuthCookies_NoCookies(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sessionManager, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", stripAuthCookies: true, extractClaimsFunc: func(token string) (map[string]interface{}, error) { @@ -1034,8 +1034,8 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sessionManager, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", stripAuthCookies: true, extractClaimsFunc: func(token string) (map[string]interface{}, error) { @@ -1085,8 +1085,8 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sessionManager, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", stripAuthCookies: true, extractClaimsFunc: func(token string) (map[string]interface{}, error) { @@ -1148,8 +1148,8 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sm, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", stripAuthCookies: true, extractClaimsFunc: func(token string) (map[string]interface{}, error) { diff --git a/main_test.go b/main_test.go index 34b8ea1..58e487c 100644 --- a/main_test.go +++ b/main_test.go @@ -16,6 +16,7 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "testing" "time" @@ -2685,10 +2686,9 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) { providerAvailable = true mu.Unlock() - // Reset the retry timer to allow immediate retry - m.metadataRetryMutex.Lock() - m.lastMetadataRetryTime = time.Time{} // Reset to zero time - m.metadataRetryMutex.Unlock() + // Reset the retry timer to allow immediate retry. The field is atomic + // now, so no lock is needed. + atomic.StoreInt64(&m.lastMetadataRetryNano, 0) // Second request should trigger recovery attempt req2 := httptest.NewRequest("GET", "/protected", nil) diff --git a/middleware.go b/middleware.go index 9d6a5d7..6dc989e 100644 --- a/middleware.go +++ b/middleware.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "strings" + "sync/atomic" "time" "github.com/lukaszraczylo/traefikoidc/internal/utils" @@ -145,19 +146,20 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if !strings.HasPrefix(req.URL.Path, "/health") { - t.firstRequestMutex.Lock() - if !t.firstRequestReceived { - t.firstRequestReceived = true + // Lock-free one-shot bootstrap. The previous firstRequestMutex.Lock() + // fired on EVERY non-health request forever (even after the boolean + // flipped true), which under Yaegi added a per-request serialization + // point. CAS gives single-firing semantics with zero steady-state cost. + if atomic.CompareAndSwapInt32(&t.firstRequestStarted, 0, 1) { t.logger.Debug("Starting background tasks on first request") t.startTokenCleanup() - if !t.metadataRefreshStarted && t.providerURL != "" { - t.metadataRefreshStarted = true + if t.providerURL != "" && + atomic.CompareAndSwapInt32(&t.metadataRefreshStartedAtomic, 0, 1) { // Metadata refresh is handled by singleton resource manager t.startMetadataRefresh(t.providerURL) } } - t.firstRequestMutex.Unlock() } // Evaluate auth-bypass once, before waiting for initialization. Excluded @@ -213,14 +215,14 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.metadataMu.RUnlock() if issuerURL == "" { - // Provider metadata initialization failed - try to recover - // Retry every 30 seconds to allow automatic recovery when provider comes back online - t.metadataRetryMutex.Lock() - shouldRetry := time.Since(t.lastMetadataRetryTime) >= 30*time.Second - if shouldRetry { - t.lastMetadataRetryTime = time.Now() - } - t.metadataRetryMutex.Unlock() + // Provider metadata initialization failed - try to recover. + // Retry every 30 seconds to allow automatic recovery. Lock-free + // throttle via CAS on lastMetadataRetryNano: one goroutine wins + // the window, others see shouldRetry=false. + nowNano := time.Now().UnixNano() + last := atomic.LoadInt64(&t.lastMetadataRetryNano) + shouldRetry := time.Duration(nowNano-last) >= 30*time.Second && + atomic.CompareAndSwapInt64(&t.lastMetadataRetryNano, last, nowNano) if shouldRetry && t.providerURL != "" { t.logger.Info("Attempting to recover OIDC provider metadata...") diff --git a/middleware_edge_cases_test.go b/middleware_edge_cases_test.go index 5aab76c..0bfccc2 100644 --- a/middleware_edge_cases_test.go +++ b/middleware_edge_cases_test.go @@ -13,8 +13,8 @@ func TestMiddlewareContextCancellation(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), // Never close to simulate waiting sessionManager: createTestSessionManager(t), - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, } // Create request with canceled context @@ -39,8 +39,8 @@ func TestMiddlewareSessionErrorRecovery(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: createTestSessionManager(t), - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", redirURLPath: "/callback", logoutURLPath: "/logout", @@ -73,8 +73,8 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: createTestSessionManager(t), - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", redirURLPath: "/callback", logoutURLPath: "/logout", @@ -102,8 +102,8 @@ func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), // Never close to simulate provider unavailable sessionManager: createTestSessionManager(t), - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, logoutURLPath: "/logout", postLogoutRedirectURI: "/", forceHTTPS: false, @@ -142,8 +142,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sessionManager, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", redirURLPath: "/callback", logoutURLPath: "/logout", @@ -187,8 +187,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sessionManager, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", redirURLPath: "/callback", logoutURLPath: "/logout", @@ -236,8 +236,8 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) { logger: NewLogger("debug"), initComplete: make(chan struct{}), sessionManager: sessionManager, - firstRequestReceived: true, - metadataRefreshStarted: true, + firstRequestStarted: 1, + metadataRefreshStartedAtomic: 1, issuerURL: "https://provider.example.com", redirURLPath: "/callback", logoutURLPath: "/logout", diff --git a/refresh_coordinator.go b/refresh_coordinator.go index 587ca88..3edca74 100644 --- a/refresh_coordinator.go +++ b/refresh_coordinator.go @@ -21,16 +21,23 @@ type RefreshCoordinator struct { // refreshMutex.Lock() was held for tens of milliseconds per request due // to interpreter overhead on the work inside the critical section, // causing dozens of goroutines to stack up on it and pin one CPU core. - inFlightRefreshes sync.Map + inFlightRefreshes sync.Map + // sessionRefreshAttempts maps sessionID -> *refreshAttemptTracker. + // sync.Map + atomic tracker fields means isInCooldown/recordRefreshAttempt/ + // recordRefreshSuccess/recordRefreshFailure are lock-free. Previously + // these used attemptsMutex sync.RWMutex; under Yaegi every Lock() acquisition + // adds 10-50ms of dispatch overhead, and they were called twice per leader + // request (once for recordRefreshAttempt, once for isInCooldown). That + // serializing pattern caused the v1.0.15 death spiral after v1.0.14 + // removed the refreshMutex (same architectural shape, different mutex). + sessionRefreshAttempts sync.Map cleanupTimers map[string]*time.Timer - sessionRefreshAttempts map[string]*refreshAttemptTracker circuitBreaker *RefreshCircuitBreaker metrics *RefreshMetrics logger *Logger stopChan chan struct{} config RefreshCoordinatorConfig wg sync.WaitGroup - attemptsMutex sync.RWMutex cleanupTimerMu sync.Mutex } @@ -89,14 +96,22 @@ type refreshResult struct { fromCache bool } -// refreshAttemptTracker tracks refresh attempts for a session +// refreshAttemptTracker tracks refresh attempts for a session. All fields are +// accessed via sync/atomic so isInCooldown/recordRefreshAttempt/Success/Failure +// can run without holding any per-coordinator lock. Times are UnixNano so they +// fit in an int64 and can be read with a single atomic.LoadInt64. +// +// cooldownEndNano == 0 means "not in cooldown". This sentinel replaces the +// inCooldown bool that the previous implementation kept under attemptsMutex — +// under Yaegi any per-request global mutex turns into a serializing bottleneck +// (the v1.0.14 refreshMutex -> sync.Map fix removed only one such bottleneck; +// attemptsMutex was the next one in the queue). type refreshAttemptTracker struct { - lastAttemptTime time.Time - windowStartTime time.Time - cooldownEndTime time.Time - attempts int32 - consecutiveFailures int32 - inCooldown bool + lastAttemptNano int64 // atomic, UnixNano of last attempt + windowStartNano int64 // atomic, UnixNano of attempt-window start + cooldownEndNano int64 // atomic, UnixNano; 0 = not in cooldown + attempts int32 // atomic + consecutiveFailures int32 // atomic } // RefreshMetrics tracks coordinator performance metrics @@ -111,14 +126,18 @@ type RefreshMetrics struct { currentInFlightRefreshes int32 } -// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations +// RefreshCircuitBreaker implements a circuit breaker specifically for refresh +// operations. All mutable fields are atomic so AllowRequest/RecordSuccess/ +// RecordFailure run without any mutex. The previous sync.RWMutex.RLock() was +// taken on every CoordinateRefresh — under Yaegi this added 10-50ms of +// interpreter dispatch per call, which compounded with attemptsMutex to keep +// the pod's single CPU core saturated. type RefreshCircuitBreaker struct { - lastFailureTime time.Time - lastSuccessTime time.Time + lastFailureNano int64 // atomic, UnixNano of most recent failure + lastSuccessNano int64 // atomic, UnixNano of most recent success config RefreshCircuitBreakerConfig - mutex sync.RWMutex - state int32 - failures int32 + state int32 // atomic: 0=closed, 1=open, 2=half-open + failures int32 // atomic } // RefreshCircuitBreakerConfig configures the refresh circuit breaker @@ -135,13 +154,13 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref } rc := &RefreshCoordinator{ - // inFlightRefreshes is a sync.Map; zero value is ready to use. - sessionRefreshAttempts: make(map[string]*refreshAttemptTracker), - config: config, - metrics: &RefreshMetrics{}, - logger: logger, - stopChan: make(chan struct{}), - cleanupTimers: make(map[string]*time.Timer), + // inFlightRefreshes and sessionRefreshAttempts are both sync.Map; + // their zero values are ready to use. + config: config, + metrics: &RefreshMetrics{}, + logger: logger, + stopChan: make(chan struct{}), + cleanupTimers: make(map[string]*time.Timer), circuitBreaker: &RefreshCircuitBreaker{ config: RefreshCircuitBreakerConfig{ MaxFailures: 3, @@ -415,86 +434,99 @@ func (rc *RefreshCoordinator) performCleanup(tokenHash string) { } } -// isInCooldown checks if a session is in cooldown after recording an attempt -func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool { - rc.attemptsMutex.Lock() - defer rc.attemptsMutex.Unlock() +// getOrCreateTracker fetches the tracker for sessionID or atomically creates a +// fresh one. The sync.Map.LoadOrStore semantics make this lock-free even under +// concurrent first-touch races: at most one tracker per sessionID survives. +// +// trackerFromMapValue centralizes the type assertion so the lint-mandated +// two-value form lives in one place; the stored type is always +// *refreshAttemptTracker by construction. +func trackerFromMapValue(v interface{}) *refreshAttemptTracker { + t, _ := v.(*refreshAttemptTracker) + return t +} - tracker, exists := rc.sessionRefreshAttempts[sessionID] - if !exists { +func (rc *RefreshCoordinator) getOrCreateTracker(sessionID string) *refreshAttemptTracker { + if v, ok := rc.sessionRefreshAttempts.Load(sessionID); ok { + return trackerFromMapValue(v) + } + fresh := &refreshAttemptTracker{ + windowStartNano: time.Now().UnixNano(), + } + actual, _ := rc.sessionRefreshAttempts.LoadOrStore(sessionID, fresh) + return trackerFromMapValue(actual) +} + +// isInCooldown checks if a session is in cooldown. Lock-free read with a +// best-effort cooldown-reset CAS on the cooldownEndNano sentinel. If the +// reset races with another goroutine we accept the loser's view (the winner's +// reset still happens). The attempt-window expiry and limit-exceeded paths +// are write-mostly but use atomic.StoreInt64/AddInt32 — never a held lock. +func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool { + v, ok := rc.sessionRefreshAttempts.Load(sessionID) + if !ok { return false // No tracker means first attempt, not in cooldown } - + tracker := trackerFromMapValue(v) now := time.Now() + nowNano := now.UnixNano() - // Check if already in cooldown - if tracker.inCooldown { - if now.After(tracker.cooldownEndTime) { - // Cooldown expired, reset tracker - tracker.inCooldown = false - tracker.attempts = 1 // Already recorded one attempt - tracker.consecutiveFailures = 0 - tracker.windowStartTime = now - return false + // Already in cooldown? + if cooldownEnd := atomic.LoadInt64(&tracker.cooldownEndNano); cooldownEnd != 0 { + if nowNano <= cooldownEnd { + return true // still in cooldown + } + // Cooldown expired. Best-effort reset (a concurrent caller may also + // reset; the result is equivalent — fresh window + one recorded + // attempt — so the CAS race is benign). + if atomic.CompareAndSwapInt64(&tracker.cooldownEndNano, cooldownEnd, 0) { + atomic.StoreInt32(&tracker.attempts, 1) + atomic.StoreInt32(&tracker.consecutiveFailures, 0) + atomic.StoreInt64(&tracker.windowStartNano, nowNano) } - return true // Still in cooldown - } - - // Check if window expired - if now.Sub(tracker.windowStartTime) > rc.config.RefreshAttemptWindow { - // Reset window - tracker.attempts = 1 // Already recorded one attempt - tracker.windowStartTime = now return false } - // Check if just exceeded attempt limit - if int(tracker.attempts) >= rc.config.MaxRefreshAttempts { - // Enter cooldown now - tracker.inCooldown = true - tracker.cooldownEndTime = now.Add(rc.config.RefreshCooldownPeriod) - rc.logger.Infof("Session %s entering refresh cooldown after %d attempts", - sessionID, tracker.attempts) + // Window expired? + if windowStart := atomic.LoadInt64(&tracker.windowStartNano); time.Duration(nowNano-windowStart) > rc.config.RefreshAttemptWindow { + atomic.StoreInt32(&tracker.attempts, 1) + atomic.StoreInt64(&tracker.windowStartNano, nowNano) + return false + } + + // Just exceeded attempt limit? + if int(atomic.LoadInt32(&tracker.attempts)) >= rc.config.MaxRefreshAttempts { + end := now.Add(rc.config.RefreshCooldownPeriod).UnixNano() + // Only one CAS winner publishes the cooldown end + logs. + if atomic.CompareAndSwapInt64(&tracker.cooldownEndNano, 0, end) { + rc.logger.Infof("Session %s entering refresh cooldown after %d attempts", + sessionID, atomic.LoadInt32(&tracker.attempts)) + } return true } return false } -// recordRefreshAttempt records a refresh attempt for rate limiting +// recordRefreshAttempt records a refresh attempt for rate limiting. Lock-free: +// LoadOrStore for the tracker, atomic counters/timestamps for fields. func (rc *RefreshCoordinator) recordRefreshAttempt(sessionID string) { - rc.attemptsMutex.Lock() - defer rc.attemptsMutex.Unlock() - - tracker, exists := rc.sessionRefreshAttempts[sessionID] - if !exists { - tracker = &refreshAttemptTracker{ - windowStartTime: time.Now(), - } - rc.sessionRefreshAttempts[sessionID] = tracker - } - + tracker := rc.getOrCreateTracker(sessionID) atomic.AddInt32(&tracker.attempts, 1) - tracker.lastAttemptTime = time.Now() + atomic.StoreInt64(&tracker.lastAttemptNano, time.Now().UnixNano()) } -// recordRefreshSuccess records a successful refresh +// recordRefreshSuccess records a successful refresh. Lock-free. func (rc *RefreshCoordinator) recordRefreshSuccess(sessionID string) { - rc.attemptsMutex.Lock() - defer rc.attemptsMutex.Unlock() - - if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists { - tracker.consecutiveFailures = 0 + if v, ok := rc.sessionRefreshAttempts.Load(sessionID); ok { + atomic.StoreInt32(&trackerFromMapValue(v).consecutiveFailures, 0) } } -// recordRefreshFailure records a failed refresh +// recordRefreshFailure records a failed refresh. Lock-free. func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) { - rc.attemptsMutex.Lock() - defer rc.attemptsMutex.Unlock() - - if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists { - atomic.AddInt32(&tracker.consecutiveFailures, 1) + if v, ok := rc.sessionRefreshAttempts.Load(sessionID); ok { + atomic.AddInt32(&trackerFromMapValue(v).consecutiveFailures, 1) } } @@ -546,20 +578,22 @@ func (rc *RefreshCoordinator) cleanupRoutine() { } } -// cleanupStaleEntries removes outdated tracking entries +// cleanupStaleEntries removes outdated tracking entries. Lock-free iteration +// via sync.Map.Range; safe to race with concurrent reads/writes. func (rc *RefreshCoordinator) cleanupStaleEntries() { - now := time.Now() - - rc.attemptsMutex.Lock() - defer rc.attemptsMutex.Unlock() - - // Clean up old session trackers - for sessionID, tracker := range rc.sessionRefreshAttempts { - // Remove trackers that haven't been used recently - if now.Sub(tracker.lastAttemptTime) > 2*rc.config.RefreshAttemptWindow { - delete(rc.sessionRefreshAttempts, sessionID) + cutoff := time.Now().Add(-2 * rc.config.RefreshAttemptWindow).UnixNano() + rc.sessionRefreshAttempts.Range(func(key, value interface{}) bool { + tracker := trackerFromMapValue(value) + if tracker == nil { + return true } - } + if atomic.LoadInt64(&tracker.lastAttemptNano) < cutoff { + // Compare-and-delete to avoid evicting a tracker that was just + // re-used by a concurrent caller. We compare by pointer identity. + rc.sessionRefreshAttempts.CompareAndDelete(key, value) + } + return true + }) } // GetMetrics returns current coordinator metrics @@ -592,63 +626,51 @@ func (rc *RefreshCoordinator) Shutdown() { rc.wg.Wait() } -// AllowRequest checks if the circuit breaker allows a request +// AllowRequest reports whether the circuit breaker allows a request. Lock-free. func (cb *RefreshCircuitBreaker) AllowRequest() bool { - cb.mutex.RLock() - defer cb.mutex.RUnlock() - - state := atomic.LoadInt32(&cb.state) - - switch state { - case 0: // Closed + switch atomic.LoadInt32(&cb.state) { + case 0: // closed return true - case 1: // Open - if time.Since(cb.lastFailureTime) > cb.config.OpenDuration { - // Try to transition to half-open + case 1: // open + lastFail := atomic.LoadInt64(&cb.lastFailureNano) + if time.Duration(time.Now().UnixNano()-lastFail) > cb.config.OpenDuration { + // Transition to half-open; first CAS winner gets the probe. if atomic.CompareAndSwapInt32(&cb.state, 1, 2) { return true } } return false - case 2: // Half-open + case 2: // half-open return true default: return false } } -// RecordSuccess records a successful operation +// RecordSuccess records a successful operation. Lock-free. func (cb *RefreshCircuitBreaker) RecordSuccess() { - cb.mutex.Lock() - defer cb.mutex.Unlock() - - state := atomic.LoadInt32(&cb.state) - if state == 2 { // Half-open - // Close the circuit + switch atomic.LoadInt32(&cb.state) { + case 2: // half-open -> close atomic.StoreInt32(&cb.state, 0) atomic.StoreInt32(&cb.failures, 0) - } else if state == 0 { // Closed - // Reset failure count on success + case 0: // closed atomic.StoreInt32(&cb.failures, 0) } - cb.lastSuccessTime = time.Now() + atomic.StoreInt64(&cb.lastSuccessNano, time.Now().UnixNano()) } -// RecordFailure records a failed operation +// RecordFailure records a failed operation. Lock-free. func (cb *RefreshCircuitBreaker) RecordFailure() { - cb.mutex.Lock() - defer cb.mutex.Unlock() - failures := atomic.AddInt32(&cb.failures, 1) - cb.lastFailureTime = time.Now() + atomic.StoreInt64(&cb.lastFailureNano, time.Now().UnixNano()) - state := atomic.LoadInt32(&cb.state) - - if state == 0 && int(failures) >= cb.config.MaxFailures { - // Open the circuit - atomic.StoreInt32(&cb.state, 1) - } else if state == 2 { - // Half-open failed, return to open + switch atomic.LoadInt32(&cb.state) { + case 0: + if int(failures) >= cb.config.MaxFailures { + atomic.StoreInt32(&cb.state, 1) + } + case 2: + // Half-open probe failed -> back to open. atomic.StoreInt32(&cb.state, 1) } } diff --git a/refresh_coordinator_test.go b/refresh_coordinator_test.go index ea26baf..4e9e45f 100644 --- a/refresh_coordinator_test.go +++ b/refresh_coordinator_test.go @@ -365,10 +365,12 @@ func TestMemoryLeakPrevention(t *testing.T) { } } - // Verify cleanup is working - coordinator.attemptsMutex.RLock() - sessionCount := len(coordinator.sessionRefreshAttempts) - coordinator.attemptsMutex.RUnlock() + // Verify cleanup is working. sync.Map has no Len(); count via Range. + sessionCount := 0 + coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool { + sessionCount++ + return true + }) // Should have cleaned up old sessions (only recent ones remain) if sessionCount > numWorkers*2 { @@ -650,24 +652,23 @@ func TestCleanupRoutine(t *testing.T) { coordinator.recordRefreshAttempt(fmt.Sprintf("session_%d", i)) } - // Verify sessions exist - coordinator.attemptsMutex.RLock() - initialCount := len(coordinator.sessionRefreshAttempts) - coordinator.attemptsMutex.RUnlock() + countSessions := func() int { + n := 0 + coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool { + n++ + return true + }) + return n + } - if initialCount != 5 { + if initialCount := countSessions(); initialCount != 5 { t.Errorf("Expected 5 sessions, got %d", initialCount) } // Wait for cleanup to run (2x window + cleanup interval) time.Sleep(2*config.RefreshAttemptWindow + 2*config.CleanupInterval) - // Verify sessions were cleaned up - coordinator.attemptsMutex.RLock() - finalCount := len(coordinator.sessionRefreshAttempts) - coordinator.attemptsMutex.RUnlock() - - if finalCount != 0 { + if finalCount := countSessions(); finalCount != 0 { t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount) } } diff --git a/types.go b/types.go index 06323df..2381d2f 100644 --- a/types.go +++ b/types.go @@ -65,7 +65,19 @@ type ProviderMetadata struct { // the complete authentication flow. It's designed to work seamlessly with Traefik's // plugin system and provides flexible configuration options. type TraefikOidc struct { - lastMetadataRetryTime time.Time + // lastMetadataRetryNano is the UnixNano timestamp of the last metadata + // recovery attempt. Stored atomically so the hot ServeHTTP path can + // throttle retries without acquiring metadataRetryMutex on every request. + lastMetadataRetryNano int64 + // firstRequestStarted is 0 until the very first non-health request fires + // the background-task bootstrap; then it flips to 1 via CAS. Replaces the + // firstRequestMutex + firstRequestReceived combo which previously took + // a write lock on every non-health request forever. + firstRequestStarted int32 + // metadataRefreshStartedAtomic is the CAS-only variant of the old + // metadataRefreshStarted bool. Both flags live under the same atomic so + // concurrent first-request goroutines race exactly once. + metadataRefreshStartedAtomic int32 jwkCache JWKCacheInterface jwtVerifier JWTVerifier ctx context.Context @@ -130,17 +142,13 @@ type TraefikOidc struct { maxRefreshTokenAge time.Duration metadataMu sync.RWMutex shutdownOnce sync.Once - metadataRetryMutex sync.Mutex - firstRequestMutex sync.Mutex sessionInvalidationCache CacheInterface refreshResultCache CacheInterface minimalHeaders bool stripAuthCookies bool enableBackchannelLogout bool enableFrontchannelLogout bool - firstRequestReceived bool requireTokenIntrospection bool - metadataRefreshStarted bool allowPrivateIPAddresses bool disableReplayDetection bool allowOpaqueTokens bool