mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
4 Commits
feat/oidcgate
...
v1.0.15
| Author | SHA1 | Date | |
|---|---|---|---|
| 72e2b682bb | |||
| ae4ccaa89d | |||
| 984fd1c08f | |||
| 99bdd23986 |
@@ -411,6 +411,19 @@ namespaced claims, Cognito regions, GitLab self-hosted) live in
|
||||
|
||||
Set `logLevel: debug` to surface detail.
|
||||
|
||||
## Telemetry
|
||||
|
||||
On first plugin instantiation this middleware sends a single anonymous
|
||||
adoption ping — project name, version, timestamp; no identifiers, no
|
||||
request data, no token contents. Fire-and-forget with a 2-second timeout;
|
||||
cannot block plugin load or panic.
|
||||
|
||||
Local source: [`telemetry.go`](./telemetry.go). Disclosure mirrors
|
||||
**[oss-telemetry — Disabling telemetry](https://github.com/lukaszraczylo/oss-telemetry#disabling-telemetry)**.
|
||||
|
||||
Quick opt-out: set any of `DO_NOT_TRACK=1`, `OSS_TELEMETRY_DISABLED=1`,
|
||||
or `TRAEFIKOIDC_DISABLE_TELEMETRY=1`.
|
||||
|
||||
## License
|
||||
|
||||
See [LICENSE](LICENSE).
|
||||
|
||||
+2
-2
@@ -71,8 +71,8 @@ func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc {
|
||||
logger: NewLogger("error"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://issuer.example.com",
|
||||
audience: "https://api.example.com",
|
||||
clientID: "https://api.example.com",
|
||||
|
||||
@@ -478,11 +478,10 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
|
||||
|
||||
// Test 3: Rate limiting
|
||||
t.Run("RateLimiting", func(t *testing.T) {
|
||||
// Reset circuit breaker to closed state for this test
|
||||
coordinator.circuitBreaker.mutex.Lock()
|
||||
// Reset circuit breaker to closed state for this test. All fields are
|
||||
// atomic so we don't need any mutex.
|
||||
atomic.StoreInt32(&coordinator.circuitBreaker.state, 0) // closed
|
||||
atomic.StoreInt32(&coordinator.circuitBreaker.failures, 0)
|
||||
coordinator.circuitBreaker.mutex.Unlock()
|
||||
|
||||
// Temporarily increase circuit breaker threshold to not interfere
|
||||
oldMaxFailures := coordinator.circuitBreaker.config.MaxFailures
|
||||
@@ -525,9 +524,11 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
|
||||
time.Sleep(config.CleanupInterval * 3)
|
||||
|
||||
// Old sessions should be cleaned up
|
||||
coordinator.attemptsMutex.RLock()
|
||||
count := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
count := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
|
||||
// Should have fewer sessions after cleanup
|
||||
if count > 10 {
|
||||
|
||||
+4
-4
@@ -415,8 +415,8 @@ func TestMiddlewareBackchannelLogoutRouting(t *testing.T) {
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -457,8 +457,8 @@ func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) {
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
@@ -89,6 +89,7 @@ var defaultExcludedURLs = map[string]struct{}{
|
||||
// - The configured TraefikOidc handler ready to process requests.
|
||||
// - An error if essential configuration is missing or invalid (e.g., short encryption key).
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
sendTelemetry(pluginVersion)
|
||||
return NewWithContext(ctx, config, next, name)
|
||||
}
|
||||
|
||||
|
||||
+14
-30
@@ -8,6 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -484,9 +485,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
providerURL: server.URL,
|
||||
firstRequestReceived: false,
|
||||
firstRequestMutex: sync.Mutex{},
|
||||
providerURL: server.URL,
|
||||
firstRequestStarted: 0,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
@@ -508,19 +508,13 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate first request processing
|
||||
oidc.firstRequestMutex.Lock()
|
||||
if !oidc.firstRequestReceived {
|
||||
oidc.firstRequestReceived = true
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
|
||||
// Simulate first request processing — single-firing via CAS.
|
||||
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
|
||||
// This would normally be called asynchronously
|
||||
go func() {
|
||||
oidc.initializeMetadata(server.URL)
|
||||
// initComplete is closed internally by initializeMetadata
|
||||
}()
|
||||
} else {
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Wait for initialization
|
||||
@@ -556,9 +550,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
providerURL: server.URL,
|
||||
firstRequestReceived: false,
|
||||
firstRequestMutex: sync.Mutex{},
|
||||
providerURL: server.URL,
|
||||
firstRequestStarted: 0,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
@@ -580,31 +573,22 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate multiple concurrent "first" requests
|
||||
// Simulate multiple concurrent "first" requests — only one CAS winner
|
||||
// fires the bootstrap path.
|
||||
const numRequests = 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numRequests)
|
||||
|
||||
initStarted := 0
|
||||
var initMu sync.Mutex
|
||||
var initStarted int32
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
oidc.firstRequestMutex.Lock()
|
||||
if !oidc.firstRequestReceived {
|
||||
oidc.firstRequestReceived = true
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
|
||||
initMu.Lock()
|
||||
initStarted++
|
||||
initMu.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
|
||||
atomic.AddInt32(&initStarted, 1)
|
||||
// Only one should actually start initialization
|
||||
oidc.initializeMetadata(server.URL)
|
||||
} else {
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -612,8 +596,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
wg.Wait()
|
||||
|
||||
// Verify only one initialization was started
|
||||
if initStarted != 1 {
|
||||
t.Errorf("expected exactly 1 initialization, got %d", initStarted)
|
||||
if atomic.LoadInt32(&initStarted) != 1 {
|
||||
t.Errorf("expected exactly 1 initialization, got %d", atomic.LoadInt32(&initStarted))
|
||||
}
|
||||
|
||||
// The metadata endpoint might be called once or not at all depending on timing
|
||||
|
||||
+28
-28
@@ -61,8 +61,8 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com", // Required for initialization check
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -92,8 +92,8 @@ func TestServeHTTP_EventStream(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -175,8 +175,8 @@ func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -272,8 +272,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close this to simulate timeout
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
@@ -307,8 +307,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -337,8 +337,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -367,8 +367,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -740,8 +740,8 @@ func TestMinimalHeaders(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: tt.minimalHeaders,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -817,8 +817,8 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: true, // Enable minimal headers
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -903,8 +903,8 @@ func TestStripAuthCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: tt.stripAuthCookies,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -987,8 +987,8 @@ func TestStripAuthCookies_NoCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -1034,8 +1034,8 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -1085,8 +1085,8 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -1148,8 +1148,8 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
|
||||
+4
-4
@@ -16,6 +16,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -2685,10 +2686,9 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
providerAvailable = true
|
||||
mu.Unlock()
|
||||
|
||||
// Reset the retry timer to allow immediate retry
|
||||
m.metadataRetryMutex.Lock()
|
||||
m.lastMetadataRetryTime = time.Time{} // Reset to zero time
|
||||
m.metadataRetryMutex.Unlock()
|
||||
// Reset the retry timer to allow immediate retry. The field is atomic
|
||||
// now, so no lock is needed.
|
||||
atomic.StoreInt64(&m.lastMetadataRetryNano, 0)
|
||||
|
||||
// Second request should trigger recovery attempt
|
||||
req2 := httptest.NewRequest("GET", "/protected", nil)
|
||||
|
||||
+16
-14
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
@@ -145,19 +146,20 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(req.URL.Path, "/health") {
|
||||
t.firstRequestMutex.Lock()
|
||||
if !t.firstRequestReceived {
|
||||
t.firstRequestReceived = true
|
||||
// Lock-free one-shot bootstrap. The previous firstRequestMutex.Lock()
|
||||
// fired on EVERY non-health request forever (even after the boolean
|
||||
// flipped true), which under Yaegi added a per-request serialization
|
||||
// point. CAS gives single-firing semantics with zero steady-state cost.
|
||||
if atomic.CompareAndSwapInt32(&t.firstRequestStarted, 0, 1) {
|
||||
t.logger.Debug("Starting background tasks on first request")
|
||||
t.startTokenCleanup()
|
||||
|
||||
if !t.metadataRefreshStarted && t.providerURL != "" {
|
||||
t.metadataRefreshStarted = true
|
||||
if t.providerURL != "" &&
|
||||
atomic.CompareAndSwapInt32(&t.metadataRefreshStartedAtomic, 0, 1) {
|
||||
// Metadata refresh is handled by singleton resource manager
|
||||
t.startMetadataRefresh(t.providerURL)
|
||||
}
|
||||
}
|
||||
t.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Evaluate auth-bypass once, before waiting for initialization. Excluded
|
||||
@@ -213,14 +215,14 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if issuerURL == "" {
|
||||
// Provider metadata initialization failed - try to recover
|
||||
// Retry every 30 seconds to allow automatic recovery when provider comes back online
|
||||
t.metadataRetryMutex.Lock()
|
||||
shouldRetry := time.Since(t.lastMetadataRetryTime) >= 30*time.Second
|
||||
if shouldRetry {
|
||||
t.lastMetadataRetryTime = time.Now()
|
||||
}
|
||||
t.metadataRetryMutex.Unlock()
|
||||
// Provider metadata initialization failed - try to recover.
|
||||
// Retry every 30 seconds to allow automatic recovery. Lock-free
|
||||
// throttle via CAS on lastMetadataRetryNano: one goroutine wins
|
||||
// the window, others see shouldRetry=false.
|
||||
nowNano := time.Now().UnixNano()
|
||||
last := atomic.LoadInt64(&t.lastMetadataRetryNano)
|
||||
shouldRetry := time.Duration(nowNano-last) >= 30*time.Second &&
|
||||
atomic.CompareAndSwapInt64(&t.lastMetadataRetryNano, last, nowNano)
|
||||
|
||||
if shouldRetry && t.providerURL != "" {
|
||||
t.logger.Info("Attempting to recover OIDC provider metadata...")
|
||||
|
||||
@@ -13,8 +13,8 @@ func TestMiddlewareContextCancellation(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close to simulate waiting
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
}
|
||||
|
||||
// Create request with canceled context
|
||||
@@ -39,8 +39,8 @@ func TestMiddlewareSessionErrorRecovery(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -73,8 +73,8 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -102,8 +102,8 @@ func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close to simulate provider unavailable
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
postLogoutRedirectURI: "/",
|
||||
forceHTTPS: false,
|
||||
@@ -142,8 +142,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -187,8 +187,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -236,8 +236,8 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
|
||||
+222
-166
@@ -15,17 +15,29 @@ import (
|
||||
// It implements request coalescing, rate limiting, and circuit breaking
|
||||
// specifically for token refresh operations.
|
||||
type RefreshCoordinator struct {
|
||||
inFlightRefreshes map[string]*refreshOperation
|
||||
// inFlightRefreshes maps tokenHash -> *refreshOperation. sync.Map is used
|
||||
// instead of a plain map + RWMutex so concurrent refreshes do not
|
||||
// serialize on a single global lock. Under Yaegi the previous
|
||||
// refreshMutex.Lock() was held for tens of milliseconds per request due
|
||||
// to interpreter overhead on the work inside the critical section,
|
||||
// causing dozens of goroutines to stack up on it and pin one CPU core.
|
||||
inFlightRefreshes sync.Map
|
||||
// sessionRefreshAttempts maps sessionID -> *refreshAttemptTracker.
|
||||
// sync.Map + atomic tracker fields means isInCooldown/recordRefreshAttempt/
|
||||
// recordRefreshSuccess/recordRefreshFailure are lock-free. Previously
|
||||
// these used attemptsMutex sync.RWMutex; under Yaegi every Lock() acquisition
|
||||
// adds 10-50ms of dispatch overhead, and they were called twice per leader
|
||||
// request (once for recordRefreshAttempt, once for isInCooldown). That
|
||||
// serializing pattern caused the v1.0.15 death spiral after v1.0.14
|
||||
// removed the refreshMutex (same architectural shape, different mutex).
|
||||
sessionRefreshAttempts sync.Map
|
||||
cleanupTimers map[string]*time.Timer
|
||||
sessionRefreshAttempts map[string]*refreshAttemptTracker
|
||||
circuitBreaker *RefreshCircuitBreaker
|
||||
metrics *RefreshMetrics
|
||||
logger *Logger
|
||||
stopChan chan struct{}
|
||||
config RefreshCoordinatorConfig
|
||||
wg sync.WaitGroup
|
||||
attemptsMutex sync.RWMutex
|
||||
refreshMutex sync.RWMutex
|
||||
cleanupTimerMu sync.Mutex
|
||||
}
|
||||
|
||||
@@ -84,14 +96,22 @@ type refreshResult struct {
|
||||
fromCache bool
|
||||
}
|
||||
|
||||
// refreshAttemptTracker tracks refresh attempts for a session
|
||||
// refreshAttemptTracker tracks refresh attempts for a session. All fields are
|
||||
// accessed via sync/atomic so isInCooldown/recordRefreshAttempt/Success/Failure
|
||||
// can run without holding any per-coordinator lock. Times are UnixNano so they
|
||||
// fit in an int64 and can be read with a single atomic.LoadInt64.
|
||||
//
|
||||
// cooldownEndNano == 0 means "not in cooldown". This sentinel replaces the
|
||||
// inCooldown bool that the previous implementation kept under attemptsMutex —
|
||||
// under Yaegi any per-request global mutex turns into a serializing bottleneck
|
||||
// (the v1.0.14 refreshMutex -> sync.Map fix removed only one such bottleneck;
|
||||
// attemptsMutex was the next one in the queue).
|
||||
type refreshAttemptTracker struct {
|
||||
lastAttemptTime time.Time
|
||||
windowStartTime time.Time
|
||||
cooldownEndTime time.Time
|
||||
attempts int32
|
||||
consecutiveFailures int32
|
||||
inCooldown bool
|
||||
lastAttemptNano int64 // atomic, UnixNano of last attempt
|
||||
windowStartNano int64 // atomic, UnixNano of attempt-window start
|
||||
cooldownEndNano int64 // atomic, UnixNano; 0 = not in cooldown
|
||||
attempts int32 // atomic
|
||||
consecutiveFailures int32 // atomic
|
||||
}
|
||||
|
||||
// RefreshMetrics tracks coordinator performance metrics
|
||||
@@ -106,14 +126,18 @@ type RefreshMetrics struct {
|
||||
currentInFlightRefreshes int32
|
||||
}
|
||||
|
||||
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations
|
||||
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh
|
||||
// operations. All mutable fields are atomic so AllowRequest/RecordSuccess/
|
||||
// RecordFailure run without any mutex. The previous sync.RWMutex.RLock() was
|
||||
// taken on every CoordinateRefresh — under Yaegi this added 10-50ms of
|
||||
// interpreter dispatch per call, which compounded with attemptsMutex to keep
|
||||
// the pod's single CPU core saturated.
|
||||
type RefreshCircuitBreaker struct {
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureNano int64 // atomic, UnixNano of most recent failure
|
||||
lastSuccessNano int64 // atomic, UnixNano of most recent success
|
||||
config RefreshCircuitBreakerConfig
|
||||
mutex sync.RWMutex
|
||||
state int32
|
||||
failures int32
|
||||
state int32 // atomic: 0=closed, 1=open, 2=half-open
|
||||
failures int32 // atomic
|
||||
}
|
||||
|
||||
// RefreshCircuitBreakerConfig configures the refresh circuit breaker
|
||||
@@ -130,13 +154,13 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
|
||||
}
|
||||
|
||||
rc := &RefreshCoordinator{
|
||||
inFlightRefreshes: make(map[string]*refreshOperation),
|
||||
sessionRefreshAttempts: make(map[string]*refreshAttemptTracker),
|
||||
config: config,
|
||||
metrics: &RefreshMetrics{},
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
cleanupTimers: make(map[string]*time.Timer),
|
||||
// inFlightRefreshes and sessionRefreshAttempts are both sync.Map;
|
||||
// their zero values are ready to use.
|
||||
config: config,
|
||||
metrics: &RefreshMetrics{},
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
cleanupTimers: make(map[string]*time.Timer),
|
||||
circuitBreaker: &RefreshCircuitBreaker{
|
||||
config: RefreshCircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
@@ -227,13 +251,28 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
|
||||
tokenHash string,
|
||||
refreshToken string,
|
||||
) (*refreshOperation, bool, error) {
|
||||
rc.refreshMutex.Lock()
|
||||
defer rc.refreshMutex.Unlock()
|
||||
// Speculatively construct the operation we WOULD register if we win the
|
||||
// race. Allocating here keeps the LoadOrStore call below atomic and
|
||||
// avoids any global lock — under Yaegi the previous map+RWMutex design
|
||||
// held the write lock long enough (tens of ms per call) that concurrent
|
||||
// refreshes on the same coordinator serialized into a queue that grew
|
||||
// without bound. See struct comment on inFlightRefreshes.
|
||||
candidate := &refreshOperation{
|
||||
refreshToken: refreshToken,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
waiterCount: 1,
|
||||
}
|
||||
|
||||
// Check for existing operation while holding the lock
|
||||
if existingOp, exists := rc.inFlightRefreshes[tokenHash]; exists {
|
||||
if existing, loaded := rc.inFlightRefreshes.LoadOrStore(tokenHash, candidate); loaded {
|
||||
existingOp, ok := existing.(*refreshOperation)
|
||||
if !ok {
|
||||
// Defensive: anything stored here is always *refreshOperation, but
|
||||
// keep the typed assert so a programming error elsewhere doesn't
|
||||
// surface as a confusing panic in an interpreter frame.
|
||||
return nil, false, fmt.Errorf("inFlightRefreshes corrupt: unexpected type %T", existing)
|
||||
}
|
||||
if existingOp.refreshToken == refreshToken {
|
||||
// Join existing operation
|
||||
atomic.AddInt32(&existingOp.waiterCount, 1)
|
||||
return existingOp, false, nil
|
||||
}
|
||||
@@ -241,41 +280,60 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
|
||||
return nil, false, fmt.Errorf("refresh token mismatch")
|
||||
}
|
||||
|
||||
// No existing operation - check if we can create a new one
|
||||
// All checks happen while holding the lock to prevent races
|
||||
// We won the race and registered `candidate`. Apply gates now. If any
|
||||
// gate fails we must remove our entry from the map and signal failure
|
||||
// to any joiners that snuck in between LoadOrStore and now.
|
||||
if err := rc.applyLeaderGates(sessionID); err != nil {
|
||||
rc.failCandidate(tokenHash, candidate, err)
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// Check and record refresh attempt for rate limiting
|
||||
// Reserve concurrent slot via CAS — without the old global lock we can
|
||||
// no longer rely on mutex-mediated check-then-increment. If we lose the
|
||||
// CAS race we retry; if the limit has since been reached we back out.
|
||||
for {
|
||||
current := atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes)
|
||||
if int(current) >= rc.config.MaxConcurrentRefreshes {
|
||||
err := fmt.Errorf("maximum concurrent refresh operations reached")
|
||||
rc.failCandidate(tokenHash, candidate, err)
|
||||
return nil, false, err
|
||||
}
|
||||
if atomic.CompareAndSwapInt32(&rc.metrics.currentInFlightRefreshes, current, current+1) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return candidate, true, nil
|
||||
}
|
||||
|
||||
// applyLeaderGates runs the rate-limit, cooldown, and memory-pressure checks
|
||||
// that previously ran under the global refreshMutex. Only the leader (the
|
||||
// goroutine that just registered the operation) runs them; joiners share the
|
||||
// leader's outcome via operation.done.
|
||||
func (rc *RefreshCoordinator) applyLeaderGates(sessionID string) error {
|
||||
rc.recordRefreshAttempt(sessionID)
|
||||
if rc.isInCooldown(sessionID) {
|
||||
atomic.AddInt64(&rc.metrics.cooldownsTriggered, 1)
|
||||
return nil, false, fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
|
||||
return fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
|
||||
}
|
||||
|
||||
// Check memory pressure
|
||||
if rc.config.EnableMemoryPressureDetection && rc.isUnderMemoryPressure() {
|
||||
atomic.AddInt64(&rc.metrics.memoryPressureEvents, 1)
|
||||
return nil, false, fmt.Errorf("system under memory pressure, refresh denied")
|
||||
return fmt.Errorf("system under memory pressure, refresh denied")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check and reserve concurrent refresh slot atomically
|
||||
current := atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes)
|
||||
if int(current) >= rc.config.MaxConcurrentRefreshes {
|
||||
return nil, false, fmt.Errorf("maximum concurrent refresh operations reached")
|
||||
}
|
||||
|
||||
// Reserve the slot - we're still holding the lock so this is safe
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
|
||||
|
||||
// Create and register new operation
|
||||
operation := &refreshOperation{
|
||||
refreshToken: refreshToken,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
waiterCount: 1,
|
||||
}
|
||||
rc.inFlightRefreshes[tokenHash] = operation
|
||||
|
||||
return operation, true, nil
|
||||
// failCandidate removes the leader's just-registered operation from the
|
||||
// in-flight map and signals the error to any joiners by recording the result
|
||||
// and closing the done channel. This keeps the (nil, false, err) return path
|
||||
// equivalent to the pre-sync.Map version: callers see the error directly,
|
||||
// joiners see it via operation.done.
|
||||
func (rc *RefreshCoordinator) failCandidate(tokenHash string, op *refreshOperation, err error) {
|
||||
rc.inFlightRefreshes.Delete(tokenHash)
|
||||
op.mutex.Lock()
|
||||
op.result = &refreshResult{err: err}
|
||||
op.mutex.Unlock()
|
||||
close(op.done)
|
||||
}
|
||||
|
||||
// executeRefreshAsync performs the actual refresh operation asynchronously
|
||||
@@ -367,100 +425,108 @@ func (rc *RefreshCoordinator) scheduleDelayedCleanup(tokenHash string) {
|
||||
|
||||
// performCleanup removes the operation from the in-flight map.
|
||||
// Idempotent: only decrements the in-flight counter if an entry was actually
|
||||
// removed. This guards against any future path accidentally calling cleanup
|
||||
// twice for the same tokenHash (which would corrupt the refresh budget).
|
||||
// removed. LoadAndDelete is atomic so any concurrent failCandidate or repeat
|
||||
// cleanup call will see exactly one removal — the budget cannot be corrupted
|
||||
// by double-decrement.
|
||||
func (rc *RefreshCoordinator) performCleanup(tokenHash string) {
|
||||
rc.refreshMutex.Lock()
|
||||
_, existed := rc.inFlightRefreshes[tokenHash]
|
||||
if existed {
|
||||
delete(rc.inFlightRefreshes, tokenHash)
|
||||
}
|
||||
rc.refreshMutex.Unlock()
|
||||
if existed {
|
||||
if _, existed := rc.inFlightRefreshes.LoadAndDelete(tokenHash); existed {
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// isInCooldown checks if a session is in cooldown after recording an attempt
|
||||
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
// getOrCreateTracker fetches the tracker for sessionID or atomically creates a
|
||||
// fresh one. The sync.Map.LoadOrStore semantics make this lock-free even under
|
||||
// concurrent first-touch races: at most one tracker per sessionID survives.
|
||||
//
|
||||
// trackerFromMapValue centralizes the type assertion so the lint-mandated
|
||||
// two-value form lives in one place; the stored type is always
|
||||
// *refreshAttemptTracker by construction.
|
||||
func trackerFromMapValue(v interface{}) *refreshAttemptTracker {
|
||||
t, _ := v.(*refreshAttemptTracker)
|
||||
return t
|
||||
}
|
||||
|
||||
tracker, exists := rc.sessionRefreshAttempts[sessionID]
|
||||
if !exists {
|
||||
func (rc *RefreshCoordinator) getOrCreateTracker(sessionID string) *refreshAttemptTracker {
|
||||
if v, ok := rc.sessionRefreshAttempts.Load(sessionID); ok {
|
||||
return trackerFromMapValue(v)
|
||||
}
|
||||
fresh := &refreshAttemptTracker{
|
||||
windowStartNano: time.Now().UnixNano(),
|
||||
}
|
||||
actual, _ := rc.sessionRefreshAttempts.LoadOrStore(sessionID, fresh)
|
||||
return trackerFromMapValue(actual)
|
||||
}
|
||||
|
||||
// isInCooldown checks if a session is in cooldown. Lock-free read with a
|
||||
// best-effort cooldown-reset CAS on the cooldownEndNano sentinel. If the
|
||||
// reset races with another goroutine we accept the loser's view (the winner's
|
||||
// reset still happens). The attempt-window expiry and limit-exceeded paths
|
||||
// are write-mostly but use atomic.StoreInt64/AddInt32 — never a held lock.
|
||||
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
|
||||
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
|
||||
if !ok {
|
||||
return false // No tracker means first attempt, not in cooldown
|
||||
}
|
||||
|
||||
tracker := trackerFromMapValue(v)
|
||||
now := time.Now()
|
||||
nowNano := now.UnixNano()
|
||||
|
||||
// Check if already in cooldown
|
||||
if tracker.inCooldown {
|
||||
if now.After(tracker.cooldownEndTime) {
|
||||
// Cooldown expired, reset tracker
|
||||
tracker.inCooldown = false
|
||||
tracker.attempts = 1 // Already recorded one attempt
|
||||
tracker.consecutiveFailures = 0
|
||||
tracker.windowStartTime = now
|
||||
return false
|
||||
// Already in cooldown?
|
||||
if cooldownEnd := atomic.LoadInt64(&tracker.cooldownEndNano); cooldownEnd != 0 {
|
||||
if nowNano <= cooldownEnd {
|
||||
return true // still in cooldown
|
||||
}
|
||||
// Cooldown expired. Best-effort reset (a concurrent caller may also
|
||||
// reset; the result is equivalent — fresh window + one recorded
|
||||
// attempt — so the CAS race is benign).
|
||||
if atomic.CompareAndSwapInt64(&tracker.cooldownEndNano, cooldownEnd, 0) {
|
||||
atomic.StoreInt32(&tracker.attempts, 1)
|
||||
atomic.StoreInt32(&tracker.consecutiveFailures, 0)
|
||||
atomic.StoreInt64(&tracker.windowStartNano, nowNano)
|
||||
}
|
||||
return true // Still in cooldown
|
||||
}
|
||||
|
||||
// Check if window expired
|
||||
if now.Sub(tracker.windowStartTime) > rc.config.RefreshAttemptWindow {
|
||||
// Reset window
|
||||
tracker.attempts = 1 // Already recorded one attempt
|
||||
tracker.windowStartTime = now
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if just exceeded attempt limit
|
||||
if int(tracker.attempts) >= rc.config.MaxRefreshAttempts {
|
||||
// Enter cooldown now
|
||||
tracker.inCooldown = true
|
||||
tracker.cooldownEndTime = now.Add(rc.config.RefreshCooldownPeriod)
|
||||
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
|
||||
sessionID, tracker.attempts)
|
||||
// Window expired?
|
||||
if windowStart := atomic.LoadInt64(&tracker.windowStartNano); time.Duration(nowNano-windowStart) > rc.config.RefreshAttemptWindow {
|
||||
atomic.StoreInt32(&tracker.attempts, 1)
|
||||
atomic.StoreInt64(&tracker.windowStartNano, nowNano)
|
||||
return false
|
||||
}
|
||||
|
||||
// Just exceeded attempt limit?
|
||||
if int(atomic.LoadInt32(&tracker.attempts)) >= rc.config.MaxRefreshAttempts {
|
||||
end := now.Add(rc.config.RefreshCooldownPeriod).UnixNano()
|
||||
// Only one CAS winner publishes the cooldown end + logs.
|
||||
if atomic.CompareAndSwapInt64(&tracker.cooldownEndNano, 0, end) {
|
||||
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
|
||||
sessionID, atomic.LoadInt32(&tracker.attempts))
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// recordRefreshAttempt records a refresh attempt for rate limiting
|
||||
// recordRefreshAttempt records a refresh attempt for rate limiting. Lock-free:
|
||||
// LoadOrStore for the tracker, atomic counters/timestamps for fields.
|
||||
func (rc *RefreshCoordinator) recordRefreshAttempt(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
tracker, exists := rc.sessionRefreshAttempts[sessionID]
|
||||
if !exists {
|
||||
tracker = &refreshAttemptTracker{
|
||||
windowStartTime: time.Now(),
|
||||
}
|
||||
rc.sessionRefreshAttempts[sessionID] = tracker
|
||||
}
|
||||
|
||||
tracker := rc.getOrCreateTracker(sessionID)
|
||||
atomic.AddInt32(&tracker.attempts, 1)
|
||||
tracker.lastAttemptTime = time.Now()
|
||||
atomic.StoreInt64(&tracker.lastAttemptNano, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// recordRefreshSuccess records a successful refresh
|
||||
// recordRefreshSuccess records a successful refresh. Lock-free.
|
||||
func (rc *RefreshCoordinator) recordRefreshSuccess(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
|
||||
tracker.consecutiveFailures = 0
|
||||
if v, ok := rc.sessionRefreshAttempts.Load(sessionID); ok {
|
||||
atomic.StoreInt32(&trackerFromMapValue(v).consecutiveFailures, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// recordRefreshFailure records a failed refresh
|
||||
// recordRefreshFailure records a failed refresh. Lock-free.
|
||||
func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
|
||||
atomic.AddInt32(&tracker.consecutiveFailures, 1)
|
||||
if v, ok := rc.sessionRefreshAttempts.Load(sessionID); ok {
|
||||
atomic.AddInt32(&trackerFromMapValue(v).consecutiveFailures, 1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -512,20 +578,22 @@ func (rc *RefreshCoordinator) cleanupRoutine() {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleEntries removes outdated tracking entries
|
||||
// cleanupStaleEntries removes outdated tracking entries. Lock-free iteration
|
||||
// via sync.Map.Range; safe to race with concurrent reads/writes.
|
||||
func (rc *RefreshCoordinator) cleanupStaleEntries() {
|
||||
now := time.Now()
|
||||
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
// Clean up old session trackers
|
||||
for sessionID, tracker := range rc.sessionRefreshAttempts {
|
||||
// Remove trackers that haven't been used recently
|
||||
if now.Sub(tracker.lastAttemptTime) > 2*rc.config.RefreshAttemptWindow {
|
||||
delete(rc.sessionRefreshAttempts, sessionID)
|
||||
cutoff := time.Now().Add(-2 * rc.config.RefreshAttemptWindow).UnixNano()
|
||||
rc.sessionRefreshAttempts.Range(func(key, value interface{}) bool {
|
||||
tracker := trackerFromMapValue(value)
|
||||
if tracker == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if atomic.LoadInt64(&tracker.lastAttemptNano) < cutoff {
|
||||
// Compare-and-delete to avoid evicting a tracker that was just
|
||||
// re-used by a concurrent caller. We compare by pointer identity.
|
||||
rc.sessionRefreshAttempts.CompareAndDelete(key, value)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// GetMetrics returns current coordinator metrics
|
||||
@@ -558,63 +626,51 @@ func (rc *RefreshCoordinator) Shutdown() {
|
||||
rc.wg.Wait()
|
||||
}
|
||||
|
||||
// AllowRequest checks if the circuit breaker allows a request
|
||||
// AllowRequest reports whether the circuit breaker allows a request. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) AllowRequest() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
|
||||
switch state {
|
||||
case 0: // Closed
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 0: // closed
|
||||
return true
|
||||
case 1: // Open
|
||||
if time.Since(cb.lastFailureTime) > cb.config.OpenDuration {
|
||||
// Try to transition to half-open
|
||||
case 1: // open
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailureNano)
|
||||
if time.Duration(time.Now().UnixNano()-lastFail) > cb.config.OpenDuration {
|
||||
// Transition to half-open; first CAS winner gets the probe.
|
||||
if atomic.CompareAndSwapInt32(&cb.state, 1, 2) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
case 2: // Half-open
|
||||
case 2: // half-open
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation
|
||||
// RecordSuccess records a successful operation. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) RecordSuccess() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
if state == 2 { // Half-open
|
||||
// Close the circuit
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 2: // half-open -> close
|
||||
atomic.StoreInt32(&cb.state, 0)
|
||||
atomic.StoreInt32(&cb.failures, 0)
|
||||
} else if state == 0 { // Closed
|
||||
// Reset failure count on success
|
||||
case 0: // closed
|
||||
atomic.StoreInt32(&cb.failures, 0)
|
||||
}
|
||||
cb.lastSuccessTime = time.Now()
|
||||
atomic.StoreInt64(&cb.lastSuccessNano, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// RecordFailure records a failed operation
|
||||
// RecordFailure records a failed operation. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) RecordFailure() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
failures := atomic.AddInt32(&cb.failures, 1)
|
||||
cb.lastFailureTime = time.Now()
|
||||
atomic.StoreInt64(&cb.lastFailureNano, time.Now().UnixNano())
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
|
||||
if state == 0 && int(failures) >= cb.config.MaxFailures {
|
||||
// Open the circuit
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
} else if state == 2 {
|
||||
// Half-open failed, return to open
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 0:
|
||||
if int(failures) >= cb.config.MaxFailures {
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
}
|
||||
case 2:
|
||||
// Half-open probe failed -> back to open.
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
}
|
||||
}
|
||||
|
||||
+16
-15
@@ -365,10 +365,12 @@ func TestMemoryLeakPrevention(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Verify cleanup is working
|
||||
coordinator.attemptsMutex.RLock()
|
||||
sessionCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
// Verify cleanup is working. sync.Map has no Len(); count via Range.
|
||||
sessionCount := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
sessionCount++
|
||||
return true
|
||||
})
|
||||
|
||||
// Should have cleaned up old sessions (only recent ones remain)
|
||||
if sessionCount > numWorkers*2 {
|
||||
@@ -650,24 +652,23 @@ func TestCleanupRoutine(t *testing.T) {
|
||||
coordinator.recordRefreshAttempt(fmt.Sprintf("session_%d", i))
|
||||
}
|
||||
|
||||
// Verify sessions exist
|
||||
coordinator.attemptsMutex.RLock()
|
||||
initialCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
countSessions := func() int {
|
||||
n := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
n++
|
||||
return true
|
||||
})
|
||||
return n
|
||||
}
|
||||
|
||||
if initialCount != 5 {
|
||||
if initialCount := countSessions(); initialCount != 5 {
|
||||
t.Errorf("Expected 5 sessions, got %d", initialCount)
|
||||
}
|
||||
|
||||
// Wait for cleanup to run (2x window + cleanup interval)
|
||||
time.Sleep(2*config.RefreshAttemptWindow + 2*config.CleanupInterval)
|
||||
|
||||
// Verify sessions were cleaned up
|
||||
coordinator.attemptsMutex.RLock()
|
||||
finalCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
|
||||
if finalCount != 0 {
|
||||
if finalCount := countSessions(); finalCount != 0 {
|
||||
t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount)
|
||||
}
|
||||
}
|
||||
|
||||
+142
@@ -0,0 +1,142 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// pluginVersion is bumped manually on each release. Keep in sync with the
|
||||
// most recent git tag (see `git tag --sort=-v:refname | head -1`).
|
||||
const pluginVersion = "1.0.11"
|
||||
|
||||
const (
|
||||
telemetryProject = "traefikoidc"
|
||||
telemetryTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
// telemetryEndpoint is intentionally a var rather than a const so the test
|
||||
// suite in this package can retarget it at an httptest server. Production
|
||||
// code never mutates it.
|
||||
var telemetryEndpoint = "https://oss.raczylo.com/v1/ping"
|
||||
|
||||
// telemetryOnce guarantees a single anonymous "plugin loaded" ping per
|
||||
// process lifetime. Traefik can instantiate a middleware many times per
|
||||
// process (one per route using the plugin); the sync.Once gate keeps the
|
||||
// fire-and-forget call from amplifying into many pings.
|
||||
//
|
||||
// Reset in tests via `telemetryOnce = sync.Once{}`.
|
||||
var telemetryOnce sync.Once
|
||||
|
||||
// telemetryInflight tracks any background goroutine started by sendTelemetry.
|
||||
// Tests Wait on it to drain in-flight goroutines before mutating package
|
||||
// state. Production code never calls Wait — the goroutine is fire-and-forget.
|
||||
var telemetryInflight sync.WaitGroup
|
||||
|
||||
// sendTelemetry fires one anonymous usage ping in the background. It is
|
||||
// failproof by contract:
|
||||
//
|
||||
// - never blocks the caller
|
||||
// - never panics (the goroutine recovers internally)
|
||||
// - never returns errors
|
||||
// - silently dropped on invalid input, env-driven opt-out, or network failure
|
||||
//
|
||||
// Opt-out is honored via any of:
|
||||
//
|
||||
// - DO_NOT_TRACK=1
|
||||
// - OSS_TELEMETRY_DISABLED=1
|
||||
// - TRAEFIKOIDC_DISABLE_TELEMETRY=1
|
||||
//
|
||||
// Yaegi note: this file deliberately avoids generics (atomic.Pointer[T]) and
|
||||
// range-over-int (Go 1.22) so it interprets under any reasonably recent
|
||||
// Traefik yaegi runtime.
|
||||
func sendTelemetry(version string) {
|
||||
telemetryOnce.Do(func() {
|
||||
if telemetryDisabledByEnv() {
|
||||
return
|
||||
}
|
||||
if !validTelemetryVersion(version) {
|
||||
return
|
||||
}
|
||||
telemetryInflight.Add(1)
|
||||
go func() {
|
||||
defer telemetryInflight.Done()
|
||||
defer func() { _ = recover() }()
|
||||
doTelemetryPost(version)
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func telemetryDisabledByEnv() bool {
|
||||
keys := []string{
|
||||
"DO_NOT_TRACK",
|
||||
"OSS_TELEMETRY_DISABLED",
|
||||
"TRAEFIKOIDC_DISABLE_TELEMETRY",
|
||||
}
|
||||
for _, k := range keys {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv(k)))
|
||||
if v == "1" || v == "true" || v == "yes" || v == "on" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// validTelemetryVersion mirrors the server-side regex ^[A-Za-z0-9.+_-]{1,32}$
|
||||
// using a byte loop. No allocation, no regexp dependency.
|
||||
//
|
||||
// Yaegi note: written as an `||` chain rather than `switch{case A,B,C:}` —
|
||||
// some yaegi releases mis-evaluate comma-separated case expressions in
|
||||
// switch-true blocks, returning false for all inputs.
|
||||
func validTelemetryVersion(v string) bool {
|
||||
if len(v) == 0 || len(v) > 32 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(v); i++ {
|
||||
c := v[i]
|
||||
ok := (c >= 'A' && c <= 'Z') ||
|
||||
(c >= 'a' && c <= 'z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '.' || c == '+' || c == '_' || c == '-'
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// doTelemetryPost builds the JSON body manually. The project name is a
|
||||
// constant and the version is pre-validated against an ASCII-only allowlist,
|
||||
// so direct concatenation needs no JSON escaping.
|
||||
func doTelemetryPost(version string) {
|
||||
body := make([]byte, 0, 96)
|
||||
body = append(body, `{"project":"`...)
|
||||
body = append(body, telemetryProject...)
|
||||
body = append(body, `","version":"`...)
|
||||
body = append(body, version...)
|
||||
body = append(body, `","ts":`...)
|
||||
body = strconv.AppendInt(body, time.Now().Unix(), 10)
|
||||
body = append(body, '}')
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), telemetryTimeout)
|
||||
defer cancel()
|
||||
|
||||
url := telemetryEndpoint
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: telemetryTimeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// resetTelemetryState restores package-level mutable state so tests do not
|
||||
// contaminate one another. The cleanup waits for any in-flight ping goroutine
|
||||
// to finish before restoring telemetryEndpoint — without that drain step the
|
||||
// goroutine and the cleanup would race on the var.
|
||||
func resetTelemetryState(t *testing.T) {
|
||||
t.Helper()
|
||||
telemetryOnce = sync.Once{}
|
||||
prev := telemetryEndpoint
|
||||
t.Cleanup(func() {
|
||||
telemetryInflight.Wait()
|
||||
telemetryEndpoint = prev
|
||||
telemetryOnce = sync.Once{}
|
||||
})
|
||||
}
|
||||
|
||||
func newTelemetryServer(t *testing.T, status int) (hits *int32, lastBody func() string) {
|
||||
t.Helper()
|
||||
var counter int32
|
||||
var mu sync.Mutex
|
||||
var body string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&counter, 1)
|
||||
b, _ := io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
mu.Lock()
|
||||
body = string(b)
|
||||
mu.Unlock()
|
||||
w.WriteHeader(status)
|
||||
}))
|
||||
telemetryEndpoint = srv.URL
|
||||
t.Cleanup(srv.Close)
|
||||
return &counter, func() string {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return body
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidTelemetryVersion(t *testing.T) {
|
||||
good := []string{"1.2.3", "1.4.0-beta1", "2.0", "v1.0.0", "1.0.0+meta", "dev"}
|
||||
for _, v := range good {
|
||||
if !validTelemetryVersion(v) {
|
||||
t.Errorf("validTelemetryVersion(%q) = false, want true", v)
|
||||
}
|
||||
}
|
||||
bad := []string{"", "has space", "semi;colon", strings.Repeat("1", 33)}
|
||||
for _, v := range bad {
|
||||
if validTelemetryVersion(v) {
|
||||
t.Errorf("validTelemetryVersion(%q) = true, want false", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTelemetryDisabledByEnv(t *testing.T) {
|
||||
for _, k := range []string{"DO_NOT_TRACK", "OSS_TELEMETRY_DISABLED", "TRAEFIKOIDC_DISABLE_TELEMETRY"} {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Setenv(k, "1")
|
||||
if !telemetryDisabledByEnv() {
|
||||
t.Fatalf("%s=1 should disable", k)
|
||||
}
|
||||
})
|
||||
}
|
||||
t.Run("falsy_values_do_not_disable", func(t *testing.T) {
|
||||
t.Setenv("DO_NOT_TRACK", "0")
|
||||
t.Setenv("OSS_TELEMETRY_DISABLED", "false")
|
||||
t.Setenv("TRAEFIKOIDC_DISABLE_TELEMETRY", "no")
|
||||
if telemetryDisabledByEnv() {
|
||||
t.Fatal("falsy env values should not disable")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendTelemetry_FiresOnceAcrossManyCalls(t *testing.T) {
|
||||
resetTelemetryState(t)
|
||||
hits, lastBody := newTelemetryServer(t, http.StatusNoContent)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
sendTelemetry("1.2.3")
|
||||
}
|
||||
telemetryInflight.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(hits); got != 1 {
|
||||
t.Fatalf("expected exactly 1 hit, got %d", got)
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Project string `json:"project"`
|
||||
Version string `json:"version"`
|
||||
Ts int64 `json:"ts"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(lastBody()), &payload); err != nil {
|
||||
t.Fatalf("server received non-JSON body: %q (err: %v)", lastBody(), err)
|
||||
}
|
||||
if payload.Project != "traefikoidc" || payload.Version != "1.2.3" || payload.Ts <= 0 {
|
||||
t.Fatalf("unexpected payload: %+v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendTelemetry_RespectsDisableEnv(t *testing.T) {
|
||||
resetTelemetryState(t)
|
||||
hits, _ := newTelemetryServer(t, http.StatusNoContent)
|
||||
t.Setenv("DO_NOT_TRACK", "1")
|
||||
|
||||
sendTelemetry("1.2.3")
|
||||
telemetryInflight.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(hits); got != 0 {
|
||||
t.Fatalf("DO_NOT_TRACK should suppress; got %d hits", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendTelemetry_DropsInvalidVersion(t *testing.T) {
|
||||
resetTelemetryState(t)
|
||||
hits, _ := newTelemetryServer(t, http.StatusNoContent)
|
||||
|
||||
sendTelemetry("has space")
|
||||
telemetryInflight.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(hits); got != 0 {
|
||||
t.Fatalf("invalid version should suppress; got %d hits", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendTelemetry_DoesNotBlock(t *testing.T) {
|
||||
resetTelemetryState(t)
|
||||
// Hanging server proves the caller is never blocked. The 2s context
|
||||
// timeout in doTelemetryPost ensures the goroutine eventually exits;
|
||||
// resetTelemetryState's cleanup waits for that drain before restoring
|
||||
// telemetryEndpoint so there is no race with this test's mutation.
|
||||
hung := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
time.Sleep(5 * time.Second)
|
||||
}))
|
||||
t.Cleanup(hung.Close)
|
||||
telemetryEndpoint = hung.URL
|
||||
|
||||
start := time.Now()
|
||||
sendTelemetry("1.2.3")
|
||||
if elapsed := time.Since(start); elapsed > 50*time.Millisecond {
|
||||
t.Fatalf("sendTelemetry blocked for %v, expected near-instant return", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendTelemetry_SurvivesServerError(t *testing.T) {
|
||||
resetTelemetryState(t)
|
||||
hits, _ := newTelemetryServer(t, http.StatusInternalServerError)
|
||||
|
||||
sendTelemetry("1.2.3")
|
||||
telemetryInflight.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(hits); got != 1 {
|
||||
t.Fatalf("request should still reach server even on 500; got %d hits", got)
|
||||
}
|
||||
}
|
||||
@@ -65,7 +65,19 @@ type ProviderMetadata struct {
|
||||
// the complete authentication flow. It's designed to work seamlessly with Traefik's
|
||||
// plugin system and provides flexible configuration options.
|
||||
type TraefikOidc struct {
|
||||
lastMetadataRetryTime time.Time
|
||||
// lastMetadataRetryNano is the UnixNano timestamp of the last metadata
|
||||
// recovery attempt. Stored atomically so the hot ServeHTTP path can
|
||||
// throttle retries without acquiring metadataRetryMutex on every request.
|
||||
lastMetadataRetryNano int64
|
||||
// firstRequestStarted is 0 until the very first non-health request fires
|
||||
// the background-task bootstrap; then it flips to 1 via CAS. Replaces the
|
||||
// firstRequestMutex + firstRequestReceived combo which previously took
|
||||
// a write lock on every non-health request forever.
|
||||
firstRequestStarted int32
|
||||
// metadataRefreshStartedAtomic is the CAS-only variant of the old
|
||||
// metadataRefreshStarted bool. Both flags live under the same atomic so
|
||||
// concurrent first-request goroutines race exactly once.
|
||||
metadataRefreshStartedAtomic int32
|
||||
jwkCache JWKCacheInterface
|
||||
jwtVerifier JWTVerifier
|
||||
ctx context.Context
|
||||
@@ -130,17 +142,13 @@ type TraefikOidc struct {
|
||||
maxRefreshTokenAge time.Duration
|
||||
metadataMu sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
metadataRetryMutex sync.Mutex
|
||||
firstRequestMutex sync.Mutex
|
||||
sessionInvalidationCache CacheInterface
|
||||
refreshResultCache CacheInterface
|
||||
minimalHeaders bool
|
||||
stripAuthCookies bool
|
||||
enableBackchannelLogout bool
|
||||
enableFrontchannelLogout bool
|
||||
firstRequestReceived bool
|
||||
requireTokenIntrospection bool
|
||||
metadataRefreshStarted bool
|
||||
allowPrivateIPAddresses bool
|
||||
disableReplayDetection bool
|
||||
allowOpaqueTokens bool
|
||||
|
||||
Reference in New Issue
Block a user