mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
fix(refresh): coalesce refresh-token grants + bound goroutines + cache hot path (target v0.8.27) (#131)
* fix(refresh): wire RefreshCoordinator into the live refresh path
The RefreshCoordinator existed but was never instantiated. The actual
refresh path used only session.refreshMutex, which is per-SessionData
instance - and SessionData is pulled from a sync.Pool per request -
so concurrent requests sharing a refresh token had ZERO coordination.
Symptom: when access_token expired (e.g. 5min Zitadel default), every
in-flight request from a polling client (Grafana panels) entered the
refresh path simultaneously and POSTed the same refresh_token to the
IdP. With refresh-token rotation enabled (Zitadel/Authentik default),
only one grant succeeded; the rest got invalid_grant and each cleared
the entire session. Subsequent requests then thrashed in re-auth loops.
This commit:
- adds refreshCoordinator field on TraefikOidc
- instantiates it in NewWithContext with DefaultRefreshCoordinatorConfig
- shuts it down in Close() under shutdownOnce
- routes refreshToken() through the coordinator via coordinatedTokenRefresh,
which collapses concurrent grants to a single upstream call per
refresh_token hash
- exports refreshCoordinatorSessionID for both internal hashing and the
middleware-level wireup so dedup keys stay aligned
Behavioural notes:
- nil-coordinator fallback preserves existing tests that build TraefikOidc
literals without going through the constructor
- followers receive the same TokenResponse/error as the leader, so no
per-instance code paths change
- existing TestGetNewTokenWithRefreshToken_Concurrency still passes
because it hits GetNewTokenWithRefreshToken directly, below the
coordinator boundary
Tests:
- refresh_coordinator_wireup_test.go: 50 concurrent refreshes coalesce
to <=2 upstream calls; distinct tokens still run in parallel; nil
coordinator falls back cleanly
* perf(cache): bound L1 backfill goroutines in HybridBackend
Get() and GetMany() previously spawned a goroutine per L2 hit to write
the value through to L1. Under sustained polling traffic (e.g. a Grafana
dashboard refreshing every 30s with N panels) this minted thousands of
goroutines, each running in Yaegi - directly contributing to the
~1000% CPU spike that pairs with the refresh-token herd.
Replace the per-hit goroutines with a single l1BackfillWorker fed by
l1BackfillBuffer, mirroring the existing asyncWriteBuffer/asyncWriteWorker
pattern for L2 writes. Buffer overflow drops the backfill (counted via
l1BackfillDrops) - a dropped backfill just means the next L2 hit for
that key re-queues it, which is safe.
Tests:
- TestHybridBackend_L1BackfillBounded: 1000 distinct L2 hits keep
goroutine count within +20 of baseline (pre-fix it grew by ~1000)
- TestHybridBackend_L1BackfillFullDrops: drops are accounted for when
the buffer is saturated and the worker is stopped
* feat(refresh): implement isRefreshTokenExpired heuristic
Replace the placeholder `return false` with a real check based on the
issued_at timestamp that SetRefreshToken already stamps into the session.
Gated by a new MaxRefreshTokenAgeSeconds config field (default 21600 =
6h, matching the existing comment). 0 disables the check.
This wires the previously-dead refreshTokenExpired branch in middleware.go,
which short-circuits AJAX requests with a 401 instead of letting them
hammer the IdP for a refresh token that's almost certainly stale - the
classic Grafana-after-long-pause failure mode.
Behaviour:
- maxRefreshTokenAge=0 disables the check (preserves prior behaviour)
- legacy sessions without issued_at still attempt one refresh; the IdP
remains the source of truth on first try
- nil-receiver and nil-session guards keep test code that builds
TraefikOidc literals safe
Tests:
- TestIsRefreshTokenExpired_DisabledWhenAgeZero
- TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp
- TestIsRefreshTokenExpired_WithinWindow
- TestIsRefreshTokenExpired_BeyondWindow
- TestIsRefreshTokenExpired_NilGuards
* perf(token): skip parseJWT on cache hit in VerifyToken
The token cache fast-return existed but ran AFTER parseJWT, so every
validation paid for base64 + JSON unmarshal even on a hit. Under bursty
traffic (e.g. 10+ concurrent panel requests on every Grafana dashboard
refresh, each calling validateStandardTokens which verifies BOTH the
access token and the ID token), this is two redundant parses per
request multiplied by the panel count.
Move the cache lookup ahead of parseJWT. On a hit the function returns
nil immediately. On a miss the original flow runs unchanged.
Also nil-guard t.tokenCache to keep partial-literal test instances safe
(matches the same pattern we already use for tokenBlacklist).
Tests:
- TestVerifyToken_CacheHitSkipsParse: cache pre-populated with claims
for a token whose body would fail parseJWT - returns nil iff the
fast-path bypasses the parse
- TestVerifyToken_CacheMissStillParses: a syntactically valid but
unsigned token still errors past parseJWT on cache miss
* feat(refresh): cross-replica refresh-grant dedup via shared cache
The in-process RefreshCoordinator added in 9f96d8c already collapses
concurrent refresh-token grants on a single Traefik replica. With the
plugin's existing Redis (Dragonfly) cache infrastructure available, we
can extend that dedup across replicas: if pod A refreshes a token at
T+0 and pod B receives a request for the same session at T+1, pod B
should reuse pod A's result rather than POSTing the now-rotated refresh
token to the IdP.
Implementation:
- Add a refreshResultCache to UniversalCacheManager (memory-only when
Redis is disabled, Redis-backed in production via the existing
hybrid/Redis-only mode selection)
- Expose it through CacheManager.GetSharedRefreshResultCache and on the
TraefikOidc struct as refreshResultCache (CacheInterface)
- Inside the closure passed to RefreshCoordinator.CoordinateRefresh,
consult the cache first; on hit return immediately, on miss exchange
with the IdP and populate the cache for peers
- 5s TTL: long enough for siblings to observe, short enough that a
rotated refresh token cannot be re-supplied after the IdP has moved on
- Errors are intentionally NOT cached - peers must always be able to
retry on their own
Pragmatic choice: optimistic cache rather than a hard distributed lock.
- A hard lock (SET NX + poll) doubles Redis RTT and risks dead-locks
if a Traefik pod dies mid-grant.
- The user's BGP+Local externalTrafficPolicy already pins ingress for
a session to one node in steady state, so cross-pod racing is rare.
- This optimistic path catches the rare failover case without adding
failure modes.
Tests:
- TestCoordinatedTokenRefresh_CrossReplicaCacheHit: pre-populated cache
short-circuits the upstream call entirely (0 IdP calls)
- TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache: leader stores
a successful result for peers to find
- TestCoordinatedTokenRefresh_ErrorIsNotCached: invalid_grant must not
poison the dedup cache - peers must retry independently
This commit is contained in:
+27
-4
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// validateRedirectCount checks if redirect limit is exceeded and handles the error
|
||||
@@ -360,9 +361,31 @@ func (t *TraefikOidc) isNonNavigationRequest(req *http.Request) bool {
|
||||
return !strings.Contains(accept, "text/html")
|
||||
}
|
||||
|
||||
// isRefreshTokenExpired checks if refresh token is likely expired (older than 6 hours)
|
||||
// isRefreshTokenExpired checks whether the stored refresh token is likely
|
||||
// past its useful lifetime, using the cookie-side issued_at timestamp set by
|
||||
// SetRefreshToken. IdPs do not expose RT TTL on the wire, so this is a
|
||||
// conservative heuristic gated by t.maxRefreshTokenAge (default 6h, set via
|
||||
// MaxRefreshTokenAgeSeconds; 0 disables the check).
|
||||
//
|
||||
// The point of this check is to short-circuit the refresh path BEFORE the
|
||||
// thundering herd hits the IdP for a token the provider has almost certainly
|
||||
// revoked. Together with the RefreshCoordinator wireup, it keeps Grafana-
|
||||
// style polling clients from looping on invalid_grant after a long pause.
|
||||
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
|
||||
// This is a heuristic check - actual implementation would depend on
|
||||
// the specific provider and token metadata
|
||||
return false // Placeholder implementation
|
||||
if t == nil || session == nil {
|
||||
return false
|
||||
}
|
||||
if t.maxRefreshTokenAge <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
issuedAt := session.GetRefreshTokenIssuedAt()
|
||||
if issuedAt.IsZero() {
|
||||
// No timestamp recorded (legacy session pre-dating the issued_at
|
||||
// field). Don't force a re-auth - attempt refresh once and let the
|
||||
// IdP be the source of truth.
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Since(issuedAt) > t.maxRefreshTokenAge
|
||||
}
|
||||
|
||||
@@ -113,6 +113,14 @@ func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedRefreshResultCache returns the short-lived refresh-result cache used
|
||||
// by the refresh path to coalesce grants across Traefik replicas via Redis.
|
||||
func (cm *CacheManager) GetSharedRefreshResultCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetRefreshResultCache(), managed: true}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all cache components
|
||||
func (cm *CacheManager) Close() error {
|
||||
cm.mu.Lock()
|
||||
|
||||
Vendored
+73
-26
@@ -20,6 +20,7 @@ type HybridBackend struct {
|
||||
ctx context.Context
|
||||
syncWriteCacheTypes map[string]bool
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
l1BackfillBuffer chan *l1BackfillItem
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
l1Hits atomic.Int64
|
||||
@@ -28,6 +29,7 @@ type HybridBackend struct {
|
||||
l1Writes atomic.Int64
|
||||
misses atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
l1BackfillDrops atomic.Int64
|
||||
fallbackMode atomic.Bool
|
||||
}
|
||||
|
||||
@@ -39,6 +41,15 @@ type asyncWriteItem struct {
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// l1BackfillItem represents a deferred write of an L2-resolved value back into
|
||||
// L1. Backfills run on a single bounded worker so a burst of L2 hits cannot
|
||||
// detonate the goroutine count (issue: ~1000% CPU under sustained polling).
|
||||
type l1BackfillItem struct {
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
@@ -114,6 +125,7 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
secondary: config.Secondary,
|
||||
syncWriteCacheTypes: config.SyncWriteCacheTypes,
|
||||
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
|
||||
l1BackfillBuffer: make(chan *l1BackfillItem, config.AsyncBufferSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: config.Logger,
|
||||
@@ -123,6 +135,11 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
h.wg.Add(1)
|
||||
go h.asyncWriteWorker()
|
||||
|
||||
// Start L1 backfill worker (single goroutine) to bound goroutine growth on
|
||||
// L2 hits regardless of request rate.
|
||||
h.wg.Add(1)
|
||||
go h.l1BackfillWorker()
|
||||
|
||||
// Start health monitoring
|
||||
h.wg.Add(1)
|
||||
go h.healthMonitor()
|
||||
@@ -223,18 +240,10 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Populate L1 cache with value from L2 (write-through on read)
|
||||
// Use goroutine to avoid blocking the read path
|
||||
go func() {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
|
||||
}
|
||||
}()
|
||||
// Populate L1 cache with value from L2 (write-through on read).
|
||||
// Hand off to the bounded backfill worker instead of spawning a goroutine
|
||||
// per read - under burst that would mint thousands of goroutines.
|
||||
h.queueL1Backfill(key, value, ttl)
|
||||
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
@@ -371,6 +380,7 @@ func (h *HybridBackend) Close() error {
|
||||
|
||||
// Close async write channel
|
||||
close(h.asyncWriteBuffer)
|
||||
close(h.l1BackfillBuffer)
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
@@ -440,13 +450,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
|
||||
for key, value := range l2Results {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
|
||||
}(key, value)
|
||||
h.queueL1Backfill(key, value, 0) // 0 = primary backend default TTL
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -455,13 +459,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
|
||||
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte, t time.Duration) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, t)
|
||||
}(key, value, ttl)
|
||||
h.queueL1Backfill(key, value, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -538,6 +536,55 @@ func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, tt
|
||||
return nil
|
||||
}
|
||||
|
||||
// queueL1Backfill enqueues an L2-resolved value for write-through into L1.
|
||||
// Drops on full buffer to keep the read path constant-time; the next L2 hit
|
||||
// for the same key simply re-queues it.
|
||||
func (h *HybridBackend) queueL1Backfill(key string, value []byte, ttl time.Duration) {
|
||||
select {
|
||||
case h.l1BackfillBuffer <- &l1BackfillItem{key: key, value: value, ttl: ttl}:
|
||||
default:
|
||||
h.l1BackfillDrops.Add(1)
|
||||
h.logger.Debugf("L1 backfill buffer full, dropping for key: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
// l1BackfillWorker drains the backfill queue serially. Single worker is
|
||||
// intentional - L1 writes are local and cheap, and serializing them keeps
|
||||
// goroutine count bounded under any read rate.
|
||||
func (h *HybridBackend) l1BackfillWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
// Drain remaining items best-effort then exit.
|
||||
for len(h.l1BackfillBuffer) > 0 {
|
||||
select {
|
||||
case item := <-h.l1BackfillBuffer:
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
_ = h.primary.Set(writeCtx, item.key, item.value, item.ttl)
|
||||
cancel()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
case item, ok := <-h.l1BackfillBuffer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
if err := h.primary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", item.key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", item.key)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// asyncWriteWorker processes asynchronous writes to L2
|
||||
func (h *HybridBackend) asyncWriteWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
+112
@@ -0,0 +1,112 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestHybridBackend_L1BackfillBounded verifies that a burst of L2 hits does
|
||||
// not detonate the goroutine count. Pre-fix the code spawned one goroutine
|
||||
// per Get() L2 hit; post-fix all backfills funnel through a single worker.
|
||||
func TestHybridBackend_L1BackfillBounded(t *testing.T) {
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
hybrid, err := NewHybridBackend(&HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 256,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer hybrid.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
const burst = 1000
|
||||
|
||||
// Pre-populate L2 with `burst` distinct keys so each Get triggers a
|
||||
// fresh L1 backfill enqueue.
|
||||
for i := 0; i < burst; i++ {
|
||||
require.NoError(t, secondary.Set(ctx, fmt.Sprintf("k:%d", i), []byte("v"), time.Minute))
|
||||
}
|
||||
|
||||
baseline := runtime.NumGoroutine()
|
||||
|
||||
// Issue the burst as fast as possible; the backfill worker MUST be the
|
||||
// only goroutine doing L1 writes. Allow brief slack for the test runtime
|
||||
// scheduling but anything north of +20 means goroutine leakage.
|
||||
peak := baseline
|
||||
for i := 0; i < burst; i++ {
|
||||
_, _, exists, err := hybrid.Get(ctx, fmt.Sprintf("k:%d", i))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
if g := runtime.NumGoroutine(); g > peak {
|
||||
peak = g
|
||||
}
|
||||
}
|
||||
|
||||
delta := peak - baseline
|
||||
if delta > 20 {
|
||||
t.Fatalf("goroutine count grew by %d during burst (baseline=%d peak=%d); backfill worker not bounding goroutines",
|
||||
delta, baseline, peak)
|
||||
}
|
||||
|
||||
// L1 must eventually catch up via the worker. Worker drains serially so
|
||||
// give it a generous window proportional to the burst size.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
var populated int
|
||||
for i := 0; i < burst; i++ {
|
||||
if _, _, ok, _ := primary.Get(ctx, fmt.Sprintf("k:%d", i)); ok {
|
||||
populated++
|
||||
}
|
||||
}
|
||||
// Be lenient: drops are acceptable under buffer pressure, just want
|
||||
// most of the keys to make it.
|
||||
if populated >= burst-int(hybrid.l1BackfillDrops.Load()) {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("L1 not backfilled within deadline: l2Hits=%d l1Writes=%d drops=%d",
|
||||
hybrid.l2Hits.Load(), hybrid.l1Writes.Load(), hybrid.l1BackfillDrops.Load())
|
||||
}
|
||||
|
||||
// TestHybridBackend_L1BackfillFullDrops verifies the drop semantics when the
|
||||
// buffer is saturated. Drops must be counted, never block, never spawn a
|
||||
// goroutine.
|
||||
func TestHybridBackend_L1BackfillFullDrops(t *testing.T) {
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
// Tiny buffer + slow primary writes via failSet so the worker stays
|
||||
// blocked enough to overflow the buffer.
|
||||
hybrid, err := NewHybridBackend(&HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 4,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer hybrid.Close()
|
||||
|
||||
// Stop the worker from draining: cancel the underlying context so the
|
||||
// worker bails out, leaving us with a cold buffer and the queue method
|
||||
// itself responsible for drop accounting.
|
||||
hybrid.cancel()
|
||||
// Wait for worker to exit so it can't drain.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
hybrid.queueL1Backfill(fmt.Sprintf("k:%d", i), []byte("v"), time.Minute)
|
||||
}
|
||||
|
||||
assert.Greater(t, hybrid.l1BackfillDrops.Load(), int64(0),
|
||||
"expected some drops when buffer is saturated and worker is stopped")
|
||||
}
|
||||
@@ -226,6 +226,13 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}
|
||||
return 60 * time.Second
|
||||
}(),
|
||||
maxRefreshTokenAge: func() time.Duration {
|
||||
// 0 (or unset) disables the heuristic; negative is rejected by Validate.
|
||||
if config.MaxRefreshTokenAgeSeconds > 0 {
|
||||
return time.Duration(config.MaxRefreshTokenAgeSeconds) * time.Second
|
||||
}
|
||||
return 0
|
||||
}(),
|
||||
tokenCleanupStopChan: make(chan struct{}),
|
||||
metadataRefreshStopChan: make(chan struct{}),
|
||||
ctx: pluginCtx,
|
||||
@@ -242,6 +249,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
|
||||
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
|
||||
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
|
||||
refreshResultCache: cacheManager.GetSharedRefreshResultCache(),
|
||||
}
|
||||
|
||||
// Log audience configuration
|
||||
@@ -260,6 +268,11 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
tokenResilienceConfig := DefaultTokenResilienceConfig()
|
||||
t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger)
|
||||
|
||||
// Coalesces concurrent refresh-token grants per refresh_token to one upstream
|
||||
// call, preventing the thundering herd that yields invalid_grant when the IdP
|
||||
// rotates refresh tokens (Zitadel/Authentik default).
|
||||
t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger)
|
||||
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
|
||||
@@ -466,10 +466,23 @@ func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
|
||||
|
||||
// hashRefreshToken creates a hash of the refresh token for deduplication
|
||||
func (rc *RefreshCoordinator) hashRefreshToken(token string) string {
|
||||
return refreshCoordinatorSessionID(token)
|
||||
}
|
||||
|
||||
// refreshCoordinatorSessionID derives a stable identifier from a refresh token
|
||||
// for both deduplication and per-session attempt tracking. Using sha256 of the
|
||||
// raw token means each rotation produces a fresh sessionID with its own attempt
|
||||
// budget, which is what we want.
|
||||
func refreshCoordinatorSessionID(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// refreshCoordinatorWaitTimeout caps how long a request may wait for a
|
||||
// coordinated refresh result. It is wider than RefreshTimeout so a follower
|
||||
// always sees the leader's result instead of timing out independently.
|
||||
const refreshCoordinatorWaitTimeout = 35 * time.Second
|
||||
|
||||
// isUnderMemoryPressure checks if the system is under memory pressure by
|
||||
// consulting the global memory monitor. Returns true when pressure reaches
|
||||
// High or Critical, at which point we refuse new refresh operations to
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// stubTokenExchanger lets us count how many upstream refresh-token grants
|
||||
// happen for a given refresh_token across concurrent middleware-level calls.
|
||||
type stubTokenExchanger struct {
|
||||
calls int32
|
||||
delay time.Duration
|
||||
resp *TokenResponse
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.delay > 0 {
|
||||
time.Sleep(s.delay)
|
||||
}
|
||||
return s.resp, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) RevokeTokenWithProvider(_, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_SingleUpstreamCall verifies the wireup: many
|
||||
// concurrent calls to coordinatedTokenRefresh with the same refresh token
|
||||
// must collapse to a single tokenExchanger.GetNewTokenWithRefreshToken call.
|
||||
//
|
||||
// Without the wireup this assertion fails (one upstream call per goroutine).
|
||||
func TestCoordinatedTokenRefresh_SingleUpstreamCall(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 100 * time.Millisecond,
|
||||
resp: &TokenResponse{
|
||||
AccessToken: "new_access",
|
||||
RefreshToken: "new_refresh",
|
||||
IDToken: "new_id",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const concurrency = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrency)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
start := make(chan struct{})
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
resp, err := oidc.coordinatedTokenRefresh(req, "shared_refresh_token")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "new_access" {
|
||||
t.Errorf("unexpected response: %+v", resp)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
got := atomic.LoadInt32(&stub.calls)
|
||||
// Up to 2 is acceptable to absorb the documented timing slack in the
|
||||
// existing coordinator tests (e.g. operation just cleaned up before a
|
||||
// late goroutine reads the in-flight map). Anything beyond that means
|
||||
// coalescing is broken.
|
||||
if got > 2 {
|
||||
t.Fatalf("expected <=2 upstream refresh calls, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator verifies the nil
|
||||
// coordinator path so existing tests that build TraefikOidc literals stay
|
||||
// green.
|
||||
func TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenExchanger: stub,
|
||||
// refreshCoordinator deliberately nil
|
||||
}
|
||||
|
||||
resp, err := oidc.coordinatedTokenRefresh(nil, "rt")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "ok" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 1 {
|
||||
t.Fatalf("expected exactly 1 upstream call, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_DistinctTokensRunInParallel verifies that
|
||||
// distinct refresh tokens are not falsely coalesced.
|
||||
func TestCoordinatedTokenRefresh_DistinctTokensRunInParallel(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 20 * time.Millisecond,
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
cfg.DeduplicationCleanupDelay = 0
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const distinct = 8
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(distinct)
|
||||
for i := 0; i < distinct; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := oidc.coordinatedTokenRefresh(nil, refreshCoordinatorSessionID(string(rune('a'+i))))
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(&stub.calls); int(got) != distinct {
|
||||
t.Fatalf("expected %d distinct upstream calls, got %d", distinct, got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// inMemoryCache is the smallest CacheInterface that satisfies the cross-
|
||||
// replica dedup contract: Set/Get with TTL. Used in place of the universal
|
||||
// cache singleton so these tests stay hermetic.
|
||||
type inMemoryCache struct {
|
||||
entries map[string]inMemoryCacheEntry
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type inMemoryCacheEntry struct {
|
||||
expiresAt time.Time
|
||||
value interface{}
|
||||
}
|
||||
|
||||
func newInMemoryCache() *inMemoryCache {
|
||||
return &inMemoryCache{entries: make(map[string]inMemoryCacheEntry)}
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Set(key string, value any, ttl time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.entries[key] = inMemoryCacheEntry{value: value, expiresAt: time.Now().Add(ttl)}
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Get(key string) (any, bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
e, ok := c.entries[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(e.expiresAt) {
|
||||
delete(c.entries, key)
|
||||
return nil, false
|
||||
}
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.entries, key)
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) SetMaxSize(int) {}
|
||||
func (c *inMemoryCache) Cleanup() {}
|
||||
func (c *inMemoryCache) Close() {}
|
||||
func (c *inMemoryCache) Size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.entries)
|
||||
}
|
||||
func (c *inMemoryCache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.entries = map[string]inMemoryCacheEntry{}
|
||||
}
|
||||
func (c *inMemoryCache) GetStats() map[string]any { return map[string]any{} }
|
||||
|
||||
// erroringTokenExchanger always errors - simulates an IdP rejection.
|
||||
type erroringTokenExchanger struct {
|
||||
calls int32
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
|
||||
return nil, errors.New("not used")
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&e.calls, 1)
|
||||
return nil, errors.New("invalid_grant")
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) RevokeTokenWithProvider(_, _ string) error { return nil }
|
||||
|
||||
// TestCoordinatedTokenRefresh_CrossReplicaCacheHit simulates a peer Traefik
|
||||
// replica having just refreshed: the shared cache already has the result, so
|
||||
// this pod must reuse it without ever calling the IdP.
|
||||
func TestCoordinatedTokenRefresh_CrossReplicaCacheHit(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "should_not_be_called"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
preExisting := &TokenResponse{
|
||||
AccessToken: "from_peer",
|
||||
RefreshToken: "rotated_by_peer",
|
||||
IDToken: "id_from_peer",
|
||||
}
|
||||
rt := "shared_refresh_token"
|
||||
cache.Set(refreshResultCacheKey(refreshCoordinatorSessionID(rt)), preExisting, refreshResultCacheTTL)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
resp, err := oidc.coordinatedTokenRefresh(httptest.NewRequest("GET", "/", nil), rt)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "from_peer" {
|
||||
t.Fatalf("expected peer-provided response, got %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 0 {
|
||||
t.Fatalf("expected 0 upstream calls (peer already refreshed), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache verifies that on a
|
||||
// cache miss the leader stores its result for peers to find within the TTL.
|
||||
func TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "fresh_grant"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
rt := "fresh_refresh_token"
|
||||
resp, err := oidc.coordinatedTokenRefresh(nil, rt)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "fresh_grant" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 1 {
|
||||
t.Fatalf("expected 1 upstream call, got %d", got)
|
||||
}
|
||||
|
||||
v, ok := cache.Get(refreshResultCacheKey(refreshCoordinatorSessionID(rt)))
|
||||
if !ok {
|
||||
t.Fatal("expected refresh result to be cached after upstream success")
|
||||
}
|
||||
if tr, ok := v.(*TokenResponse); !ok || tr.AccessToken != "fresh_grant" {
|
||||
t.Fatalf("cached value malformed: %+v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_ErrorIsNotCached makes sure we don't poison the
|
||||
// dedup cache when the IdP rejects the grant. Peers must run their own
|
||||
// refresh; they cannot inherit an error.
|
||||
func TestCoordinatedTokenRefresh_ErrorIsNotCached(t *testing.T) {
|
||||
failing := &erroringTokenExchanger{}
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: failing,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
if _, err := oidc.coordinatedTokenRefresh(nil, "doomed_refresh_token"); err == nil {
|
||||
t.Fatal("expected an error from the failing exchanger")
|
||||
}
|
||||
if cache.Size() != 0 {
|
||||
t.Fatalf("error result must not be cached, size=%d", cache.Size())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// sessionWithIssuedAt builds the smallest SessionData that GetRefreshTokenIssuedAt
|
||||
// reads from. We can't reuse sessionPool.Get() here because that requires a
|
||||
// fully initialized SessionManager - overkill for this unit-level check.
|
||||
func sessionWithIssuedAt(t *testing.T, issuedAt time.Time) *SessionData {
|
||||
t.Helper()
|
||||
rs := sessions.NewSession(nil, "refresh")
|
||||
if !issuedAt.IsZero() {
|
||||
rs.Values["issued_at"] = issuedAt.Unix()
|
||||
}
|
||||
return &SessionData{
|
||||
refreshSession: rs,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
idTokenChunks: make(map[int]*sessions.Session),
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_DisabledWhenAgeZero(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 0}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-30*24*time.Hour))
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false when maxRefreshTokenAge is 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Time{}) // no issued_at value
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false when issued_at missing (legacy session)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_WithinWindow(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-1*time.Hour))
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false within max age")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_BeyondWindow(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-7*time.Hour))
|
||||
if !tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=true beyond max age")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_NilGuards(t *testing.T) {
|
||||
var tr *TraefikOidc
|
||||
if tr.isRefreshTokenExpired(nil) {
|
||||
t.Fatal("nil receiver must not panic and must return false")
|
||||
}
|
||||
tr = &TraefikOidc{maxRefreshTokenAge: time.Hour}
|
||||
if tr.isRefreshTokenExpired(nil) {
|
||||
t.Fatal("nil session must return false")
|
||||
}
|
||||
}
|
||||
+15
@@ -55,6 +55,15 @@ type Config struct {
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
// MaxRefreshTokenAgeSeconds is a heuristic upper bound on the lifetime of
|
||||
// a stored refresh token. Once the token has been in the session longer
|
||||
// than this, requests treat it as expired up-front - returning 401 to
|
||||
// AJAX callers and triggering full re-auth on navigations - instead of
|
||||
// hammering the IdP with grants that will only fail with invalid_grant.
|
||||
// IdPs do not expose RT TTL on the wire, so this is intentionally a
|
||||
// conservative heuristic; tune to match your provider configuration.
|
||||
// Default 21600 (6h). Set to 0 to disable the check.
|
||||
MaxRefreshTokenAgeSeconds int `json:"maxRefreshTokenAgeSeconds"`
|
||||
SessionMaxAge int `json:"sessionMaxAge"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
@@ -247,6 +256,7 @@ func CreateConfig() *Config {
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
OverrideScopes: false, // Default to appending scopes, not overriding
|
||||
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
|
||||
MaxRefreshTokenAgeSeconds: 21600, // 6h - conservative heuristic, see field doc
|
||||
SecurityHeaders: createDefaultSecurityConfig(),
|
||||
Redis: nil, // Redis is disabled by default, configure via Traefik or env vars
|
||||
}
|
||||
@@ -370,6 +380,11 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
|
||||
}
|
||||
|
||||
// Validate refresh-token max-age heuristic
|
||||
if c.MaxRefreshTokenAgeSeconds < 0 {
|
||||
return fmt.Errorf("maxRefreshTokenAgeSeconds cannot be negative")
|
||||
}
|
||||
|
||||
// Validate audience if specified
|
||||
if c.Audience != "" {
|
||||
// Validate audience format - should be a valid identifier or URL
|
||||
|
||||
+97
-7
@@ -46,6 +46,17 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Hot-path fast-return: a previously-verified token has already passed
|
||||
// signature, claims, and replay checks. Skipping the parseJWT cost here
|
||||
// matters under bursty traffic (e.g. 10+ concurrent panel requests on
|
||||
// every Grafana dashboard refresh) where the same token is validated
|
||||
// dozens of times per second by validateStandardTokens.
|
||||
if t.tokenCache != nil {
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
parsedJWT, parseErr := parseJWT(token)
|
||||
if parseErr != nil {
|
||||
return fmt.Errorf("failed to parse JWT for blacklist check: %w", parseErr)
|
||||
@@ -63,12 +74,6 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check token cache FIRST - if token is already verified and cached, return immediately
|
||||
// This prevents false positives when multiple goroutines validate the same token concurrently
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only check JTI blacklist for tokens that aren't already in the cache
|
||||
// This is for FIRST-TIME validation to detect replay attacks
|
||||
if jti, ok := parsedJWT.Claims["jti"].(string); ok && jti != "" {
|
||||
@@ -416,7 +421,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
}
|
||||
t.logger.Debugf("Attempting refresh with token starting with %s...", tokenPrefix)
|
||||
|
||||
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
|
||||
newToken, err := t.coordinatedTokenRefresh(req, initialRefreshToken)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
|
||||
@@ -518,6 +523,91 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return true
|
||||
}
|
||||
|
||||
// coordinatedTokenRefresh routes a refresh-token grant through the
|
||||
// RefreshCoordinator so that concurrent requests sharing the same refresh
|
||||
// token coalesce into a single upstream call. This prevents the thundering
|
||||
// herd that yields invalid_grant when the IdP rotates refresh tokens.
|
||||
//
|
||||
// Falls back to a direct call when the coordinator is nil, which only
|
||||
// happens in tests that build TraefikOidc literals without going through
|
||||
// NewWithContext.
|
||||
func (t *TraefikOidc) coordinatedTokenRefresh(req *http.Request, refreshToken string) (*TokenResponse, error) {
|
||||
if t.refreshCoordinator == nil {
|
||||
return t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
}
|
||||
|
||||
parentCtx := context.Background()
|
||||
if req != nil {
|
||||
parentCtx = req.Context()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(parentCtx, refreshCoordinatorWaitTimeout)
|
||||
defer cancel()
|
||||
|
||||
sessionID := refreshCoordinatorSessionID(refreshToken)
|
||||
|
||||
return t.refreshCoordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
sessionID,
|
||||
refreshToken,
|
||||
func() (*TokenResponse, error) {
|
||||
// Cross-replica dedup. The in-process coordinator already
|
||||
// collapses concurrent grants on this pod; this Redis-backed
|
||||
// short-TTL cache covers the (rare) case of a failover or
|
||||
// load-balancer reroute mid-refresh, where two pods would
|
||||
// otherwise both POST the same refresh_token to the IdP.
|
||||
if cached, ok := t.lookupCachedRefreshResult(sessionID); ok {
|
||||
return cached, nil
|
||||
}
|
||||
resp, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
if err == nil && resp != nil {
|
||||
t.cacheRefreshResult(sessionID, resp)
|
||||
}
|
||||
return resp, err
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// lookupCachedRefreshResult returns a previously-stored TokenResponse for the
|
||||
// given refresh-token hash, if one exists and is still within its short TTL.
|
||||
// The cache wraps the universal cache, which is Redis-backed in production -
|
||||
// so a "hit" here means another Traefik replica refreshed this same token
|
||||
// within the last few seconds.
|
||||
func (t *TraefikOidc) lookupCachedRefreshResult(sessionID string) (*TokenResponse, bool) {
|
||||
if t.refreshResultCache == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := t.refreshResultCache.Get(refreshResultCacheKey(sessionID))
|
||||
if !ok || v == nil {
|
||||
return nil, false
|
||||
}
|
||||
if tr, ok := v.(*TokenResponse); ok && tr != nil {
|
||||
return tr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// cacheRefreshResult stores the new TokenResponse under the refresh-token
|
||||
// hash for a short window. TTL is intentionally tight: the rotated refresh
|
||||
// token cannot be re-presented to the IdP, and any peer waiting longer than
|
||||
// this window has almost certainly given up via its own coordinator timeout.
|
||||
func (t *TraefikOidc) cacheRefreshResult(sessionID string, resp *TokenResponse) {
|
||||
if t.refreshResultCache == nil || resp == nil {
|
||||
return
|
||||
}
|
||||
t.refreshResultCache.Set(refreshResultCacheKey(sessionID), resp, refreshResultCacheTTL)
|
||||
}
|
||||
|
||||
// refreshResultCacheKey namespaces refresh-result entries inside the shared
|
||||
// cache namespace.
|
||||
func refreshResultCacheKey(sessionID string) string {
|
||||
return "rt-result:" + sessionID
|
||||
}
|
||||
|
||||
// refreshResultCacheTTL bounds how long a peer can lean on the dedup cache.
|
||||
// Long enough for a sibling replica to observe the result, short enough that
|
||||
// a stale entry never re-supplies a token after the IdP has already moved on.
|
||||
const refreshResultCacheTTL = 5 * time.Second
|
||||
|
||||
// RevokeToken revokes a token locally by adding it to the blacklist cache.
|
||||
// It removes the token from the verification cache and adds both the token
|
||||
// and its JTI (if present) to the blacklist to prevent future use.
|
||||
|
||||
@@ -95,6 +95,7 @@ type TraefikOidc struct {
|
||||
cancelFunc context.CancelFunc
|
||||
errorRecoveryManager *ErrorRecoveryManager
|
||||
tokenResilienceManager *TokenResilienceManager
|
||||
refreshCoordinator *RefreshCoordinator
|
||||
goroutineWG *sync.WaitGroup
|
||||
dcrConfig *DynamicClientRegistrationConfig
|
||||
dynamicClientRegistrar *DynamicClientRegistrar
|
||||
@@ -124,11 +125,13 @@ type TraefikOidc struct {
|
||||
scopesSupported []string
|
||||
scopes []string
|
||||
refreshGracePeriod time.Duration
|
||||
maxRefreshTokenAge time.Duration
|
||||
metadataMu sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
metadataRetryMutex sync.Mutex
|
||||
firstRequestMutex sync.Mutex
|
||||
sessionInvalidationCache CacheInterface
|
||||
refreshResultCache CacheInterface
|
||||
minimalHeaders bool
|
||||
stripAuthCookies bool
|
||||
enableBackchannelLogout bool
|
||||
|
||||
@@ -23,6 +23,7 @@ type UniversalCacheManager struct {
|
||||
metadataCache *UniversalCache
|
||||
dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
|
||||
sessionInvalidationCache *UniversalCache // Session invalidation cache for backchannel/front-channel logout
|
||||
refreshResultCache *UniversalCache // Short-lived cross-replica refresh-result dedup (paired with RefreshCoordinator)
|
||||
logger *Logger
|
||||
blacklistCache *UniversalCache
|
||||
cancel context.CancelFunc
|
||||
@@ -181,6 +182,18 @@ func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) {
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
|
||||
// Refresh-result cache: short-lived store keyed by sha256(refreshToken).
|
||||
// In Redis-backed mode this gives cross-replica dedup of refresh grants;
|
||||
// in memory-only mode it's effectively redundant with RefreshCoordinator
|
||||
// but safe and cheap to keep.
|
||||
manager.refreshResultCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000,
|
||||
DefaultTTL: 5 * time.Second,
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
}
|
||||
|
||||
// initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration
|
||||
@@ -387,6 +400,21 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
|
||||
createBackend("session_invalidation"),
|
||||
)
|
||||
|
||||
// Refresh-result cache - shared via Redis so concurrent refreshes across
|
||||
// Traefik replicas can dedup their grants. The 5s TTL is long enough for
|
||||
// peers to observe a recent refresh and short enough that a stale entry
|
||||
// can't be replayed against a now-rotated refresh token.
|
||||
manager.refreshResultCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000,
|
||||
DefaultTTL: 5 * time.Second,
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
},
|
||||
createBackend("refresh_result"),
|
||||
)
|
||||
|
||||
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
|
||||
}
|
||||
|
||||
@@ -436,6 +464,7 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() {
|
||||
m.tokenTypeCache,
|
||||
m.dcrCredentialsCache,
|
||||
m.sessionInvalidationCache,
|
||||
m.refreshResultCache,
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
@@ -498,6 +527,14 @@ func (m *UniversalCacheManager) GetSessionInvalidationCache() *UniversalCache {
|
||||
return m.sessionInvalidationCache
|
||||
}
|
||||
|
||||
// GetRefreshResultCache returns the short-lived refresh-result cache used to
|
||||
// coalesce refresh-token grants across Traefik replicas.
|
||||
func (m *UniversalCacheManager) GetRefreshResultCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.refreshResultCache
|
||||
}
|
||||
|
||||
// GetDCRCredentialsCache returns the DCR credentials cache for distributed storage
|
||||
func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
@@ -520,7 +557,7 @@ func (m *UniversalCacheManager) Close() error {
|
||||
|
||||
// Close all caches first (they won't close the shared backend)
|
||||
for _, cache := range []*UniversalCache{
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache,
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache, m.refreshResultCache,
|
||||
} {
|
||||
if cache != nil {
|
||||
_ = cache.Close() // Safe to ignore: best effort cache cleanup
|
||||
|
||||
@@ -250,6 +250,11 @@ func (t *TraefikOidc) Close() error {
|
||||
t.safeLogDebug("metadataRefreshStopChan closed")
|
||||
}
|
||||
|
||||
if t.refreshCoordinator != nil {
|
||||
t.refreshCoordinator.Shutdown()
|
||||
t.safeLogDebug("refreshCoordinator shut down")
|
||||
}
|
||||
|
||||
if t.goroutineWG != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestVerifyToken_CacheHitSkipsParse proves the hot-path optimization: when a
|
||||
// token is in the cache, VerifyToken returns nil without calling parseJWT.
|
||||
// We construct a token that PASSES the cheap format checks (3 segments, len
|
||||
// >= 10) but whose body is unparseable JSON. With the cache hit hoisted ahead
|
||||
// of parseJWT, the function returns nil. Without the hoist, parseJWT would
|
||||
// fail with "failed to parse JWT for blacklist check".
|
||||
func TestVerifyToken_CacheHitSkipsParse(t *testing.T) {
|
||||
tr := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenCache: NewTokenCache(),
|
||||
// limiter intentionally absent; if we reached the rate-limit check
|
||||
// the test would NPE - this is a stronger assertion that we exit
|
||||
// before that point.
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
tr.tokenVerifier = tr
|
||||
|
||||
// Three segments separated by '.', body is junk after base64-decode + JSON.
|
||||
// Pre-fix this fails parseJWT; post-fix it returns nil because the cache
|
||||
// short-circuits.
|
||||
junkToken := "header.bm90LWpzb24.signature" // base64(not-json) in the middle
|
||||
tr.tokenCache.Set(junkToken, map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
"sub": "test",
|
||||
}, time.Hour)
|
||||
|
||||
if err := tr.VerifyToken(junkToken); err != nil {
|
||||
t.Fatalf("expected cache-hit fast path to return nil, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyToken_CacheMissStillParses ensures we did not skip too aggressively
|
||||
// - on a cache miss, the function must still parse and reach the rate-limit
|
||||
// check. We assert by passing a syntactically valid token whose signature
|
||||
// won't verify, expecting an error from later in the pipeline.
|
||||
func TestVerifyToken_CacheMissStillParses(t *testing.T) {
|
||||
tr := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenCache: NewTokenCache(),
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
// no tokenBlacklist, no jwkCache - the function will fail somewhere
|
||||
// after parseJWT. We just need a non-nil error to confirm we did
|
||||
// progress past the cache check.
|
||||
}
|
||||
tr.tokenVerifier = tr
|
||||
|
||||
// Real JWT structure but unsigned/unverifiable.
|
||||
rawToken := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature"
|
||||
|
||||
if err := tr.VerifyToken(rawToken); err == nil {
|
||||
t.Fatal("expected an error past parseJWT for an unsigned token, got nil")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user