diff --git a/auth_flow.go b/auth_flow.go index 9cf672d..47b3113 100644 --- a/auth_flow.go +++ b/auth_flow.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "strings" + "time" ) // validateRedirectCount checks if redirect limit is exceeded and handles the error @@ -360,9 +361,31 @@ func (t *TraefikOidc) isNonNavigationRequest(req *http.Request) bool { return !strings.Contains(accept, "text/html") } -// isRefreshTokenExpired checks if refresh token is likely expired (older than 6 hours) +// isRefreshTokenExpired checks whether the stored refresh token is likely +// past its useful lifetime, using the cookie-side issued_at timestamp set by +// SetRefreshToken. IdPs do not expose RT TTL on the wire, so this is a +// conservative heuristic gated by t.maxRefreshTokenAge (default 6h, set via +// MaxRefreshTokenAgeSeconds; 0 disables the check). +// +// The point of this check is to short-circuit the refresh path BEFORE the +// thundering herd hits the IdP for a token the provider has almost certainly +// revoked. Together with the RefreshCoordinator wireup, it keeps Grafana- +// style polling clients from looping on invalid_grant after a long pause. func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool { - // This is a heuristic check - actual implementation would depend on - // the specific provider and token metadata - return false // Placeholder implementation + if t == nil || session == nil { + return false + } + if t.maxRefreshTokenAge <= 0 { + return false + } + + issuedAt := session.GetRefreshTokenIssuedAt() + if issuedAt.IsZero() { + // No timestamp recorded (legacy session pre-dating the issued_at + // field). Don't force a re-auth - attempt refresh once and let the + // IdP be the source of truth. + return false + } + + return time.Since(issuedAt) > t.maxRefreshTokenAge } diff --git a/cache_manager.go b/cache_manager.go index 399f3f2..0508cbd 100644 --- a/cache_manager.go +++ b/cache_manager.go @@ -113,6 +113,14 @@ func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface { return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true} } +// GetSharedRefreshResultCache returns the short-lived refresh-result cache used +// by the refresh path to coalesce grants across Traefik replicas via Redis. +func (cm *CacheManager) GetSharedRefreshResultCache() CacheInterface { + cm.mu.RLock() + defer cm.mu.RUnlock() + return &CacheInterfaceWrapper{cache: cm.manager.GetRefreshResultCache(), managed: true} +} + // Close gracefully shuts down all cache components func (cm *CacheManager) Close() error { cm.mu.Lock() diff --git a/internal/cache/backends/hybrid.go b/internal/cache/backends/hybrid.go index b8cf0e6..890c513 100644 --- a/internal/cache/backends/hybrid.go +++ b/internal/cache/backends/hybrid.go @@ -20,6 +20,7 @@ type HybridBackend struct { ctx context.Context syncWriteCacheTypes map[string]bool asyncWriteBuffer chan *asyncWriteItem + l1BackfillBuffer chan *l1BackfillItem cancel context.CancelFunc wg sync.WaitGroup l1Hits atomic.Int64 @@ -28,6 +29,7 @@ type HybridBackend struct { l1Writes atomic.Int64 misses atomic.Int64 l2Hits atomic.Int64 + l1BackfillDrops atomic.Int64 fallbackMode atomic.Bool } @@ -39,6 +41,15 @@ type asyncWriteItem struct { ttl time.Duration } +// l1BackfillItem represents a deferred write of an L2-resolved value back into +// L1. Backfills run on a single bounded worker so a burst of L2 hits cannot +// detonate the goroutine count (issue: ~1000% CPU under sustained polling). +type l1BackfillItem struct { + key string + value []byte + ttl time.Duration +} + // Logger interface for structured logging type Logger interface { Debugf(format string, args ...interface{}) @@ -114,6 +125,7 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) { secondary: config.Secondary, syncWriteCacheTypes: config.SyncWriteCacheTypes, asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize), + l1BackfillBuffer: make(chan *l1BackfillItem, config.AsyncBufferSize), ctx: ctx, cancel: cancel, logger: config.Logger, @@ -123,6 +135,11 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) { h.wg.Add(1) go h.asyncWriteWorker() + // Start L1 backfill worker (single goroutine) to bound goroutine growth on + // L2 hits regardless of request rate. + h.wg.Add(1) + go h.l1BackfillWorker() + // Start health monitoring h.wg.Add(1) go h.healthMonitor() @@ -223,18 +240,10 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat h.l2Hits.Add(1) - // Populate L1 cache with value from L2 (write-through on read) - // Use goroutine to avoid blocking the read path - go func() { - writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - if err := h.primary.Set(writeCtx, key, value, ttl); err != nil { - h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err) - } else { - h.logger.Debugf("Populated L1 cache from L2 for key: %s", key) - } - }() + // Populate L1 cache with value from L2 (write-through on read). + // Hand off to the bounded backfill worker instead of spawning a goroutine + // per read - under burst that would mint thousands of goroutines. + h.queueL1Backfill(key, value, ttl) return value, ttl, true, nil } @@ -371,6 +380,7 @@ func (h *HybridBackend) Close() error { // Close async write channel close(h.asyncWriteBuffer) + close(h.l1BackfillBuffer) // Wait for workers to finish with timeout done := make(chan struct{}) @@ -440,13 +450,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string] for key, value := range l2Results { results[key] = value h.l2Hits.Add(1) - - // Asynchronously populate L1 - go func(k string, v []byte) { - writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - _ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL - }(key, value) + h.queueL1Backfill(key, value, 0) // 0 = primary backend default TTL } } } else { @@ -455,13 +459,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string] if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists { results[key] = value h.l2Hits.Add(1) - - // Asynchronously populate L1 - go func(k string, v []byte, t time.Duration) { - writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - _ = h.primary.Set(writeCtx, k, v, t) - }(key, value, ttl) + h.queueL1Backfill(key, value, ttl) } } } @@ -538,6 +536,55 @@ func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, tt return nil } +// queueL1Backfill enqueues an L2-resolved value for write-through into L1. +// Drops on full buffer to keep the read path constant-time; the next L2 hit +// for the same key simply re-queues it. +func (h *HybridBackend) queueL1Backfill(key string, value []byte, ttl time.Duration) { + select { + case h.l1BackfillBuffer <- &l1BackfillItem{key: key, value: value, ttl: ttl}: + default: + h.l1BackfillDrops.Add(1) + h.logger.Debugf("L1 backfill buffer full, dropping for key: %s", key) + } +} + +// l1BackfillWorker drains the backfill queue serially. Single worker is +// intentional - L1 writes are local and cheap, and serializing them keeps +// goroutine count bounded under any read rate. +func (h *HybridBackend) l1BackfillWorker() { + defer h.wg.Done() + + for { + select { + case <-h.ctx.Done(): + // Drain remaining items best-effort then exit. + for len(h.l1BackfillBuffer) > 0 { + select { + case item := <-h.l1BackfillBuffer: + writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + _ = h.primary.Set(writeCtx, item.key, item.value, item.ttl) + cancel() + default: + return + } + } + return + + case item, ok := <-h.l1BackfillBuffer: + if !ok { + return + } + writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + if err := h.primary.Set(writeCtx, item.key, item.value, item.ttl); err != nil { + h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", item.key, err) + } else { + h.logger.Debugf("Populated L1 cache from L2 for key: %s", item.key) + } + cancel() + } + } +} + // asyncWriteWorker processes asynchronous writes to L2 func (h *HybridBackend) asyncWriteWorker() { defer h.wg.Done() diff --git a/internal/cache/backends/hybrid_l1_backfill_test.go b/internal/cache/backends/hybrid_l1_backfill_test.go new file mode 100644 index 0000000..fb33cb6 --- /dev/null +++ b/internal/cache/backends/hybrid_l1_backfill_test.go @@ -0,0 +1,112 @@ +//go:build !yaegi + +package backends + +import ( + "context" + "fmt" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestHybridBackend_L1BackfillBounded verifies that a burst of L2 hits does +// not detonate the goroutine count. Pre-fix the code spawned one goroutine +// per Get() L2 hit; post-fix all backfills funnel through a single worker. +func TestHybridBackend_L1BackfillBounded(t *testing.T) { + primary := newMockBackend() + secondary := newMockBackend() + + hybrid, err := NewHybridBackend(&HybridConfig{ + Primary: primary, + Secondary: secondary, + AsyncBufferSize: 256, + }) + require.NoError(t, err) + defer hybrid.Close() + + ctx := context.Background() + const burst = 1000 + + // Pre-populate L2 with `burst` distinct keys so each Get triggers a + // fresh L1 backfill enqueue. + for i := 0; i < burst; i++ { + require.NoError(t, secondary.Set(ctx, fmt.Sprintf("k:%d", i), []byte("v"), time.Minute)) + } + + baseline := runtime.NumGoroutine() + + // Issue the burst as fast as possible; the backfill worker MUST be the + // only goroutine doing L1 writes. Allow brief slack for the test runtime + // scheduling but anything north of +20 means goroutine leakage. + peak := baseline + for i := 0; i < burst; i++ { + _, _, exists, err := hybrid.Get(ctx, fmt.Sprintf("k:%d", i)) + require.NoError(t, err) + require.True(t, exists) + if g := runtime.NumGoroutine(); g > peak { + peak = g + } + } + + delta := peak - baseline + if delta > 20 { + t.Fatalf("goroutine count grew by %d during burst (baseline=%d peak=%d); backfill worker not bounding goroutines", + delta, baseline, peak) + } + + // L1 must eventually catch up via the worker. Worker drains serially so + // give it a generous window proportional to the burst size. + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + var populated int + for i := 0; i < burst; i++ { + if _, _, ok, _ := primary.Get(ctx, fmt.Sprintf("k:%d", i)); ok { + populated++ + } + } + // Be lenient: drops are acceptable under buffer pressure, just want + // most of the keys to make it. + if populated >= burst-int(hybrid.l1BackfillDrops.Load()) { + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("L1 not backfilled within deadline: l2Hits=%d l1Writes=%d drops=%d", + hybrid.l2Hits.Load(), hybrid.l1Writes.Load(), hybrid.l1BackfillDrops.Load()) +} + +// TestHybridBackend_L1BackfillFullDrops verifies the drop semantics when the +// buffer is saturated. Drops must be counted, never block, never spawn a +// goroutine. +func TestHybridBackend_L1BackfillFullDrops(t *testing.T) { + primary := newMockBackend() + secondary := newMockBackend() + + // Tiny buffer + slow primary writes via failSet so the worker stays + // blocked enough to overflow the buffer. + hybrid, err := NewHybridBackend(&HybridConfig{ + Primary: primary, + Secondary: secondary, + AsyncBufferSize: 4, + }) + require.NoError(t, err) + defer hybrid.Close() + + // Stop the worker from draining: cancel the underlying context so the + // worker bails out, leaving us with a cold buffer and the queue method + // itself responsible for drop accounting. + hybrid.cancel() + // Wait for worker to exit so it can't drain. + time.Sleep(50 * time.Millisecond) + + for i := 0; i < 50; i++ { + hybrid.queueL1Backfill(fmt.Sprintf("k:%d", i), []byte("v"), time.Minute) + } + + assert.Greater(t, hybrid.l1BackfillDrops.Load(), int64(0), + "expected some drops when buffer is saturated and worker is stopped") +} diff --git a/main.go b/main.go index 9086980..38747e8 100644 --- a/main.go +++ b/main.go @@ -226,6 +226,13 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name } return 60 * time.Second }(), + maxRefreshTokenAge: func() time.Duration { + // 0 (or unset) disables the heuristic; negative is rejected by Validate. + if config.MaxRefreshTokenAgeSeconds > 0 { + return time.Duration(config.MaxRefreshTokenAgeSeconds) * time.Second + } + return 0 + }(), tokenCleanupStopChan: make(chan struct{}), metadataRefreshStopChan: make(chan struct{}), ctx: pluginCtx, @@ -242,6 +249,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL), frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL), sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(), + refreshResultCache: cacheManager.GetSharedRefreshResultCache(), } // Log audience configuration @@ -260,6 +268,11 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name tokenResilienceConfig := DefaultTokenResilienceConfig() t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger) + // Coalesces concurrent refresh-token grants per refresh_token to one upstream + // call, preventing the thundering herd that yields invalid_grant when the IdP + // rotates refresh tokens (Zitadel/Authentik default). + t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger) + t.extractClaimsFunc = extractClaims t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) { t.defaultInitiateAuthentication(rw, req, session, redirectURL) diff --git a/refresh_coordinator.go b/refresh_coordinator.go index b741e26..ee81c5b 100644 --- a/refresh_coordinator.go +++ b/refresh_coordinator.go @@ -466,10 +466,23 @@ func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) { // hashRefreshToken creates a hash of the refresh token for deduplication func (rc *RefreshCoordinator) hashRefreshToken(token string) string { + return refreshCoordinatorSessionID(token) +} + +// refreshCoordinatorSessionID derives a stable identifier from a refresh token +// for both deduplication and per-session attempt tracking. Using sha256 of the +// raw token means each rotation produces a fresh sessionID with its own attempt +// budget, which is what we want. +func refreshCoordinatorSessionID(token string) string { hash := sha256.Sum256([]byte(token)) return hex.EncodeToString(hash[:]) } +// refreshCoordinatorWaitTimeout caps how long a request may wait for a +// coordinated refresh result. It is wider than RefreshTimeout so a follower +// always sees the leader's result instead of timing out independently. +const refreshCoordinatorWaitTimeout = 35 * time.Second + // isUnderMemoryPressure checks if the system is under memory pressure by // consulting the global memory monitor. Returns true when pressure reaches // High or Critical, at which point we refuse new refresh operations to diff --git a/refresh_coordinator_wireup_test.go b/refresh_coordinator_wireup_test.go new file mode 100644 index 0000000..5f658a9 --- /dev/null +++ b/refresh_coordinator_wireup_test.go @@ -0,0 +1,164 @@ +package traefikoidc + +import ( + "context" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" +) + +// stubTokenExchanger lets us count how many upstream refresh-token grants +// happen for a given refresh_token across concurrent middleware-level calls. +type stubTokenExchanger struct { + calls int32 + delay time.Duration + resp *TokenResponse +} + +func (s *stubTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) { + return nil, nil +} + +func (s *stubTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) { + atomic.AddInt32(&s.calls, 1) + if s.delay > 0 { + time.Sleep(s.delay) + } + return s.resp, nil +} + +func (s *stubTokenExchanger) RevokeTokenWithProvider(_, _ string) error { + return nil +} + +// TestCoordinatedTokenRefresh_SingleUpstreamCall verifies the wireup: many +// concurrent calls to coordinatedTokenRefresh with the same refresh token +// must collapse to a single tokenExchanger.GetNewTokenWithRefreshToken call. +// +// Without the wireup this assertion fails (one upstream call per goroutine). +func TestCoordinatedTokenRefresh_SingleUpstreamCall(t *testing.T) { + stub := &stubTokenExchanger{ + delay: 100 * time.Millisecond, + resp: &TokenResponse{ + AccessToken: "new_access", + RefreshToken: "new_refresh", + IDToken: "new_id", + ExpiresIn: 3600, + }, + } + + logger := NewLogger("error") + cfg := DefaultRefreshCoordinatorConfig() + cfg.MaxRefreshAttempts = 10000 + cfg.MaxConcurrentRefreshes = 32 + + oidc := &TraefikOidc{ + logger: logger, + tokenExchanger: stub, + refreshCoordinator: NewRefreshCoordinator(cfg, logger), + } + defer oidc.refreshCoordinator.Shutdown() + + const concurrency = 50 + var wg sync.WaitGroup + wg.Add(concurrency) + + req := httptest.NewRequest("GET", "/", nil) + start := make(chan struct{}) + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + <-start + resp, err := oidc.coordinatedTokenRefresh(req, "shared_refresh_token") + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + if resp == nil || resp.AccessToken != "new_access" { + t.Errorf("unexpected response: %+v", resp) + } + }() + } + + close(start) + wg.Wait() + + got := atomic.LoadInt32(&stub.calls) + // Up to 2 is acceptable to absorb the documented timing slack in the + // existing coordinator tests (e.g. operation just cleaned up before a + // late goroutine reads the in-flight map). Anything beyond that means + // coalescing is broken. + if got > 2 { + t.Fatalf("expected <=2 upstream refresh calls, got %d", got) + } +} + +// TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator verifies the nil +// coordinator path so existing tests that build TraefikOidc literals stay +// green. +func TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator(t *testing.T) { + stub := &stubTokenExchanger{ + resp: &TokenResponse{AccessToken: "ok"}, + } + + oidc := &TraefikOidc{ + logger: NewLogger("error"), + tokenExchanger: stub, + // refreshCoordinator deliberately nil + } + + resp, err := oidc.coordinatedTokenRefresh(nil, "rt") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil || resp.AccessToken != "ok" { + t.Fatalf("unexpected response: %+v", resp) + } + if got := atomic.LoadInt32(&stub.calls); got != 1 { + t.Fatalf("expected exactly 1 upstream call, got %d", got) + } +} + +// TestCoordinatedTokenRefresh_DistinctTokensRunInParallel verifies that +// distinct refresh tokens are not falsely coalesced. +func TestCoordinatedTokenRefresh_DistinctTokensRunInParallel(t *testing.T) { + stub := &stubTokenExchanger{ + delay: 20 * time.Millisecond, + resp: &TokenResponse{AccessToken: "ok"}, + } + + logger := NewLogger("error") + cfg := DefaultRefreshCoordinatorConfig() + cfg.MaxRefreshAttempts = 10000 + cfg.MaxConcurrentRefreshes = 32 + cfg.DeduplicationCleanupDelay = 0 + + oidc := &TraefikOidc{ + logger: logger, + tokenExchanger: stub, + refreshCoordinator: NewRefreshCoordinator(cfg, logger), + } + defer oidc.refreshCoordinator.Shutdown() + + const distinct = 8 + var wg sync.WaitGroup + wg.Add(distinct) + for i := 0; i < distinct; i++ { + i := i + go func() { + defer wg.Done() + _, err := oidc.coordinatedTokenRefresh(nil, refreshCoordinatorSessionID(string(rune('a'+i)))) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }() + } + wg.Wait() + + if got := atomic.LoadInt32(&stub.calls); int(got) != distinct { + t.Fatalf("expected %d distinct upstream calls, got %d", distinct, got) + } +} diff --git a/refresh_distributed_test.go b/refresh_distributed_test.go new file mode 100644 index 0000000..8aa7e97 --- /dev/null +++ b/refresh_distributed_test.go @@ -0,0 +1,186 @@ +package traefikoidc + +import ( + "context" + "errors" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" +) + +// inMemoryCache is the smallest CacheInterface that satisfies the cross- +// replica dedup contract: Set/Get with TTL. Used in place of the universal +// cache singleton so these tests stay hermetic. +type inMemoryCache struct { + entries map[string]inMemoryCacheEntry + mu sync.Mutex +} + +type inMemoryCacheEntry struct { + expiresAt time.Time + value interface{} +} + +func newInMemoryCache() *inMemoryCache { + return &inMemoryCache{entries: make(map[string]inMemoryCacheEntry)} +} + +func (c *inMemoryCache) Set(key string, value any, ttl time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.entries[key] = inMemoryCacheEntry{value: value, expiresAt: time.Now().Add(ttl)} +} + +func (c *inMemoryCache) Get(key string) (any, bool) { + c.mu.Lock() + defer c.mu.Unlock() + e, ok := c.entries[key] + if !ok { + return nil, false + } + if time.Now().After(e.expiresAt) { + delete(c.entries, key) + return nil, false + } + return e.value, true +} + +func (c *inMemoryCache) Delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.entries, key) +} + +func (c *inMemoryCache) SetMaxSize(int) {} +func (c *inMemoryCache) Cleanup() {} +func (c *inMemoryCache) Close() {} +func (c *inMemoryCache) Size() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.entries) +} +func (c *inMemoryCache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + c.entries = map[string]inMemoryCacheEntry{} +} +func (c *inMemoryCache) GetStats() map[string]any { return map[string]any{} } + +// erroringTokenExchanger always errors - simulates an IdP rejection. +type erroringTokenExchanger struct { + calls int32 +} + +func (e *erroringTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) { + return nil, errors.New("not used") +} + +func (e *erroringTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) { + atomic.AddInt32(&e.calls, 1) + return nil, errors.New("invalid_grant") +} + +func (e *erroringTokenExchanger) RevokeTokenWithProvider(_, _ string) error { return nil } + +// TestCoordinatedTokenRefresh_CrossReplicaCacheHit simulates a peer Traefik +// replica having just refreshed: the shared cache already has the result, so +// this pod must reuse it without ever calling the IdP. +func TestCoordinatedTokenRefresh_CrossReplicaCacheHit(t *testing.T) { + stub := &stubTokenExchanger{ + resp: &TokenResponse{AccessToken: "should_not_be_called"}, + } + + logger := NewLogger("error") + cache := newInMemoryCache() + preExisting := &TokenResponse{ + AccessToken: "from_peer", + RefreshToken: "rotated_by_peer", + IDToken: "id_from_peer", + } + rt := "shared_refresh_token" + cache.Set(refreshResultCacheKey(refreshCoordinatorSessionID(rt)), preExisting, refreshResultCacheTTL) + + oidc := &TraefikOidc{ + logger: logger, + tokenExchanger: stub, + refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger), + refreshResultCache: cache, + } + defer oidc.refreshCoordinator.Shutdown() + + resp, err := oidc.coordinatedTokenRefresh(httptest.NewRequest("GET", "/", nil), rt) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil || resp.AccessToken != "from_peer" { + t.Fatalf("expected peer-provided response, got %+v", resp) + } + if got := atomic.LoadInt32(&stub.calls); got != 0 { + t.Fatalf("expected 0 upstream calls (peer already refreshed), got %d", got) + } +} + +// TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache verifies that on a +// cache miss the leader stores its result for peers to find within the TTL. +func TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache(t *testing.T) { + stub := &stubTokenExchanger{ + resp: &TokenResponse{AccessToken: "fresh_grant"}, + } + + logger := NewLogger("error") + cache := newInMemoryCache() + + oidc := &TraefikOidc{ + logger: logger, + tokenExchanger: stub, + refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger), + refreshResultCache: cache, + } + defer oidc.refreshCoordinator.Shutdown() + + rt := "fresh_refresh_token" + resp, err := oidc.coordinatedTokenRefresh(nil, rt) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil || resp.AccessToken != "fresh_grant" { + t.Fatalf("unexpected response: %+v", resp) + } + if got := atomic.LoadInt32(&stub.calls); got != 1 { + t.Fatalf("expected 1 upstream call, got %d", got) + } + + v, ok := cache.Get(refreshResultCacheKey(refreshCoordinatorSessionID(rt))) + if !ok { + t.Fatal("expected refresh result to be cached after upstream success") + } + if tr, ok := v.(*TokenResponse); !ok || tr.AccessToken != "fresh_grant" { + t.Fatalf("cached value malformed: %+v", v) + } +} + +// TestCoordinatedTokenRefresh_ErrorIsNotCached makes sure we don't poison the +// dedup cache when the IdP rejects the grant. Peers must run their own +// refresh; they cannot inherit an error. +func TestCoordinatedTokenRefresh_ErrorIsNotCached(t *testing.T) { + failing := &erroringTokenExchanger{} + logger := NewLogger("error") + cache := newInMemoryCache() + + oidc := &TraefikOidc{ + logger: logger, + tokenExchanger: failing, + refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger), + refreshResultCache: cache, + } + defer oidc.refreshCoordinator.Shutdown() + + if _, err := oidc.coordinatedTokenRefresh(nil, "doomed_refresh_token"); err == nil { + t.Fatal("expected an error from the failing exchanger") + } + if cache.Size() != 0 { + t.Fatalf("error result must not be cached, size=%d", cache.Size()) + } +} diff --git a/refresh_token_expiry_test.go b/refresh_token_expiry_test.go new file mode 100644 index 0000000..de45673 --- /dev/null +++ b/refresh_token_expiry_test.go @@ -0,0 +1,68 @@ +package traefikoidc + +import ( + "testing" + "time" + + "github.com/gorilla/sessions" +) + +// sessionWithIssuedAt builds the smallest SessionData that GetRefreshTokenIssuedAt +// reads from. We can't reuse sessionPool.Get() here because that requires a +// fully initialized SessionManager - overkill for this unit-level check. +func sessionWithIssuedAt(t *testing.T, issuedAt time.Time) *SessionData { + t.Helper() + rs := sessions.NewSession(nil, "refresh") + if !issuedAt.IsZero() { + rs.Values["issued_at"] = issuedAt.Unix() + } + return &SessionData{ + refreshSession: rs, + accessTokenChunks: make(map[int]*sessions.Session), + refreshTokenChunks: make(map[int]*sessions.Session), + idTokenChunks: make(map[int]*sessions.Session), + } +} + +func TestIsRefreshTokenExpired_DisabledWhenAgeZero(t *testing.T) { + tr := &TraefikOidc{maxRefreshTokenAge: 0} + sd := sessionWithIssuedAt(t, time.Now().Add(-30*24*time.Hour)) + if tr.isRefreshTokenExpired(sd) { + t.Fatal("expected isRefreshTokenExpired=false when maxRefreshTokenAge is 0") + } +} + +func TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp(t *testing.T) { + tr := &TraefikOidc{maxRefreshTokenAge: time.Hour} + sd := sessionWithIssuedAt(t, time.Time{}) // no issued_at value + if tr.isRefreshTokenExpired(sd) { + t.Fatal("expected isRefreshTokenExpired=false when issued_at missing (legacy session)") + } +} + +func TestIsRefreshTokenExpired_WithinWindow(t *testing.T) { + tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour} + sd := sessionWithIssuedAt(t, time.Now().Add(-1*time.Hour)) + if tr.isRefreshTokenExpired(sd) { + t.Fatal("expected isRefreshTokenExpired=false within max age") + } +} + +func TestIsRefreshTokenExpired_BeyondWindow(t *testing.T) { + tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour} + sd := sessionWithIssuedAt(t, time.Now().Add(-7*time.Hour)) + if !tr.isRefreshTokenExpired(sd) { + t.Fatal("expected isRefreshTokenExpired=true beyond max age") + } +} + +func TestIsRefreshTokenExpired_NilGuards(t *testing.T) { + var tr *TraefikOidc + if tr.isRefreshTokenExpired(nil) { + t.Fatal("nil receiver must not panic and must return false") + } + tr = &TraefikOidc{maxRefreshTokenAge: time.Hour} + if tr.isRefreshTokenExpired(nil) { + t.Fatal("nil session must return false") + } +} diff --git a/settings.go b/settings.go index 3aec242..1dacfe0 100644 --- a/settings.go +++ b/settings.go @@ -55,6 +55,15 @@ type Config struct { AllowedUsers []string `json:"allowedUsers"` Headers []TemplatedHeader `json:"headers"` RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"` + // MaxRefreshTokenAgeSeconds is a heuristic upper bound on the lifetime of + // a stored refresh token. Once the token has been in the session longer + // than this, requests treat it as expired up-front - returning 401 to + // AJAX callers and triggering full re-auth on navigations - instead of + // hammering the IdP with grants that will only fail with invalid_grant. + // IdPs do not expose RT TTL on the wire, so this is intentionally a + // conservative heuristic; tune to match your provider configuration. + // Default 21600 (6h). Set to 0 to disable the check. + MaxRefreshTokenAgeSeconds int `json:"maxRefreshTokenAgeSeconds"` SessionMaxAge int `json:"sessionMaxAge"` RateLimit int `json:"rateLimit"` OverrideScopes bool `json:"overrideScopes"` @@ -247,6 +256,7 @@ func CreateConfig() *Config { EnablePKCE: false, // PKCE is opt-in OverrideScopes: false, // Default to appending scopes, not overriding RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds + MaxRefreshTokenAgeSeconds: 21600, // 6h - conservative heuristic, see field doc SecurityHeaders: createDefaultSecurityConfig(), Redis: nil, // Redis is disabled by default, configure via Traefik or env vars } @@ -370,6 +380,11 @@ func (c *Config) Validate() error { return fmt.Errorf("refreshGracePeriodSeconds cannot be negative") } + // Validate refresh-token max-age heuristic + if c.MaxRefreshTokenAgeSeconds < 0 { + return fmt.Errorf("maxRefreshTokenAgeSeconds cannot be negative") + } + // Validate audience if specified if c.Audience != "" { // Validate audience format - should be a valid identifier or URL diff --git a/token_manager.go b/token_manager.go index 01e208b..f0e5e9f 100644 --- a/token_manager.go +++ b/token_manager.go @@ -46,6 +46,17 @@ func (t *TraefikOidc) VerifyToken(token string) error { } } + // Hot-path fast-return: a previously-verified token has already passed + // signature, claims, and replay checks. Skipping the parseJWT cost here + // matters under bursty traffic (e.g. 10+ concurrent panel requests on + // every Grafana dashboard refresh) where the same token is validated + // dozens of times per second by validateStandardTokens. + if t.tokenCache != nil { + if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 { + return nil + } + } + parsedJWT, parseErr := parseJWT(token) if parseErr != nil { return fmt.Errorf("failed to parse JWT for blacklist check: %w", parseErr) @@ -63,12 +74,6 @@ func (t *TraefikOidc) VerifyToken(token string) error { } } - // Check token cache FIRST - if token is already verified and cached, return immediately - // This prevents false positives when multiple goroutines validate the same token concurrently - if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 { - return nil - } - // Only check JTI blacklist for tokens that aren't already in the cache // This is for FIRST-TIME validation to detect replay attacks if jti, ok := parsedJWT.Claims["jti"].(string); ok && jti != "" { @@ -416,7 +421,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se } t.logger.Debugf("Attempting refresh with token starting with %s...", tokenPrefix) - newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken) + newToken, err := t.coordinatedTokenRefresh(req, initialRefreshToken) if err != nil { errMsg := err.Error() if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") { @@ -518,6 +523,91 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se return true } +// coordinatedTokenRefresh routes a refresh-token grant through the +// RefreshCoordinator so that concurrent requests sharing the same refresh +// token coalesce into a single upstream call. This prevents the thundering +// herd that yields invalid_grant when the IdP rotates refresh tokens. +// +// Falls back to a direct call when the coordinator is nil, which only +// happens in tests that build TraefikOidc literals without going through +// NewWithContext. +func (t *TraefikOidc) coordinatedTokenRefresh(req *http.Request, refreshToken string) (*TokenResponse, error) { + if t.refreshCoordinator == nil { + return t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken) + } + + parentCtx := context.Background() + if req != nil { + parentCtx = req.Context() + } + ctx, cancel := context.WithTimeout(parentCtx, refreshCoordinatorWaitTimeout) + defer cancel() + + sessionID := refreshCoordinatorSessionID(refreshToken) + + return t.refreshCoordinator.CoordinateRefresh( + ctx, + sessionID, + refreshToken, + func() (*TokenResponse, error) { + // Cross-replica dedup. The in-process coordinator already + // collapses concurrent grants on this pod; this Redis-backed + // short-TTL cache covers the (rare) case of a failover or + // load-balancer reroute mid-refresh, where two pods would + // otherwise both POST the same refresh_token to the IdP. + if cached, ok := t.lookupCachedRefreshResult(sessionID); ok { + return cached, nil + } + resp, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken) + if err == nil && resp != nil { + t.cacheRefreshResult(sessionID, resp) + } + return resp, err + }, + ) +} + +// lookupCachedRefreshResult returns a previously-stored TokenResponse for the +// given refresh-token hash, if one exists and is still within its short TTL. +// The cache wraps the universal cache, which is Redis-backed in production - +// so a "hit" here means another Traefik replica refreshed this same token +// within the last few seconds. +func (t *TraefikOidc) lookupCachedRefreshResult(sessionID string) (*TokenResponse, bool) { + if t.refreshResultCache == nil { + return nil, false + } + v, ok := t.refreshResultCache.Get(refreshResultCacheKey(sessionID)) + if !ok || v == nil { + return nil, false + } + if tr, ok := v.(*TokenResponse); ok && tr != nil { + return tr, true + } + return nil, false +} + +// cacheRefreshResult stores the new TokenResponse under the refresh-token +// hash for a short window. TTL is intentionally tight: the rotated refresh +// token cannot be re-presented to the IdP, and any peer waiting longer than +// this window has almost certainly given up via its own coordinator timeout. +func (t *TraefikOidc) cacheRefreshResult(sessionID string, resp *TokenResponse) { + if t.refreshResultCache == nil || resp == nil { + return + } + t.refreshResultCache.Set(refreshResultCacheKey(sessionID), resp, refreshResultCacheTTL) +} + +// refreshResultCacheKey namespaces refresh-result entries inside the shared +// cache namespace. +func refreshResultCacheKey(sessionID string) string { + return "rt-result:" + sessionID +} + +// refreshResultCacheTTL bounds how long a peer can lean on the dedup cache. +// Long enough for a sibling replica to observe the result, short enough that +// a stale entry never re-supplies a token after the IdP has already moved on. +const refreshResultCacheTTL = 5 * time.Second + // RevokeToken revokes a token locally by adding it to the blacklist cache. // It removes the token from the verification cache and adds both the token // and its JTI (if present) to the blacklist to prevent future use. diff --git a/types.go b/types.go index bd6f980..64470d2 100644 --- a/types.go +++ b/types.go @@ -95,6 +95,7 @@ type TraefikOidc struct { cancelFunc context.CancelFunc errorRecoveryManager *ErrorRecoveryManager tokenResilienceManager *TokenResilienceManager + refreshCoordinator *RefreshCoordinator goroutineWG *sync.WaitGroup dcrConfig *DynamicClientRegistrationConfig dynamicClientRegistrar *DynamicClientRegistrar @@ -124,11 +125,13 @@ type TraefikOidc struct { scopesSupported []string scopes []string refreshGracePeriod time.Duration + maxRefreshTokenAge time.Duration metadataMu sync.RWMutex shutdownOnce sync.Once metadataRetryMutex sync.Mutex firstRequestMutex sync.Mutex sessionInvalidationCache CacheInterface + refreshResultCache CacheInterface minimalHeaders bool stripAuthCookies bool enableBackchannelLogout bool diff --git a/universal_cache_singleton.go b/universal_cache_singleton.go index 3ccdde9..51aa644 100644 --- a/universal_cache_singleton.go +++ b/universal_cache_singleton.go @@ -23,6 +23,7 @@ type UniversalCacheManager struct { metadataCache *UniversalCache dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments sessionInvalidationCache *UniversalCache // Session invalidation cache for backchannel/front-channel logout + refreshResultCache *UniversalCache // Short-lived cross-replica refresh-result dedup (paired with RefreshCoordinator) logger *Logger blacklistCache *UniversalCache cancel context.CancelFunc @@ -181,6 +182,18 @@ func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) { Logger: logger, SkipAutoCleanup: true, // Managed cleanup }) + + // Refresh-result cache: short-lived store keyed by sha256(refreshToken). + // In Redis-backed mode this gives cross-replica dedup of refresh grants; + // in memory-only mode it's effectively redundant with RefreshCoordinator + // but safe and cheap to keep. + manager.refreshResultCache = NewUniversalCache(UniversalCacheConfig{ + Type: CacheTypeToken, + MaxSize: 1000, + DefaultTTL: 5 * time.Second, + Logger: logger, + SkipAutoCleanup: true, // Managed cleanup + }) } // initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration @@ -387,6 +400,21 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r createBackend("session_invalidation"), ) + // Refresh-result cache - shared via Redis so concurrent refreshes across + // Traefik replicas can dedup their grants. The 5s TTL is long enough for + // peers to observe a recent refresh and short enough that a stale entry + // can't be replayed against a now-rotated refresh token. + manager.refreshResultCache = NewUniversalCacheWithBackend( + UniversalCacheConfig{ + Type: CacheTypeToken, + MaxSize: 1000, + DefaultTTL: 5 * time.Second, + Logger: logger, + SkipAutoCleanup: true, // Managed cleanup + }, + createBackend("refresh_result"), + ) + logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode) } @@ -436,6 +464,7 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() { m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache, + m.refreshResultCache, } m.mu.RUnlock() @@ -498,6 +527,14 @@ func (m *UniversalCacheManager) GetSessionInvalidationCache() *UniversalCache { return m.sessionInvalidationCache } +// GetRefreshResultCache returns the short-lived refresh-result cache used to +// coalesce refresh-token grants across Traefik replicas. +func (m *UniversalCacheManager) GetRefreshResultCache() *UniversalCache { + m.mu.RLock() + defer m.mu.RUnlock() + return m.refreshResultCache +} + // GetDCRCredentialsCache returns the DCR credentials cache for distributed storage func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache { m.mu.RLock() @@ -520,7 +557,7 @@ func (m *UniversalCacheManager) Close() error { // Close all caches first (they won't close the shared backend) for _, cache := range []*UniversalCache{ - m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache, + m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache, m.refreshResultCache, } { if cache != nil { _ = cache.Close() // Safe to ignore: best effort cache cleanup diff --git a/utilities.go b/utilities.go index 0df262b..c5d50e3 100644 --- a/utilities.go +++ b/utilities.go @@ -250,6 +250,11 @@ func (t *TraefikOidc) Close() error { t.safeLogDebug("metadataRefreshStopChan closed") } + if t.refreshCoordinator != nil { + t.refreshCoordinator.Shutdown() + t.safeLogDebug("refreshCoordinator shut down") + } + if t.goroutineWG != nil { done := make(chan struct{}) go func() { diff --git a/verify_token_hotpath_test.go b/verify_token_hotpath_test.go new file mode 100644 index 0000000..6dcf744 --- /dev/null +++ b/verify_token_hotpath_test.go @@ -0,0 +1,62 @@ +package traefikoidc + +import ( + "testing" + "time" + + "golang.org/x/time/rate" +) + +// TestVerifyToken_CacheHitSkipsParse proves the hot-path optimization: when a +// token is in the cache, VerifyToken returns nil without calling parseJWT. +// We construct a token that PASSES the cheap format checks (3 segments, len +// >= 10) but whose body is unparseable JSON. With the cache hit hoisted ahead +// of parseJWT, the function returns nil. Without the hoist, parseJWT would +// fail with "failed to parse JWT for blacklist check". +func TestVerifyToken_CacheHitSkipsParse(t *testing.T) { + tr := &TraefikOidc{ + logger: NewLogger("error"), + tokenCache: NewTokenCache(), + // limiter intentionally absent; if we reached the rate-limit check + // the test would NPE - this is a stronger assertion that we exit + // before that point. + limiter: rate.NewLimiter(rate.Inf, 1), + } + tr.tokenVerifier = tr + + // Three segments separated by '.', body is junk after base64-decode + JSON. + // Pre-fix this fails parseJWT; post-fix it returns nil because the cache + // short-circuits. + junkToken := "header.bm90LWpzb24.signature" // base64(not-json) in the middle + tr.tokenCache.Set(junkToken, map[string]interface{}{ + "exp": float64(time.Now().Add(time.Hour).Unix()), + "sub": "test", + }, time.Hour) + + if err := tr.VerifyToken(junkToken); err != nil { + t.Fatalf("expected cache-hit fast path to return nil, got: %v", err) + } +} + +// TestVerifyToken_CacheMissStillParses ensures we did not skip too aggressively +// - on a cache miss, the function must still parse and reach the rate-limit +// check. We assert by passing a syntactically valid token whose signature +// won't verify, expecting an error from later in the pipeline. +func TestVerifyToken_CacheMissStillParses(t *testing.T) { + tr := &TraefikOidc{ + logger: NewLogger("error"), + tokenCache: NewTokenCache(), + limiter: rate.NewLimiter(rate.Inf, 1), + // no tokenBlacklist, no jwkCache - the function will fail somewhere + // after parseJWT. We just need a non-nil error to confirm we did + // progress past the cache check. + } + tr.tokenVerifier = tr + + // Real JWT structure but unsigned/unverifiable. + rawToken := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature" + + if err := tr.VerifyToken(rawToken); err == nil { + t.Fatal("expected an error past parseJWT for an unsigned token, got nil") + } +}