From 9f96d8c38cd44a15f2642dc690b231ac8e1aa770 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Thu, 30 Apr 2026 18:23:43 +0100 Subject: [PATCH] fix(refresh): wire RefreshCoordinator into the live refresh path The RefreshCoordinator existed but was never instantiated. The actual refresh path used only session.refreshMutex, which is per-SessionData instance - and SessionData is pulled from a sync.Pool per request - so concurrent requests sharing a refresh token had ZERO coordination. Symptom: when access_token expired (e.g. 5min Zitadel default), every in-flight request from a polling client (Grafana panels) entered the refresh path simultaneously and POSTed the same refresh_token to the IdP. With refresh-token rotation enabled (Zitadel/Authentik default), only one grant succeeded; the rest got invalid_grant and each cleared the entire session. Subsequent requests then thrashed in re-auth loops. This commit: - adds refreshCoordinator field on TraefikOidc - instantiates it in NewWithContext with DefaultRefreshCoordinatorConfig - shuts it down in Close() under shutdownOnce - routes refreshToken() through the coordinator via coordinatedTokenRefresh, which collapses concurrent grants to a single upstream call per refresh_token hash - exports refreshCoordinatorSessionID for both internal hashing and the middleware-level wireup so dedup keys stay aligned Behavioural notes: - nil-coordinator fallback preserves existing tests that build TraefikOidc literals without going through the constructor - followers receive the same TokenResponse/error as the leader, so no per-instance code paths change - existing TestGetNewTokenWithRefreshToken_Concurrency still passes because it hits GetNewTokenWithRefreshToken directly, below the coordinator boundary Tests: - refresh_coordinator_wireup_test.go: 50 concurrent refreshes coalesce to <=2 upstream calls; distinct tokens still run in parallel; nil coordinator falls back cleanly --- main.go | 5 + refresh_coordinator.go | 13 +++ refresh_coordinator_wireup_test.go | 164 +++++++++++++++++++++++++++++ token_manager.go | 34 +++++- types.go | 1 + utilities.go | 5 + 6 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 refresh_coordinator_wireup_test.go diff --git a/main.go b/main.go index 9086980..ddc75e6 100644 --- a/main.go +++ b/main.go @@ -260,6 +260,11 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name tokenResilienceConfig := DefaultTokenResilienceConfig() t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger) + // Coalesces concurrent refresh-token grants per refresh_token to one upstream + // call, preventing the thundering herd that yields invalid_grant when the IdP + // rotates refresh tokens (Zitadel/Authentik default). + t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger) + t.extractClaimsFunc = extractClaims t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { t.defaultInitiateAuthentication(rw, req, session, redirectURL) diff --git a/refresh_coordinator.go b/refresh_coordinator.go index b741e26..ee81c5b 100644 --- a/refresh_coordinator.go +++ b/refresh_coordinator.go @@ -466,10 +466,23 @@ func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) { // hashRefreshToken creates a hash of the refresh token for deduplication func (rc *RefreshCoordinator) hashRefreshToken(token string) string { + return refreshCoordinatorSessionID(token) +} + +// refreshCoordinatorSessionID derives a stable identifier from a refresh token +// for both deduplication and per-session attempt tracking. Using sha256 of the +// raw token means each rotation produces a fresh sessionID with its own attempt +// budget, which is what we want. +func refreshCoordinatorSessionID(token string) string { hash := sha256.Sum256([]byte(token)) return hex.EncodeToString(hash[:]) } +// refreshCoordinatorWaitTimeout caps how long a request may wait for a +// coordinated refresh result. It is wider than RefreshTimeout so a follower +// always sees the leader's result instead of timing out independently. +const refreshCoordinatorWaitTimeout = 35 * time.Second + // isUnderMemoryPressure checks if the system is under memory pressure by // consulting the global memory monitor. Returns true when pressure reaches // High or Critical, at which point we refuse new refresh operations to diff --git a/refresh_coordinator_wireup_test.go b/refresh_coordinator_wireup_test.go new file mode 100644 index 0000000..5f658a9 --- /dev/null +++ b/refresh_coordinator_wireup_test.go @@ -0,0 +1,164 @@ +package traefikoidc + +import ( + "context" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" +) + +// stubTokenExchanger lets us count how many upstream refresh-token grants +// happen for a given refresh_token across concurrent middleware-level calls. +type stubTokenExchanger struct { + calls int32 + delay time.Duration + resp *TokenResponse +} + +func (s *stubTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) { + return nil, nil +} + +func (s *stubTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) { + atomic.AddInt32(&s.calls, 1) + if s.delay > 0 { + time.Sleep(s.delay) + } + return s.resp, nil +} + +func (s *stubTokenExchanger) RevokeTokenWithProvider(_, _ string) error { + return nil +} + +// TestCoordinatedTokenRefresh_SingleUpstreamCall verifies the wireup: many +// concurrent calls to coordinatedTokenRefresh with the same refresh token +// must collapse to a single tokenExchanger.GetNewTokenWithRefreshToken call. +// +// Without the wireup this assertion fails (one upstream call per goroutine). +func TestCoordinatedTokenRefresh_SingleUpstreamCall(t *testing.T) { + stub := &stubTokenExchanger{ + delay: 100 * time.Millisecond, + resp: &TokenResponse{ + AccessToken: "new_access", + RefreshToken: "new_refresh", + IDToken: "new_id", + ExpiresIn: 3600, + }, + } + + logger := NewLogger("error") + cfg := DefaultRefreshCoordinatorConfig() + cfg.MaxRefreshAttempts = 10000 + cfg.MaxConcurrentRefreshes = 32 + + oidc := &TraefikOidc{ + logger: logger, + tokenExchanger: stub, + refreshCoordinator: NewRefreshCoordinator(cfg, logger), + } + defer oidc.refreshCoordinator.Shutdown() + + const concurrency = 50 + var wg sync.WaitGroup + wg.Add(concurrency) + + req := httptest.NewRequest("GET", "/", nil) + start := make(chan struct{}) + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + <-start + resp, err := oidc.coordinatedTokenRefresh(req, "shared_refresh_token") + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + if resp == nil || resp.AccessToken != "new_access" { + t.Errorf("unexpected response: %+v", resp) + } + }() + } + + close(start) + wg.Wait() + + got := atomic.LoadInt32(&stub.calls) + // Up to 2 is acceptable to absorb the documented timing slack in the + // existing coordinator tests (e.g. operation just cleaned up before a + // late goroutine reads the in-flight map). Anything beyond that means + // coalescing is broken. + if got > 2 { + t.Fatalf("expected <=2 upstream refresh calls, got %d", got) + } +} + +// TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator verifies the nil +// coordinator path so existing tests that build TraefikOidc literals stay +// green. +func TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator(t *testing.T) { + stub := &stubTokenExchanger{ + resp: &TokenResponse{AccessToken: "ok"}, + } + + oidc := &TraefikOidc{ + logger: NewLogger("error"), + tokenExchanger: stub, + // refreshCoordinator deliberately nil + } + + resp, err := oidc.coordinatedTokenRefresh(nil, "rt") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil || resp.AccessToken != "ok" { + t.Fatalf("unexpected response: %+v", resp) + } + if got := atomic.LoadInt32(&stub.calls); got != 1 { + t.Fatalf("expected exactly 1 upstream call, got %d", got) + } +} + +// TestCoordinatedTokenRefresh_DistinctTokensRunInParallel verifies that +// distinct refresh tokens are not falsely coalesced. +func TestCoordinatedTokenRefresh_DistinctTokensRunInParallel(t *testing.T) { + stub := &stubTokenExchanger{ + delay: 20 * time.Millisecond, + resp: &TokenResponse{AccessToken: "ok"}, + } + + logger := NewLogger("error") + cfg := DefaultRefreshCoordinatorConfig() + cfg.MaxRefreshAttempts = 10000 + cfg.MaxConcurrentRefreshes = 32 + cfg.DeduplicationCleanupDelay = 0 + + oidc := &TraefikOidc{ + logger: logger, + tokenExchanger: stub, + refreshCoordinator: NewRefreshCoordinator(cfg, logger), + } + defer oidc.refreshCoordinator.Shutdown() + + const distinct = 8 + var wg sync.WaitGroup + wg.Add(distinct) + for i := 0; i < distinct; i++ { + i := i + go func() { + defer wg.Done() + _, err := oidc.coordinatedTokenRefresh(nil, refreshCoordinatorSessionID(string(rune('a'+i)))) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }() + } + wg.Wait() + + if got := atomic.LoadInt32(&stub.calls); int(got) != distinct { + t.Fatalf("expected %d distinct upstream calls, got %d", distinct, got) + } +} diff --git a/token_manager.go b/token_manager.go index 01e208b..a86f8bc 100644 --- a/token_manager.go +++ b/token_manager.go @@ -416,7 +416,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se } t.logger.Debugf("Attempting refresh with token starting with %s...", tokenPrefix) - newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken) + newToken, err := t.coordinatedTokenRefresh(req, initialRefreshToken) if err != nil { errMsg := err.Error() if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") { @@ -518,6 +518,38 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se return true } +// coordinatedTokenRefresh routes a refresh-token grant through the +// RefreshCoordinator so that concurrent requests sharing the same refresh +// token coalesce into a single upstream call. This prevents the thundering +// herd that yields invalid_grant when the IdP rotates refresh tokens. +// +// Falls back to a direct call when the coordinator is nil, which only +// happens in tests that build TraefikOidc literals without going through +// NewWithContext. +func (t *TraefikOidc) coordinatedTokenRefresh(req *http.Request, refreshToken string) (*TokenResponse, error) { + if t.refreshCoordinator == nil { + return t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken) + } + + parentCtx := context.Background() + if req != nil { + parentCtx = req.Context() + } + ctx, cancel := context.WithTimeout(parentCtx, refreshCoordinatorWaitTimeout) + defer cancel() + + sessionID := refreshCoordinatorSessionID(refreshToken) + + return t.refreshCoordinator.CoordinateRefresh( + ctx, + sessionID, + refreshToken, + func() (*TokenResponse, error) { + return t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken) + }, + ) +} + // RevokeToken revokes a token locally by adding it to the blacklist cache. // It removes the token from the verification cache and adds both the token // and its JTI (if present) to the blacklist to prevent future use. diff --git a/types.go b/types.go index bd6f980..c55d188 100644 --- a/types.go +++ b/types.go @@ -95,6 +95,7 @@ type TraefikOidc struct { cancelFunc context.CancelFunc errorRecoveryManager *ErrorRecoveryManager tokenResilienceManager *TokenResilienceManager + refreshCoordinator *RefreshCoordinator goroutineWG *sync.WaitGroup dcrConfig *DynamicClientRegistrationConfig dynamicClientRegistrar *DynamicClientRegistrar diff --git a/utilities.go b/utilities.go index 0df262b..c5d50e3 100644 --- a/utilities.go +++ b/utilities.go @@ -250,6 +250,11 @@ func (t *TraefikOidc) Close() error { t.safeLogDebug("metadataRefreshStopChan closed") } + if t.refreshCoordinator != nil { + t.refreshCoordinator.Shutdown() + t.safeLogDebug("refreshCoordinator shut down") + } + if t.goroutineWG != nil { done := make(chan struct{}) go func() {