mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 17e3f8ef62 | |||
| 827926bc3a | |||
| abbfdb02a7 | |||
| 72e2b682bb |
+2
-2
@@ -71,8 +71,8 @@ func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc {
|
||||
logger: NewLogger("error"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://issuer.example.com",
|
||||
audience: "https://api.example.com",
|
||||
clientID: "https://api.example.com",
|
||||
|
||||
@@ -478,11 +478,10 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
|
||||
|
||||
// Test 3: Rate limiting
|
||||
t.Run("RateLimiting", func(t *testing.T) {
|
||||
// Reset circuit breaker to closed state for this test
|
||||
coordinator.circuitBreaker.mutex.Lock()
|
||||
// Reset circuit breaker to closed state for this test. All fields are
|
||||
// atomic so we don't need any mutex.
|
||||
atomic.StoreInt32(&coordinator.circuitBreaker.state, 0) // closed
|
||||
atomic.StoreInt32(&coordinator.circuitBreaker.failures, 0)
|
||||
coordinator.circuitBreaker.mutex.Unlock()
|
||||
|
||||
// Temporarily increase circuit breaker threshold to not interfere
|
||||
oldMaxFailures := coordinator.circuitBreaker.config.MaxFailures
|
||||
@@ -525,9 +524,11 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
|
||||
time.Sleep(config.CleanupInterval * 3)
|
||||
|
||||
// Old sessions should be cleaned up
|
||||
coordinator.attemptsMutex.RLock()
|
||||
count := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
count := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
|
||||
// Should have fewer sessions after cleanup
|
||||
if count > 10 {
|
||||
|
||||
@@ -53,10 +53,26 @@ type JWKSet struct {
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
|
||||
// JWKCache provides thread-safe caching of JWKS using UniversalCache
|
||||
// JWKCache provides thread-safe caching of JWKS using UniversalCache.
|
||||
//
|
||||
// inflightFetches deduplicates concurrent fetches for the same JWKS URL.
|
||||
// It replaces a global sync.RWMutex that was previously held for the entire
|
||||
// HTTP round-trip in GetJWKS: on a cold cache (cold pod, JWK rotation, brief
|
||||
// network blip) every concurrent request piled up on that single Lock(), and
|
||||
// under Yaegi each Lock acquisition costs 10-50ms of interpreter-dispatch
|
||||
// overhead. The singleflight pattern keeps the cold-cache cost O(1) HTTP
|
||||
// fetch regardless of how many requests are waiting.
|
||||
type JWKCache struct {
|
||||
cache *UniversalCache
|
||||
mutex sync.RWMutex
|
||||
cache *UniversalCache
|
||||
inflightFetches sync.Map // map[jwksURL string]*jwksFetch
|
||||
}
|
||||
|
||||
// jwksFetch represents an in-flight JWKS fetch. Done is closed when the fetch
|
||||
// completes; jwks and err carry the result (one of them is set, never both).
|
||||
type jwksFetch struct {
|
||||
done chan struct{}
|
||||
jwks *JWKSet
|
||||
err error
|
||||
}
|
||||
|
||||
// JWKCacheInterface defines the contract for JWK caching implementations.
|
||||
@@ -83,36 +99,58 @@ func NewJWKCache() *JWKCache {
|
||||
// request refetches from the upstream. JWK rotation is rare and a per-replica
|
||||
// HTTP fetch on cold cache is cheap, so cross-replica coherence buys nothing.
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Check cache first
|
||||
// Fast path: cache hit.
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
return jwks, nil
|
||||
}
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
// Singleflight: dedupe concurrent fetches per URL key. The first arrival
|
||||
// performs the HTTP fetch; any later arrival for the same URL waits on
|
||||
// its done channel and shares the result. No global lock is held during
|
||||
// the fetch.
|
||||
candidate := &jwksFetch{done: make(chan struct{})}
|
||||
if existing, loaded := c.inflightFetches.LoadOrStore(jwksURL, candidate); loaded {
|
||||
f, _ := existing.(*jwksFetch)
|
||||
select {
|
||||
case <-f.done:
|
||||
return f.jwks, f.err
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Double-check after acquiring lock
|
||||
// We're the leader. Make absolutely sure the result fields and the
|
||||
// in-flight map entry are cleaned up before any waiter unblocks.
|
||||
defer func() {
|
||||
c.inflightFetches.Delete(jwksURL)
|
||||
close(candidate.done)
|
||||
}()
|
||||
|
||||
// Re-check the cache in case a concurrent fetch completed between our
|
||||
// initial miss and our LoadOrStore win.
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
candidate.jwks = jwks
|
||||
return jwks, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch from URL
|
||||
jwks, err := fetchJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
candidate.err = err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(jwks.Keys) == 0 {
|
||||
return nil, fmt.Errorf("JWKS response contains no keys")
|
||||
candidate.err = fmt.Errorf("JWKS response contains no keys")
|
||||
return nil, candidate.err
|
||||
}
|
||||
|
||||
// Cache for 1 hour
|
||||
// Cache for 1 hour.
|
||||
_ = c.cache.SetLocal(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
candidate.jwks = jwks
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
|
||||
+4
-4
@@ -415,8 +415,8 @@ func TestMiddlewareBackchannelLogoutRouting(t *testing.T) {
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -457,8 +457,8 @@ func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) {
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
@@ -517,6 +517,19 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
introspectionURL := t.introspectionURL
|
||||
registrationURL := t.registrationURL
|
||||
|
||||
// Publish the read-mostly URL bundle atomically. Hot-path readers Load
|
||||
// this directly instead of acquiring metadataMu.RLock per request.
|
||||
t.metadataSnapshot.Store(&MetadataSnapshot{
|
||||
IssuerURL: metadata.Issuer,
|
||||
JWKSURL: metadata.JWKSURL,
|
||||
TokenURL: metadata.TokenURL,
|
||||
AuthURL: metadata.AuthURL,
|
||||
RevocationURL: metadata.RevokeURL,
|
||||
EndSessionURL: metadata.EndSessionURL,
|
||||
IntrospectionURL: metadata.IntrospectionURL,
|
||||
RegistrationURL: metadata.RegistrationURL,
|
||||
})
|
||||
|
||||
t.metadataMu.Unlock()
|
||||
|
||||
// Log introspection endpoint availability for opaque token support
|
||||
|
||||
+14
-30
@@ -8,6 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -484,9 +485,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
providerURL: server.URL,
|
||||
firstRequestReceived: false,
|
||||
firstRequestMutex: sync.Mutex{},
|
||||
providerURL: server.URL,
|
||||
firstRequestStarted: 0,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
@@ -508,19 +508,13 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate first request processing
|
||||
oidc.firstRequestMutex.Lock()
|
||||
if !oidc.firstRequestReceived {
|
||||
oidc.firstRequestReceived = true
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
|
||||
// Simulate first request processing — single-firing via CAS.
|
||||
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
|
||||
// This would normally be called asynchronously
|
||||
go func() {
|
||||
oidc.initializeMetadata(server.URL)
|
||||
// initComplete is closed internally by initializeMetadata
|
||||
}()
|
||||
} else {
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Wait for initialization
|
||||
@@ -556,9 +550,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
providerURL: server.URL,
|
||||
firstRequestReceived: false,
|
||||
firstRequestMutex: sync.Mutex{},
|
||||
providerURL: server.URL,
|
||||
firstRequestStarted: 0,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
@@ -580,31 +573,22 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate multiple concurrent "first" requests
|
||||
// Simulate multiple concurrent "first" requests — only one CAS winner
|
||||
// fires the bootstrap path.
|
||||
const numRequests = 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numRequests)
|
||||
|
||||
initStarted := 0
|
||||
var initMu sync.Mutex
|
||||
var initStarted int32
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
oidc.firstRequestMutex.Lock()
|
||||
if !oidc.firstRequestReceived {
|
||||
oidc.firstRequestReceived = true
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
|
||||
initMu.Lock()
|
||||
initStarted++
|
||||
initMu.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
|
||||
atomic.AddInt32(&initStarted, 1)
|
||||
// Only one should actually start initialization
|
||||
oidc.initializeMetadata(server.URL)
|
||||
} else {
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -612,8 +596,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
wg.Wait()
|
||||
|
||||
// Verify only one initialization was started
|
||||
if initStarted != 1 {
|
||||
t.Errorf("expected exactly 1 initialization, got %d", initStarted)
|
||||
if atomic.LoadInt32(&initStarted) != 1 {
|
||||
t.Errorf("expected exactly 1 initialization, got %d", atomic.LoadInt32(&initStarted))
|
||||
}
|
||||
|
||||
// The metadata endpoint might be called once or not at all depending on timing
|
||||
|
||||
+28
-28
@@ -61,8 +61,8 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com", // Required for initialization check
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -92,8 +92,8 @@ func TestServeHTTP_EventStream(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -175,8 +175,8 @@ func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -272,8 +272,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close this to simulate timeout
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
@@ -307,8 +307,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -337,8 +337,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -367,8 +367,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -740,8 +740,8 @@ func TestMinimalHeaders(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: tt.minimalHeaders,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -817,8 +817,8 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: true, // Enable minimal headers
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -903,8 +903,8 @@ func TestStripAuthCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: tt.stripAuthCookies,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -987,8 +987,8 @@ func TestStripAuthCookies_NoCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -1034,8 +1034,8 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -1085,8 +1085,8 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -1148,8 +1148,8 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
|
||||
+4
-4
@@ -16,6 +16,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -2685,10 +2686,9 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
providerAvailable = true
|
||||
mu.Unlock()
|
||||
|
||||
// Reset the retry timer to allow immediate retry
|
||||
m.metadataRetryMutex.Lock()
|
||||
m.lastMetadataRetryTime = time.Time{} // Reset to zero time
|
||||
m.metadataRetryMutex.Unlock()
|
||||
// Reset the retry timer to allow immediate retry. The field is atomic
|
||||
// now, so no lock is needed.
|
||||
atomic.StoreInt64(&m.lastMetadataRetryNano, 0)
|
||||
|
||||
// Second request should trigger recovery attempt
|
||||
req2 := httptest.NewRequest("GET", "/protected", nil)
|
||||
|
||||
+31
-18
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
@@ -145,19 +146,20 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(req.URL.Path, "/health") {
|
||||
t.firstRequestMutex.Lock()
|
||||
if !t.firstRequestReceived {
|
||||
t.firstRequestReceived = true
|
||||
// Lock-free one-shot bootstrap. The previous firstRequestMutex.Lock()
|
||||
// fired on EVERY non-health request forever (even after the boolean
|
||||
// flipped true), which under Yaegi added a per-request serialization
|
||||
// point. CAS gives single-firing semantics with zero steady-state cost.
|
||||
if atomic.CompareAndSwapInt32(&t.firstRequestStarted, 0, 1) {
|
||||
t.logger.Debug("Starting background tasks on first request")
|
||||
t.startTokenCleanup()
|
||||
|
||||
if !t.metadataRefreshStarted && t.providerURL != "" {
|
||||
t.metadataRefreshStarted = true
|
||||
if t.providerURL != "" &&
|
||||
atomic.CompareAndSwapInt32(&t.metadataRefreshStartedAtomic, 0, 1) {
|
||||
// Metadata refresh is handled by singleton resource manager
|
||||
t.startMetadataRefresh(t.providerURL)
|
||||
}
|
||||
}
|
||||
t.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Evaluate auth-bypass once, before waiting for initialization. Excluded
|
||||
@@ -207,20 +209,31 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
// Read issuerURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
issuerURL := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
// Read issuerURL via atomic snapshot when available — replaces the
|
||||
// metadataMu.RLock that previously fired on every non-bypass request.
|
||||
// Under Yaegi each RLock acquisition costs 1-5ms of interpreter
|
||||
// dispatch; the snapshot is a single atomic.Value.Load. Falls back
|
||||
// to the legacy field+RLock for paths that haven't published a
|
||||
// snapshot yet (notably some test setups that initialize the struct
|
||||
// fields directly).
|
||||
var issuerURL string
|
||||
if snap := t.metadataSnap(); snap != nil {
|
||||
issuerURL = snap.IssuerURL
|
||||
} else {
|
||||
t.metadataMu.RLock()
|
||||
issuerURL = t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
}
|
||||
|
||||
if issuerURL == "" {
|
||||
// Provider metadata initialization failed - try to recover
|
||||
// Retry every 30 seconds to allow automatic recovery when provider comes back online
|
||||
t.metadataRetryMutex.Lock()
|
||||
shouldRetry := time.Since(t.lastMetadataRetryTime) >= 30*time.Second
|
||||
if shouldRetry {
|
||||
t.lastMetadataRetryTime = time.Now()
|
||||
}
|
||||
t.metadataRetryMutex.Unlock()
|
||||
// Provider metadata initialization failed - try to recover.
|
||||
// Retry every 30 seconds to allow automatic recovery. Lock-free
|
||||
// throttle via CAS on lastMetadataRetryNano: one goroutine wins
|
||||
// the window, others see shouldRetry=false.
|
||||
nowNano := time.Now().UnixNano()
|
||||
last := atomic.LoadInt64(&t.lastMetadataRetryNano)
|
||||
shouldRetry := time.Duration(nowNano-last) >= 30*time.Second &&
|
||||
atomic.CompareAndSwapInt64(&t.lastMetadataRetryNano, last, nowNano)
|
||||
|
||||
if shouldRetry && t.providerURL != "" {
|
||||
t.logger.Info("Attempting to recover OIDC provider metadata...")
|
||||
|
||||
@@ -13,8 +13,8 @@ func TestMiddlewareContextCancellation(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close to simulate waiting
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
}
|
||||
|
||||
// Create request with canceled context
|
||||
@@ -39,8 +39,8 @@ func TestMiddlewareSessionErrorRecovery(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -73,8 +73,8 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -102,8 +102,8 @@ func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close to simulate provider unavailable
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
postLogoutRedirectURI: "/",
|
||||
forceHTTPS: false,
|
||||
@@ -142,8 +142,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -187,8 +187,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -236,8 +236,8 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
|
||||
+276
-170
@@ -21,17 +21,22 @@ type RefreshCoordinator struct {
|
||||
// refreshMutex.Lock() was held for tens of milliseconds per request due
|
||||
// to interpreter overhead on the work inside the critical section,
|
||||
// causing dozens of goroutines to stack up on it and pin one CPU core.
|
||||
inFlightRefreshes sync.Map
|
||||
cleanupTimers map[string]*time.Timer
|
||||
sessionRefreshAttempts map[string]*refreshAttemptTracker
|
||||
inFlightRefreshes sync.Map
|
||||
// sessionRefreshAttempts maps sessionID -> *refreshAttemptTracker.
|
||||
// sync.Map + atomic tracker fields means isInCooldown/recordRefreshAttempt/
|
||||
// recordRefreshSuccess/recordRefreshFailure are lock-free. Previously
|
||||
// these used attemptsMutex sync.RWMutex; under Yaegi every Lock() acquisition
|
||||
// adds 10-50ms of dispatch overhead, and they were called twice per leader
|
||||
// request (once for recordRefreshAttempt, once for isInCooldown). That
|
||||
// serializing pattern caused the v1.0.15 death spiral after v1.0.14
|
||||
// removed the refreshMutex (same architectural shape, different mutex).
|
||||
sessionRefreshAttempts sync.Map
|
||||
circuitBreaker *RefreshCircuitBreaker
|
||||
metrics *RefreshMetrics
|
||||
logger *Logger
|
||||
stopChan chan struct{}
|
||||
config RefreshCoordinatorConfig
|
||||
wg sync.WaitGroup
|
||||
attemptsMutex sync.RWMutex
|
||||
cleanupTimerMu sync.Mutex
|
||||
}
|
||||
|
||||
// RefreshCoordinatorConfig configures the refresh coordinator behavior
|
||||
@@ -89,14 +94,46 @@ type refreshResult struct {
|
||||
fromCache bool
|
||||
}
|
||||
|
||||
// refreshAttemptTracker tracks refresh attempts for a session
|
||||
type refreshAttemptTracker struct {
|
||||
lastAttemptTime time.Time
|
||||
windowStartTime time.Time
|
||||
cooldownEndTime time.Time
|
||||
// attemptState is the immutable snapshot of a session's refresh-attempt
|
||||
// state. Lives behind refreshAttemptTracker.state (atomic.Value). Every
|
||||
// transition (record, success, failure, window-reset, cooldown-enter,
|
||||
// cooldown-exit) constructs a fresh attemptState and publishes it via
|
||||
// CompareAndSwap so the entire field set is updated together.
|
||||
//
|
||||
// Per-field atomic.Load/Store (the previous v1.0.15 design) had a benign
|
||||
// but observable hazard: the cooldown-exit reset wrote cooldownEndNano = 0
|
||||
// first, then separately stored attempts = 1 and windowStartNano = now.
|
||||
// A concurrent isInCooldown call could see cooldownEndNano = 0 (reset
|
||||
// just completed) with attempts still at MaxRefreshAttempts, triggering
|
||||
// a fresh cooldown immediately. The snapshot approach eliminates the
|
||||
// intermediate state entirely.
|
||||
type attemptState struct {
|
||||
lastAttemptNano int64 // UnixNano of last attempt
|
||||
windowStartNano int64 // UnixNano of attempt-window start
|
||||
cooldownEndNano int64 // UnixNano; 0 = not in cooldown
|
||||
attempts int32
|
||||
consecutiveFailures int32
|
||||
inCooldown bool
|
||||
}
|
||||
|
||||
// refreshAttemptTracker tracks refresh attempts for a session via a single
|
||||
// atomic.Value holding a *attemptState pointer. Readers do exactly one Load.
|
||||
// Writers do Load → construct new → CompareAndSwap (retry on conflict).
|
||||
// Under Yaegi this collapses 3-4 per-field atomic dispatches into one Load,
|
||||
// and eliminates the cross-field race in the window-reset path.
|
||||
type refreshAttemptTracker struct {
|
||||
state atomic.Value // *attemptState
|
||||
}
|
||||
|
||||
// stateOf returns the current attemptState, or a zero-value snapshot if none
|
||||
// has been published yet. The empty snapshot represents "no attempts recorded".
|
||||
func (t *refreshAttemptTracker) stateOf() *attemptState {
|
||||
if v := t.state.Load(); v != nil {
|
||||
s, _ := v.(*attemptState)
|
||||
if s != nil {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return &attemptState{}
|
||||
}
|
||||
|
||||
// RefreshMetrics tracks coordinator performance metrics
|
||||
@@ -111,14 +148,18 @@ type RefreshMetrics struct {
|
||||
currentInFlightRefreshes int32
|
||||
}
|
||||
|
||||
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations
|
||||
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh
|
||||
// operations. All mutable fields are atomic so AllowRequest/RecordSuccess/
|
||||
// RecordFailure run without any mutex. The previous sync.RWMutex.RLock() was
|
||||
// taken on every CoordinateRefresh — under Yaegi this added 10-50ms of
|
||||
// interpreter dispatch per call, which compounded with attemptsMutex to keep
|
||||
// the pod's single CPU core saturated.
|
||||
type RefreshCircuitBreaker struct {
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureNano int64 // atomic, UnixNano of most recent failure
|
||||
lastSuccessNano int64 // atomic, UnixNano of most recent success
|
||||
config RefreshCircuitBreakerConfig
|
||||
mutex sync.RWMutex
|
||||
state int32
|
||||
failures int32
|
||||
state int32 // atomic: 0=closed, 1=open, 2=half-open
|
||||
failures int32 // atomic
|
||||
}
|
||||
|
||||
// RefreshCircuitBreakerConfig configures the refresh circuit breaker
|
||||
@@ -135,13 +176,12 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
|
||||
}
|
||||
|
||||
rc := &RefreshCoordinator{
|
||||
// inFlightRefreshes is a sync.Map; zero value is ready to use.
|
||||
sessionRefreshAttempts: make(map[string]*refreshAttemptTracker),
|
||||
config: config,
|
||||
metrics: &RefreshMetrics{},
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
cleanupTimers: make(map[string]*time.Timer),
|
||||
// inFlightRefreshes and sessionRefreshAttempts are both sync.Map;
|
||||
// their zero values are ready to use.
|
||||
config: config,
|
||||
metrics: &RefreshMetrics{},
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
circuitBreaker: &RefreshCircuitBreaker{
|
||||
config: RefreshCircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
@@ -269,19 +309,22 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// Reserve concurrent slot via CAS — without the old global lock we can
|
||||
// no longer rely on mutex-mediated check-then-increment. If we lose the
|
||||
// CAS race we retry; if the limit has since been reached we back out.
|
||||
for {
|
||||
current := atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes)
|
||||
if int(current) >= rc.config.MaxConcurrentRefreshes {
|
||||
err := fmt.Errorf("maximum concurrent refresh operations reached")
|
||||
rc.failCandidate(tokenHash, candidate, err)
|
||||
return nil, false, err
|
||||
}
|
||||
if atomic.CompareAndSwapInt32(&rc.metrics.currentInFlightRefreshes, current, current+1) {
|
||||
break
|
||||
}
|
||||
// Reserve concurrent slot via ticket-and-return: increment optimistically,
|
||||
// decrement if we overshot the limit. The previous CAS-loop allowed a
|
||||
// transient overshoot of up to N-1 leaders when several goroutines all
|
||||
// observed `current < max` in the same scheduling slice before any one
|
||||
// of them succeeded their CAS — visible to readers as
|
||||
// currentInFlightRefreshes > MaxConcurrentRefreshes for a brief window.
|
||||
// The ticket pattern is strictly bounded: the counter momentarily reads
|
||||
// max+k for k concurrent attempts past the limit, but only the k that
|
||||
// produced max+1..max+k decrement back, and only k=1 ever observes max+1
|
||||
// as committed.
|
||||
newCount := atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
|
||||
if int(newCount) > rc.config.MaxConcurrentRefreshes {
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
|
||||
err := fmt.Errorf("maximum concurrent refresh operations reached")
|
||||
rc.failCandidate(tokenHash, candidate, err)
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return candidate, true, nil
|
||||
@@ -292,7 +335,13 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
|
||||
// goroutine that just registered the operation) runs them; joiners share the
|
||||
// leader's outcome via operation.done.
|
||||
func (rc *RefreshCoordinator) applyLeaderGates(sessionID string) error {
|
||||
rc.recordRefreshAttempt(sessionID)
|
||||
// Cooldown check FIRST, BEFORE incrementing the attempt counter.
|
||||
// Previously this function recorded the attempt and then read the
|
||||
// cooldown state. Under burst load (many concurrent leaders with
|
||||
// different token hashes but same session) every goroutine could
|
||||
// increment past MaxRefreshAttempts before any one of them observed
|
||||
// the threshold, so the cooldown gate fired too late — the same
|
||||
// thundering-herd shape that drove v1.0.14 into the ground.
|
||||
if rc.isInCooldown(sessionID) {
|
||||
atomic.AddInt64(&rc.metrics.cooldownsTriggered, 1)
|
||||
return fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
|
||||
@@ -301,6 +350,8 @@ func (rc *RefreshCoordinator) applyLeaderGates(sessionID string) error {
|
||||
atomic.AddInt64(&rc.metrics.memoryPressureEvents, 1)
|
||||
return fmt.Errorf("system under memory pressure, refresh denied")
|
||||
}
|
||||
// Only count attempts that actually progress past the gates.
|
||||
rc.recordRefreshAttempt(sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -377,31 +428,25 @@ func (rc *RefreshCoordinator) executeRefreshAsync(
|
||||
}
|
||||
}
|
||||
|
||||
// scheduleDelayedCleanup schedules a cleanup using a timer instead of spawning a goroutine
|
||||
// This prevents goroutine explosion under high load (500+ req/sec)
|
||||
// scheduleDelayedCleanup schedules a cleanup using a timer instead of spawning
|
||||
// a goroutine — time.AfterFunc uses the runtime's timer heap and never spawns
|
||||
// a per-timer goroutine until the callback actually fires.
|
||||
//
|
||||
// The previous implementation tracked every pending timer in a map guarded by
|
||||
// cleanupTimerMu so a duplicate scheduling could cancel the prior timer. That
|
||||
// "shouldn't happen" path was the only consumer of the map, but the mutex
|
||||
// fired on every successful refresh completion — yet another per-request
|
||||
// Yaegi-dispatched lock acquisition. performCleanup is already idempotent
|
||||
// (LoadAndDelete on the sync.Map), so a duplicate scheduling at worst fires
|
||||
// performCleanup twice; the second call is a no-op. Dropping the map removes
|
||||
// the whole class of contention on this code path.
|
||||
func (rc *RefreshCoordinator) scheduleDelayedCleanup(tokenHash string) {
|
||||
delay := rc.config.DeduplicationCleanupDelay
|
||||
if delay <= 0 {
|
||||
// Immediate cleanup
|
||||
rc.performCleanup(tokenHash)
|
||||
return
|
||||
}
|
||||
|
||||
// Use time.AfterFunc which is more efficient than spawning a goroutine with Sleep
|
||||
// time.AfterFunc uses the runtime's timer heap which is much more efficient
|
||||
rc.cleanupTimerMu.Lock()
|
||||
// Cancel any existing timer for this hash (shouldn't happen, but just in case)
|
||||
if existingTimer, exists := rc.cleanupTimers[tokenHash]; exists {
|
||||
existingTimer.Stop()
|
||||
}
|
||||
rc.cleanupTimers[tokenHash] = time.AfterFunc(delay, func() {
|
||||
rc.performCleanup(tokenHash)
|
||||
// Remove timer from map
|
||||
rc.cleanupTimerMu.Lock()
|
||||
delete(rc.cleanupTimers, tokenHash)
|
||||
rc.cleanupTimerMu.Unlock()
|
||||
})
|
||||
rc.cleanupTimerMu.Unlock()
|
||||
time.AfterFunc(delay, func() { rc.performCleanup(tokenHash) })
|
||||
}
|
||||
|
||||
// performCleanup removes the operation from the in-flight map.
|
||||
@@ -415,87 +460,164 @@ func (rc *RefreshCoordinator) performCleanup(tokenHash string) {
|
||||
}
|
||||
}
|
||||
|
||||
// isInCooldown checks if a session is in cooldown after recording an attempt
|
||||
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
// getOrCreateTracker fetches the tracker for sessionID or atomically creates a
|
||||
// fresh one. The sync.Map.LoadOrStore semantics make this lock-free even under
|
||||
// concurrent first-touch races: at most one tracker per sessionID survives.
|
||||
//
|
||||
// trackerFromMapValue centralizes the type assertion so the lint-mandated
|
||||
// two-value form lives in one place; the stored type is always
|
||||
// *refreshAttemptTracker by construction.
|
||||
func trackerFromMapValue(v interface{}) *refreshAttemptTracker {
|
||||
t, _ := v.(*refreshAttemptTracker)
|
||||
return t
|
||||
}
|
||||
|
||||
tracker, exists := rc.sessionRefreshAttempts[sessionID]
|
||||
if !exists {
|
||||
func (rc *RefreshCoordinator) getOrCreateTracker(sessionID string) *refreshAttemptTracker {
|
||||
if v, ok := rc.sessionRefreshAttempts.Load(sessionID); ok {
|
||||
return trackerFromMapValue(v)
|
||||
}
|
||||
fresh := &refreshAttemptTracker{}
|
||||
fresh.state.Store(&attemptState{windowStartNano: time.Now().UnixNano()})
|
||||
actual, _ := rc.sessionRefreshAttempts.LoadOrStore(sessionID, fresh)
|
||||
return trackerFromMapValue(actual)
|
||||
}
|
||||
|
||||
// mutateState performs a CompareAndSwap loop that applies mutate to the
|
||||
// current snapshot. mutate must be PURE: it receives an immutable view of
|
||||
// the current state and returns a fresh *attemptState. If mutate returns nil
|
||||
// the update is skipped (used by isInCooldown for "no change needed" paths).
|
||||
//
|
||||
// Retries on CAS conflict are bounded by the number of concurrent writers —
|
||||
// in practice 1-3. Under Yaegi each retry pays the dispatch cost of one Load
|
||||
// + one CompareAndSwap; still cheaper than the previous per-field atomic
|
||||
// sequence and immune to the cross-field race the v1.0.15 design had.
|
||||
func (t *refreshAttemptTracker) mutateState(mutate func(cur *attemptState) *attemptState) *attemptState {
|
||||
for {
|
||||
cur := t.stateOf()
|
||||
next := mutate(cur)
|
||||
if next == nil {
|
||||
return cur
|
||||
}
|
||||
if t.state.CompareAndSwap(t.state.Load(), next) {
|
||||
return next
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isInCooldown checks if a session is in cooldown. Snapshot-based: every
|
||||
// transition publishes a fresh *attemptState atomically so readers never see
|
||||
// a partially-updated state. The previous per-field atomic design had a
|
||||
// benign race in the cooldown-exit path (cooldownEndNano reset before
|
||||
// attempts reset) that could double-trigger cooldown.
|
||||
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
|
||||
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
|
||||
if !ok {
|
||||
return false // No tracker means first attempt, not in cooldown
|
||||
}
|
||||
|
||||
tracker := trackerFromMapValue(v)
|
||||
now := time.Now()
|
||||
nowNano := now.UnixNano()
|
||||
maxAttempts := rc.config.MaxRefreshAttempts
|
||||
window := rc.config.RefreshAttemptWindow
|
||||
cooldownPeriod := rc.config.RefreshCooldownPeriod
|
||||
|
||||
// Check if already in cooldown
|
||||
if tracker.inCooldown {
|
||||
if now.After(tracker.cooldownEndTime) {
|
||||
// Cooldown expired, reset tracker
|
||||
tracker.inCooldown = false
|
||||
tracker.attempts = 1 // Already recorded one attempt
|
||||
tracker.consecutiveFailures = 0
|
||||
tracker.windowStartTime = now
|
||||
return false
|
||||
cur := tracker.stateOf()
|
||||
|
||||
// Already in cooldown?
|
||||
if cur.cooldownEndNano != 0 {
|
||||
if nowNano <= cur.cooldownEndNano {
|
||||
return true // still in cooldown
|
||||
}
|
||||
return true // Still in cooldown
|
||||
}
|
||||
|
||||
// Check if window expired
|
||||
if now.Sub(tracker.windowStartTime) > rc.config.RefreshAttemptWindow {
|
||||
// Reset window
|
||||
tracker.attempts = 1 // Already recorded one attempt
|
||||
tracker.windowStartTime = now
|
||||
// Cooldown expired: atomically publish a fresh state with the window
|
||||
// restarted from one attempt. Whichever goroutine wins the CAS sets
|
||||
// the new snapshot; losers see it via the next stateOf load.
|
||||
tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
if s.cooldownEndNano == 0 || nowNano <= s.cooldownEndNano {
|
||||
return nil // someone else already reset, or back in cooldown
|
||||
}
|
||||
return &attemptState{
|
||||
windowStartNano: nowNano,
|
||||
attempts: 1,
|
||||
}
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if just exceeded attempt limit
|
||||
if int(tracker.attempts) >= rc.config.MaxRefreshAttempts {
|
||||
// Enter cooldown now
|
||||
tracker.inCooldown = true
|
||||
tracker.cooldownEndTime = now.Add(rc.config.RefreshCooldownPeriod)
|
||||
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
|
||||
sessionID, tracker.attempts)
|
||||
// Window expired?
|
||||
if time.Duration(nowNano-cur.windowStartNano) > window {
|
||||
tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
if time.Duration(nowNano-s.windowStartNano) <= window {
|
||||
return nil
|
||||
}
|
||||
next := *s
|
||||
next.windowStartNano = nowNano
|
||||
next.attempts = 1
|
||||
return &next
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// Just exceeded attempt limit?
|
||||
if int(cur.attempts) >= maxAttempts {
|
||||
end := now.Add(cooldownPeriod).UnixNano()
|
||||
published := tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
if s.cooldownEndNano != 0 {
|
||||
return nil
|
||||
}
|
||||
next := *s
|
||||
next.cooldownEndNano = end
|
||||
return &next
|
||||
})
|
||||
if published.cooldownEndNano == end {
|
||||
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
|
||||
sessionID, published.attempts)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// recordRefreshAttempt records a refresh attempt for rate limiting
|
||||
// recordRefreshAttempt records a refresh attempt for rate limiting. Lock-free
|
||||
// snapshot mutation; attempts and lastAttemptNano are advanced atomically.
|
||||
func (rc *RefreshCoordinator) recordRefreshAttempt(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
tracker, exists := rc.sessionRefreshAttempts[sessionID]
|
||||
if !exists {
|
||||
tracker = &refreshAttemptTracker{
|
||||
windowStartTime: time.Now(),
|
||||
}
|
||||
rc.sessionRefreshAttempts[sessionID] = tracker
|
||||
}
|
||||
|
||||
atomic.AddInt32(&tracker.attempts, 1)
|
||||
tracker.lastAttemptTime = time.Now()
|
||||
tracker := rc.getOrCreateTracker(sessionID)
|
||||
nowNano := time.Now().UnixNano()
|
||||
tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
next := *s
|
||||
next.attempts++
|
||||
next.lastAttemptNano = nowNano
|
||||
return &next
|
||||
})
|
||||
}
|
||||
|
||||
// recordRefreshSuccess records a successful refresh
|
||||
// recordRefreshSuccess records a successful refresh: zero consecutiveFailures.
|
||||
func (rc *RefreshCoordinator) recordRefreshSuccess(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
|
||||
tracker.consecutiveFailures = 0
|
||||
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
trackerFromMapValue(v).mutateState(func(s *attemptState) *attemptState {
|
||||
if s.consecutiveFailures == 0 {
|
||||
return nil
|
||||
}
|
||||
next := *s
|
||||
next.consecutiveFailures = 0
|
||||
return &next
|
||||
})
|
||||
}
|
||||
|
||||
// recordRefreshFailure records a failed refresh
|
||||
// recordRefreshFailure records a failed refresh: increments consecutiveFailures.
|
||||
func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
|
||||
atomic.AddInt32(&tracker.consecutiveFailures, 1)
|
||||
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
trackerFromMapValue(v).mutateState(func(s *attemptState) *attemptState {
|
||||
next := *s
|
||||
next.consecutiveFailures++
|
||||
return &next
|
||||
})
|
||||
}
|
||||
|
||||
// hashRefreshToken creates a hash of the refresh token for deduplication
|
||||
@@ -546,20 +668,22 @@ func (rc *RefreshCoordinator) cleanupRoutine() {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleEntries removes outdated tracking entries
|
||||
// cleanupStaleEntries removes outdated tracking entries. Lock-free iteration
|
||||
// via sync.Map.Range; safe to race with concurrent reads/writes.
|
||||
func (rc *RefreshCoordinator) cleanupStaleEntries() {
|
||||
now := time.Now()
|
||||
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
// Clean up old session trackers
|
||||
for sessionID, tracker := range rc.sessionRefreshAttempts {
|
||||
// Remove trackers that haven't been used recently
|
||||
if now.Sub(tracker.lastAttemptTime) > 2*rc.config.RefreshAttemptWindow {
|
||||
delete(rc.sessionRefreshAttempts, sessionID)
|
||||
cutoff := time.Now().Add(-2 * rc.config.RefreshAttemptWindow).UnixNano()
|
||||
rc.sessionRefreshAttempts.Range(func(key, value interface{}) bool {
|
||||
tracker := trackerFromMapValue(value)
|
||||
if tracker == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if tracker.stateOf().lastAttemptNano < cutoff {
|
||||
// Compare-and-delete to avoid evicting a tracker that was just
|
||||
// re-used by a concurrent caller. We compare by pointer identity.
|
||||
rc.sessionRefreshAttempts.CompareAndDelete(key, value)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// GetMetrics returns current coordinator metrics
|
||||
@@ -577,78 +701,60 @@ func (rc *RefreshCoordinator) GetMetrics() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the coordinator
|
||||
// Shutdown gracefully shuts down the coordinator. Pending delayed-cleanup
|
||||
// timers are NOT canceled explicitly: time.AfterFunc callbacks are tiny
|
||||
// (one map LoadAndDelete) and harmless after Shutdown — sync.Map operations
|
||||
// remain safe on an unused coordinator until GC.
|
||||
func (rc *RefreshCoordinator) Shutdown() {
|
||||
close(rc.stopChan)
|
||||
|
||||
// Cancel all pending cleanup timers
|
||||
rc.cleanupTimerMu.Lock()
|
||||
for _, timer := range rc.cleanupTimers {
|
||||
timer.Stop()
|
||||
}
|
||||
rc.cleanupTimers = make(map[string]*time.Timer)
|
||||
rc.cleanupTimerMu.Unlock()
|
||||
|
||||
rc.wg.Wait()
|
||||
}
|
||||
|
||||
// AllowRequest checks if the circuit breaker allows a request
|
||||
// AllowRequest reports whether the circuit breaker allows a request. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) AllowRequest() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
|
||||
switch state {
|
||||
case 0: // Closed
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 0: // closed
|
||||
return true
|
||||
case 1: // Open
|
||||
if time.Since(cb.lastFailureTime) > cb.config.OpenDuration {
|
||||
// Try to transition to half-open
|
||||
case 1: // open
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailureNano)
|
||||
if time.Duration(time.Now().UnixNano()-lastFail) > cb.config.OpenDuration {
|
||||
// Transition to half-open; first CAS winner gets the probe.
|
||||
if atomic.CompareAndSwapInt32(&cb.state, 1, 2) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
case 2: // Half-open
|
||||
case 2: // half-open
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation
|
||||
// RecordSuccess records a successful operation. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) RecordSuccess() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
if state == 2 { // Half-open
|
||||
// Close the circuit
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 2: // half-open -> close
|
||||
atomic.StoreInt32(&cb.state, 0)
|
||||
atomic.StoreInt32(&cb.failures, 0)
|
||||
} else if state == 0 { // Closed
|
||||
// Reset failure count on success
|
||||
case 0: // closed
|
||||
atomic.StoreInt32(&cb.failures, 0)
|
||||
}
|
||||
cb.lastSuccessTime = time.Now()
|
||||
atomic.StoreInt64(&cb.lastSuccessNano, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// RecordFailure records a failed operation
|
||||
// RecordFailure records a failed operation. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) RecordFailure() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
failures := atomic.AddInt32(&cb.failures, 1)
|
||||
cb.lastFailureTime = time.Now()
|
||||
atomic.StoreInt64(&cb.lastFailureNano, time.Now().UnixNano())
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
|
||||
if state == 0 && int(failures) >= cb.config.MaxFailures {
|
||||
// Open the circuit
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
} else if state == 2 {
|
||||
// Half-open failed, return to open
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 0:
|
||||
if int(failures) >= cb.config.MaxFailures {
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
}
|
||||
case 2:
|
||||
// Half-open probe failed -> back to open.
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
}
|
||||
}
|
||||
|
||||
+28
-34
@@ -165,9 +165,14 @@ func TestRefreshRateLimiting(t *testing.T) {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Verify that cooldown was triggered after max attempts
|
||||
// With the new logic, the Nth attempt triggers cooldown, so we get N-1 successful attempts
|
||||
expectedSuccessfulAttempts := config.MaxRefreshAttempts - 1
|
||||
// Verify that cooldown was triggered after max attempts.
|
||||
// With applyLeaderGates checking cooldown BEFORE recording the attempt
|
||||
// (the v1.0.16 reorder fixing the thundering-herd off-by-one), N attempts
|
||||
// run to completion and the (N+1)th is denied. Previously the Nth was
|
||||
// denied as it tried to record, which under burst load let multiple
|
||||
// concurrent leaders increment past the limit before any one of them
|
||||
// observed the gate.
|
||||
expectedSuccessfulAttempts := config.MaxRefreshAttempts
|
||||
if attempts != expectedSuccessfulAttempts {
|
||||
t.Errorf("Expected %d successful attempts before cooldown, got %d", expectedSuccessfulAttempts, attempts)
|
||||
}
|
||||
@@ -365,10 +370,12 @@ func TestMemoryLeakPrevention(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Verify cleanup is working
|
||||
coordinator.attemptsMutex.RLock()
|
||||
sessionCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
// Verify cleanup is working. sync.Map has no Len(); count via Range.
|
||||
sessionCount := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
sessionCount++
|
||||
return true
|
||||
})
|
||||
|
||||
// Should have cleaned up old sessions (only recent ones remain)
|
||||
if sessionCount > numWorkers*2 {
|
||||
@@ -650,24 +657,23 @@ func TestCleanupRoutine(t *testing.T) {
|
||||
coordinator.recordRefreshAttempt(fmt.Sprintf("session_%d", i))
|
||||
}
|
||||
|
||||
// Verify sessions exist
|
||||
coordinator.attemptsMutex.RLock()
|
||||
initialCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
countSessions := func() int {
|
||||
n := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
n++
|
||||
return true
|
||||
})
|
||||
return n
|
||||
}
|
||||
|
||||
if initialCount != 5 {
|
||||
if initialCount := countSessions(); initialCount != 5 {
|
||||
t.Errorf("Expected 5 sessions, got %d", initialCount)
|
||||
}
|
||||
|
||||
// Wait for cleanup to run (2x window + cleanup interval)
|
||||
time.Sleep(2*config.RefreshAttemptWindow + 2*config.CleanupInterval)
|
||||
|
||||
// Verify sessions were cleaned up
|
||||
coordinator.attemptsMutex.RLock()
|
||||
finalCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
|
||||
if finalCount != 0 {
|
||||
if finalCount := countSessions(); finalCount != 0 {
|
||||
t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount)
|
||||
}
|
||||
}
|
||||
@@ -720,11 +726,9 @@ func TestNoGoroutineExplosionWithTimers(t *testing.T) {
|
||||
currentGoroutines := runtime.NumGoroutine()
|
||||
t.Logf("Goroutines after %d refresh operations: %d", numRefreshes, currentGoroutines)
|
||||
|
||||
// Check timer count
|
||||
coordinator.cleanupTimerMu.Lock()
|
||||
timerCount := len(coordinator.cleanupTimers)
|
||||
coordinator.cleanupTimerMu.Unlock()
|
||||
t.Logf("Active cleanup timers: %d", timerCount)
|
||||
// (Coordinator no longer tracks pending timers; time.AfterFunc closures
|
||||
// fire performCleanup directly. This test now only checks the goroutine
|
||||
// budget, which was always the real invariant.)
|
||||
|
||||
// With timer-based cleanup, goroutine increase should be minimal
|
||||
// Timers don't create goroutines - they use the runtime timer heap
|
||||
@@ -740,19 +744,9 @@ func TestNoGoroutineExplosionWithTimers(t *testing.T) {
|
||||
initialGoroutines, currentGoroutines, goroutineIncrease)
|
||||
}
|
||||
|
||||
// Wait for timers to fire and cleanup
|
||||
// Wait for timers to fire and cleanup.
|
||||
time.Sleep(config.DeduplicationCleanupDelay + 50*time.Millisecond)
|
||||
|
||||
// Verify timers were cleaned up
|
||||
coordinator.cleanupTimerMu.Lock()
|
||||
remainingTimers := len(coordinator.cleanupTimers)
|
||||
coordinator.cleanupTimerMu.Unlock()
|
||||
|
||||
// Most timers should have fired and been removed
|
||||
if remainingTimers > 10 {
|
||||
t.Errorf("Too many cleanup timers remaining: %d", remainingTimers)
|
||||
}
|
||||
|
||||
// Verify goroutines returned to near initial
|
||||
runtime.GC()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
@@ -64,8 +65,46 @@ type ProviderMetadata struct {
|
||||
// It integrates with various OIDC providers, manages sessions, caches tokens, and handles
|
||||
// the complete authentication flow. It's designed to work seamlessly with Traefik's
|
||||
// plugin system and provides flexible configuration options.
|
||||
// MetadataSnapshot is an immutable bundle of provider-metadata URLs that the
|
||||
// plugin needs on the hot request path. Published atomically via
|
||||
// TraefikOidc.metadataSnapshot; readers do exactly one atomic.Value.Load to
|
||||
// access all fields. Replaces 3 per-request metadataMu.RLock acquisitions
|
||||
// in middleware.ServeHTTP + token_manager paths, each of which paid
|
||||
// 1-5ms of Yaegi-dispatch overhead.
|
||||
//
|
||||
// The fields are a strict subset of the metadataMu-guarded TraefikOidc
|
||||
// fields; the legacy fields are still written under metadataMu for
|
||||
// less-frequent code paths that have not been migrated.
|
||||
type MetadataSnapshot struct {
|
||||
IssuerURL string
|
||||
JWKSURL string
|
||||
TokenURL string
|
||||
AuthURL string
|
||||
RevocationURL string
|
||||
EndSessionURL string
|
||||
IntrospectionURL string
|
||||
RegistrationURL string
|
||||
}
|
||||
|
||||
type TraefikOidc struct {
|
||||
lastMetadataRetryTime time.Time
|
||||
// metadataSnapshot atomically publishes the read-mostly URL bundle.
|
||||
// Hot-path readers (middleware.ServeHTTP, token verification) load it
|
||||
// directly; less-frequent paths still acquire metadataMu.RLock and
|
||||
// read the individual fields below.
|
||||
metadataSnapshot atomic.Value
|
||||
// lastMetadataRetryNano is the UnixNano timestamp of the last metadata
|
||||
// recovery attempt. Stored atomically so the hot ServeHTTP path can
|
||||
// throttle retries without acquiring metadataRetryMutex on every request.
|
||||
lastMetadataRetryNano int64
|
||||
// firstRequestStarted is 0 until the very first non-health request fires
|
||||
// the background-task bootstrap; then it flips to 1 via CAS. Replaces the
|
||||
// firstRequestMutex + firstRequestReceived combo which previously took
|
||||
// a write lock on every non-health request forever.
|
||||
firstRequestStarted int32
|
||||
// metadataRefreshStartedAtomic is the CAS-only variant of the old
|
||||
// metadataRefreshStarted bool. Both flags live under the same atomic so
|
||||
// concurrent first-request goroutines race exactly once.
|
||||
metadataRefreshStartedAtomic int32
|
||||
jwkCache JWKCacheInterface
|
||||
jwtVerifier JWTVerifier
|
||||
ctx context.Context
|
||||
@@ -130,17 +169,13 @@ type TraefikOidc struct {
|
||||
maxRefreshTokenAge time.Duration
|
||||
metadataMu sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
metadataRetryMutex sync.Mutex
|
||||
firstRequestMutex sync.Mutex
|
||||
sessionInvalidationCache CacheInterface
|
||||
refreshResultCache CacheInterface
|
||||
minimalHeaders bool
|
||||
stripAuthCookies bool
|
||||
enableBackchannelLogout bool
|
||||
enableFrontchannelLogout bool
|
||||
firstRequestReceived bool
|
||||
requireTokenIntrospection bool
|
||||
metadataRefreshStarted bool
|
||||
allowPrivateIPAddresses bool
|
||||
disableReplayDetection bool
|
||||
allowOpaqueTokens bool
|
||||
|
||||
@@ -14,6 +14,19 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// metadataSnap returns the most recently published *MetadataSnapshot, or nil
|
||||
// if metadata has not yet been resolved. Single atomic.Value.Load — the hot
|
||||
// ServeHTTP path uses this instead of acquiring metadataMu.RLock, which under
|
||||
// Yaegi pays 1-5ms of interpreter-dispatch overhead per acquisition.
|
||||
func (t *TraefikOidc) metadataSnap() *MetadataSnapshot {
|
||||
v := t.metadataSnapshot.Load()
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
s, _ := v.(*MetadataSnapshot)
|
||||
return s
|
||||
}
|
||||
|
||||
// safeLogDebug provides nil-safe logging for debug messages
|
||||
func (t *TraefikOidc) safeLogDebug(msg string) {
|
||||
if t.logger != nil {
|
||||
|
||||
Reference in New Issue
Block a user