mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
fix(refresh): wire RefreshCoordinator into the live refresh path
The RefreshCoordinator existed but was never instantiated. The actual refresh path used only session.refreshMutex, which is per-SessionData instance - and SessionData is pulled from a sync.Pool per request - so concurrent requests sharing a refresh token had ZERO coordination. Symptom: when access_token expired (e.g. 5min Zitadel default), every in-flight request from a polling client (Grafana panels) entered the refresh path simultaneously and POSTed the same refresh_token to the IdP. With refresh-token rotation enabled (Zitadel/Authentik default), only one grant succeeded; the rest got invalid_grant and each cleared the entire session. Subsequent requests then thrashed in re-auth loops. This commit: - adds refreshCoordinator field on TraefikOidc - instantiates it in NewWithContext with DefaultRefreshCoordinatorConfig - shuts it down in Close() under shutdownOnce - routes refreshToken() through the coordinator via coordinatedTokenRefresh, which collapses concurrent grants to a single upstream call per refresh_token hash - exports refreshCoordinatorSessionID for both internal hashing and the middleware-level wireup so dedup keys stay aligned Behavioural notes: - nil-coordinator fallback preserves existing tests that build TraefikOidc literals without going through the constructor - followers receive the same TokenResponse/error as the leader, so no per-instance code paths change - existing TestGetNewTokenWithRefreshToken_Concurrency still passes because it hits GetNewTokenWithRefreshToken directly, below the coordinator boundary Tests: - refresh_coordinator_wireup_test.go: 50 concurrent refreshes coalesce to <=2 upstream calls; distinct tokens still run in parallel; nil coordinator falls back cleanly
This commit is contained in:
@@ -260,6 +260,11 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
tokenResilienceConfig := DefaultTokenResilienceConfig()
|
||||
t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger)
|
||||
|
||||
// Coalesces concurrent refresh-token grants per refresh_token to one upstream
|
||||
// call, preventing the thundering herd that yields invalid_grant when the IdP
|
||||
// rotates refresh tokens (Zitadel/Authentik default).
|
||||
t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger)
|
||||
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
|
||||
@@ -466,10 +466,23 @@ func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
|
||||
|
||||
// hashRefreshToken creates a hash of the refresh token for deduplication
|
||||
func (rc *RefreshCoordinator) hashRefreshToken(token string) string {
|
||||
return refreshCoordinatorSessionID(token)
|
||||
}
|
||||
|
||||
// refreshCoordinatorSessionID derives a stable identifier from a refresh token
|
||||
// for both deduplication and per-session attempt tracking. Using sha256 of the
|
||||
// raw token means each rotation produces a fresh sessionID with its own attempt
|
||||
// budget, which is what we want.
|
||||
func refreshCoordinatorSessionID(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// refreshCoordinatorWaitTimeout caps how long a request may wait for a
|
||||
// coordinated refresh result. It is wider than RefreshTimeout so a follower
|
||||
// always sees the leader's result instead of timing out independently.
|
||||
const refreshCoordinatorWaitTimeout = 35 * time.Second
|
||||
|
||||
// isUnderMemoryPressure checks if the system is under memory pressure by
|
||||
// consulting the global memory monitor. Returns true when pressure reaches
|
||||
// High or Critical, at which point we refuse new refresh operations to
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// stubTokenExchanger lets us count how many upstream refresh-token grants
|
||||
// happen for a given refresh_token across concurrent middleware-level calls.
|
||||
type stubTokenExchanger struct {
|
||||
calls int32
|
||||
delay time.Duration
|
||||
resp *TokenResponse
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.delay > 0 {
|
||||
time.Sleep(s.delay)
|
||||
}
|
||||
return s.resp, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) RevokeTokenWithProvider(_, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_SingleUpstreamCall verifies the wireup: many
|
||||
// concurrent calls to coordinatedTokenRefresh with the same refresh token
|
||||
// must collapse to a single tokenExchanger.GetNewTokenWithRefreshToken call.
|
||||
//
|
||||
// Without the wireup this assertion fails (one upstream call per goroutine).
|
||||
func TestCoordinatedTokenRefresh_SingleUpstreamCall(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 100 * time.Millisecond,
|
||||
resp: &TokenResponse{
|
||||
AccessToken: "new_access",
|
||||
RefreshToken: "new_refresh",
|
||||
IDToken: "new_id",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const concurrency = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrency)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
start := make(chan struct{})
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
resp, err := oidc.coordinatedTokenRefresh(req, "shared_refresh_token")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "new_access" {
|
||||
t.Errorf("unexpected response: %+v", resp)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
got := atomic.LoadInt32(&stub.calls)
|
||||
// Up to 2 is acceptable to absorb the documented timing slack in the
|
||||
// existing coordinator tests (e.g. operation just cleaned up before a
|
||||
// late goroutine reads the in-flight map). Anything beyond that means
|
||||
// coalescing is broken.
|
||||
if got > 2 {
|
||||
t.Fatalf("expected <=2 upstream refresh calls, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator verifies the nil
|
||||
// coordinator path so existing tests that build TraefikOidc literals stay
|
||||
// green.
|
||||
func TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenExchanger: stub,
|
||||
// refreshCoordinator deliberately nil
|
||||
}
|
||||
|
||||
resp, err := oidc.coordinatedTokenRefresh(nil, "rt")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "ok" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 1 {
|
||||
t.Fatalf("expected exactly 1 upstream call, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_DistinctTokensRunInParallel verifies that
|
||||
// distinct refresh tokens are not falsely coalesced.
|
||||
func TestCoordinatedTokenRefresh_DistinctTokensRunInParallel(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 20 * time.Millisecond,
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
cfg.DeduplicationCleanupDelay = 0
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const distinct = 8
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(distinct)
|
||||
for i := 0; i < distinct; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := oidc.coordinatedTokenRefresh(nil, refreshCoordinatorSessionID(string(rune('a'+i))))
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(&stub.calls); int(got) != distinct {
|
||||
t.Fatalf("expected %d distinct upstream calls, got %d", distinct, got)
|
||||
}
|
||||
}
|
||||
+33
-1
@@ -416,7 +416,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
}
|
||||
t.logger.Debugf("Attempting refresh with token starting with %s...", tokenPrefix)
|
||||
|
||||
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
|
||||
newToken, err := t.coordinatedTokenRefresh(req, initialRefreshToken)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
|
||||
@@ -518,6 +518,38 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return true
|
||||
}
|
||||
|
||||
// coordinatedTokenRefresh routes a refresh-token grant through the
|
||||
// RefreshCoordinator so that concurrent requests sharing the same refresh
|
||||
// token coalesce into a single upstream call. This prevents the thundering
|
||||
// herd that yields invalid_grant when the IdP rotates refresh tokens.
|
||||
//
|
||||
// Falls back to a direct call when the coordinator is nil, which only
|
||||
// happens in tests that build TraefikOidc literals without going through
|
||||
// NewWithContext.
|
||||
func (t *TraefikOidc) coordinatedTokenRefresh(req *http.Request, refreshToken string) (*TokenResponse, error) {
|
||||
if t.refreshCoordinator == nil {
|
||||
return t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
}
|
||||
|
||||
parentCtx := context.Background()
|
||||
if req != nil {
|
||||
parentCtx = req.Context()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(parentCtx, refreshCoordinatorWaitTimeout)
|
||||
defer cancel()
|
||||
|
||||
sessionID := refreshCoordinatorSessionID(refreshToken)
|
||||
|
||||
return t.refreshCoordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
sessionID,
|
||||
refreshToken,
|
||||
func() (*TokenResponse, error) {
|
||||
return t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// RevokeToken revokes a token locally by adding it to the blacklist cache.
|
||||
// It removes the token from the verification cache and adds both the token
|
||||
// and its JTI (if present) to the blacklist to prevent future use.
|
||||
|
||||
@@ -95,6 +95,7 @@ type TraefikOidc struct {
|
||||
cancelFunc context.CancelFunc
|
||||
errorRecoveryManager *ErrorRecoveryManager
|
||||
tokenResilienceManager *TokenResilienceManager
|
||||
refreshCoordinator *RefreshCoordinator
|
||||
goroutineWG *sync.WaitGroup
|
||||
dcrConfig *DynamicClientRegistrationConfig
|
||||
dynamicClientRegistrar *DynamicClientRegistrar
|
||||
|
||||
@@ -250,6 +250,11 @@ func (t *TraefikOidc) Close() error {
|
||||
t.safeLogDebug("metadataRefreshStopChan closed")
|
||||
}
|
||||
|
||||
if t.refreshCoordinator != nil {
|
||||
t.refreshCoordinator.Shutdown()
|
||||
t.safeLogDebug("refreshCoordinator shut down")
|
||||
}
|
||||
|
||||
if t.goroutineWG != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
Reference in New Issue
Block a user