release 0.7.2 (#66)

* Remove trailing / from metadata provider.

* Resolves issue #67
    - Before: 100 concurrent requests → 300+ refresh attempts → OOM
    - After: 100 concurrent requests → 1 refresh attempt → Stable memory

Added following changes:
    - Introduced a refresh coordinator to manage concurrent refresh requests
    - Implemented a test to simulate high concurrency and verify memory stability

* Issue #67 fixed.
This commit is contained in:
2025-09-25 12:52:53 +01:00
committed by GitHub
parent 1b49e133da
commit 1e4142a7fb
9 changed files with 2001 additions and 1 deletions
+541
View File
@@ -0,0 +1,541 @@
package traefikoidc
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// TestIssue67_InfiniteRefreshLoop reproduces and verifies the fix for issue #67
// where concurrent requests with expired tokens caused an infinite refresh loop
// leading to OOM conditions
func TestIssue67_InfiniteRefreshLoop(t *testing.T) {
// Track memory at start
runtime.GC()
var startMem runtime.MemStats
runtime.ReadMemStats(&startMem)
// Create a mock authorization server
var refreshAttempts int32
var concurrentRefreshes int32
var maxConcurrent int32
// Create a handler with server URL to be set after creation
var serverURL string
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
// Track concurrent refresh attempts
current := atomic.AddInt32(&concurrentRefreshes, 1)
defer atomic.AddInt32(&concurrentRefreshes, -1)
// Update max concurrent
for {
max := atomic.LoadInt32(&maxConcurrent)
if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) {
break
}
}
attempts := atomic.AddInt32(&refreshAttempts, 1)
// Simulate slow/failing token endpoint (like in the issue)
if attempts < 5 {
// First few attempts fail to trigger retries
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(`{"error": "temporarily_unavailable"}`))
} else {
// Eventually succeed
time.Sleep(50 * time.Millisecond)
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"access_token": "new_access_token",
"refresh_token": "new_refresh_token",
"id_token": "new_id_token",
"expires_in": 3600,
"token_type": "Bearer"
}`))
}
case "/.well-known/openid-configuration":
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(fmt.Sprintf(`{
"issuer": "%s",
"authorization_endpoint": "%s/authorize",
"token_endpoint": "%s/token",
"jwks_uri": "%s/keys",
"response_types_supported": ["code"],
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["RS256"],
"scopes_supported": ["openid", "profile", "email"],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
"claims_supported": ["sub", "name", "email"]
}`, serverURL, serverURL, serverURL, serverURL)))
case "/keys":
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"keys": [{
"kty": "RSA",
"use": "sig",
"kid": "test-key",
"n": "test",
"e": "AQAB"
}]
}`))
}
}))
defer authServer.Close()
// Set the server URL after creation
serverURL = authServer.URL
// Setup TraefikOIDC with refresh coordinator
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
config.MaxRefreshAttempts = 3
config.RefreshAttemptWindow = 1 * time.Second
config.MaxConcurrentRefreshes = 2
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
// Simulate expired session
expiredSession := &MockExpiredSession{
refreshToken: "test_refresh_token",
sessionID: "test_session",
isExpired: true,
}
// Simulate multiple concurrent requests (as reported in issue)
numConcurrentRequests := 50
var wg sync.WaitGroup
wg.Add(numConcurrentRequests)
// Track results
var successCount int32
var errorCount int32
errors := make([]error, 0, numConcurrentRequests)
var errorMutex sync.Mutex
// Launch concurrent requests with expired tokens
startTime := time.Now()
timeout := 5 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for i := 0; i < numConcurrentRequests; i++ {
go func(reqID int) {
defer wg.Done()
// Each request tries to refresh the expired token
refreshFunc := func() (*TokenResponse, error) {
// Simulate calling the token endpoint
resp, err := http.Post(
serverURL+"/token",
"application/x-www-form-urlencoded",
nil,
)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed: %d", resp.StatusCode)
}
return &TokenResponse{
AccessToken: fmt.Sprintf("new_access_%d", reqID),
RefreshToken: "new_refresh",
IDToken: "new_id",
ExpiresIn: 3600,
}, nil
}
// Use coordinator to prevent infinite loop
result, err := coordinator.CoordinateRefresh(
ctx,
expiredSession.sessionID,
expiredSession.refreshToken,
refreshFunc,
)
if err != nil {
atomic.AddInt32(&errorCount, 1)
errorMutex.Lock()
errors = append(errors, err)
errorMutex.Unlock()
} else if result != nil {
atomic.AddInt32(&successCount, 1)
}
}(i)
}
// Wait for completion or timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Completed normally
case <-ctx.Done():
t.Fatal("Test timed out - possible infinite loop detected!")
}
elapsed := time.Since(startTime)
// Verify no infinite loop occurred
if elapsed > timeout {
t.Fatalf("Requests took too long: %v (possible infinite loop)", elapsed)
}
// Check memory usage
runtime.GC()
var endMem runtime.MemStats
runtime.ReadMemStats(&endMem)
// Calculate memory growth safely to prevent underflow
var memGrowthMB float64
if endMem.HeapAlloc >= startMem.HeapAlloc {
memGrowthMB = float64(endMem.HeapAlloc-startMem.HeapAlloc) / (1024 * 1024)
} else {
// Memory decreased (GC occurred), treat as 0 growth
memGrowthMB = 0
}
t.Logf("Memory stats: start=%d bytes, end=%d bytes, growth=%.2f MB",
startMem.HeapAlloc, endMem.HeapAlloc, memGrowthMB)
// Memory should not grow excessively (issue reported OOM at 2GB)
if memGrowthMB > 100 {
t.Errorf("Excessive memory growth: %.2f MB (possible memory leak)", memGrowthMB)
}
// Verify refresh deduplication worked
actualRefreshAttempts := atomic.LoadInt32(&refreshAttempts)
t.Logf("Total refresh attempts to server: %d", actualRefreshAttempts)
t.Logf("Max concurrent refreshes: %d", maxConcurrent)
t.Logf("Successful refreshes: %d", successCount)
t.Logf("Failed refreshes: %d", errorCount)
// With deduplication, refresh attempts should be much less than concurrent requests
if actualRefreshAttempts > int32(numConcurrentRequests/2) {
t.Errorf("Too many refresh attempts (%d), deduplication not working properly",
actualRefreshAttempts)
}
// Max concurrent should respect our limit
if maxConcurrent > int32(config.MaxConcurrentRefreshes) {
t.Errorf("Max concurrent refreshes (%d) exceeded configured limit (%d)",
maxConcurrent, config.MaxConcurrentRefreshes)
}
// Check coordinator metrics
metrics := coordinator.GetMetrics()
t.Logf("Coordinator metrics: %+v", metrics)
if deduped, ok := metrics["deduplicated_requests"].(int64); ok {
if deduped == 0 {
t.Error("No requests were deduplicated - deduplication not working")
}
t.Logf("Deduplicated requests: %d", deduped)
}
}
// TestIssue67_WithoutCoordinator demonstrates the issue without the fix
// WARNING: This test may consume significant memory - skip in CI
func TestIssue67_WithoutCoordinator(t *testing.T) {
if testing.Short() {
t.Skip("Skipping memory-intensive test in short mode")
}
// Only run this test with explicit flag to demonstrate the issue
if !testing.Verbose() {
t.Skip("Skipping demonstration of issue without fix (run with -v to see)")
}
// Track memory at start
runtime.GC()
var startMem runtime.MemStats
runtime.ReadMemStats(&startMem)
var refreshAttempts int32
var maxConcurrent int32
var currentConcurrent int32
// Simulate the issue: multiple goroutines attempting refresh without coordination
numRequests := 100
var wg sync.WaitGroup
wg.Add(numRequests)
// Use a context with short timeout to prevent actual OOM
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
for i := 0; i < numRequests; i++ {
go func(id int) {
defer wg.Done()
// Simulate retry logic without deduplication (the bug)
for attempt := 0; attempt < 3; attempt++ {
select {
case <-ctx.Done():
return
default:
}
current := atomic.AddInt32(&currentConcurrent, 1)
// Track max concurrent
for {
max := atomic.LoadInt32(&maxConcurrent)
if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) {
break
}
}
atomic.AddInt32(&refreshAttempts, 1)
// Simulate token refresh with exponential backoff
time.Sleep(time.Duration(attempt*100) * time.Millisecond)
// Allocate memory to simulate token processing
_ = make([]byte, 1024*10) // 10KB per attempt
atomic.AddInt32(&currentConcurrent, -1)
// Simulate failure requiring retry
if attempt < 2 {
continue
}
break
}
}(i)
}
// Wait with timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Completed
case <-ctx.Done():
// Timed out (expected in problematic scenario)
}
// Check memory usage
runtime.GC()
var endMem runtime.MemStats
runtime.ReadMemStats(&endMem)
memGrowthMB := float64(endMem.HeapAlloc-startMem.HeapAlloc) / (1024 * 1024)
t.Logf("WITHOUT COORDINATOR:")
t.Logf(" Refresh attempts: %d", refreshAttempts)
t.Logf(" Max concurrent: %d", maxConcurrent)
t.Logf(" Memory growth: %.2f MB", memGrowthMB)
// This demonstrates the issue - high concurrency and many attempts
if refreshAttempts < int32(numRequests*2) {
t.Logf("Note: Without coordinator, saw %d refresh attempts for %d requests",
refreshAttempts, numRequests)
}
}
// MockExpiredSession simulates an expired session for testing
type MockExpiredSession struct {
refreshToken string
sessionID string
isExpired bool
}
func (m *MockExpiredSession) GetRefreshToken() string {
return m.refreshToken
}
func (m *MockExpiredSession) GetSessionID() string {
return m.sessionID
}
func (m *MockExpiredSession) IsExpired() bool {
return m.isExpired
}
// BenchmarkRefreshWithCoordinator measures performance with the fix
func BenchmarkRefreshWithCoordinator(b *testing.B) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
refreshFunc := func() (*TokenResponse, error) {
// Simulate token refresh
time.Sleep(10 * time.Millisecond)
return &TokenResponse{
AccessToken: "new_token",
RefreshToken: "new_refresh",
}, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
ctx := context.Background()
sessionID := fmt.Sprintf("session_%d", i%10)
refreshToken := "refresh_token"
_, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
i++
}
})
b.StopTimer()
metrics := coordinator.GetMetrics()
b.Logf("Total requests: %v", metrics["total_requests"])
b.Logf("Deduplicated: %v", metrics["deduplicated_requests"])
b.Logf("Success rate: %.2f%%",
float64(metrics["successful_refreshes"].(int64))/
float64(metrics["total_requests"].(int64))*100)
}
// TestRefreshCoordinatorIntegration tests the full integration
func TestRefreshCoordinatorIntegration(t *testing.T) {
// This test verifies the coordinator integrates properly with:
// 1. Circuit breaker
// 2. Rate limiting
// 3. Deduplication
// 4. Memory management
// 5. Cleanup routines
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
config.MaxRefreshAttempts = 5
config.RefreshAttemptWindow = 1 * time.Second
config.RefreshCooldownPeriod = 2 * time.Second
config.MaxConcurrentRefreshes = 3
config.CleanupInterval = 500 * time.Millisecond
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
// Test 1: Normal operation
t.Run("NormalOperation", func(t *testing.T) {
refreshFunc := func() (*TokenResponse, error) {
return &TokenResponse{AccessToken: "token1"}, nil
}
ctx := context.Background()
result, err := coordinator.CoordinateRefresh(ctx, "session1", "refresh1", refreshFunc)
if err != nil {
t.Errorf("Normal refresh failed: %v", err)
}
if result == nil || result.AccessToken != "token1" {
t.Error("Invalid result from normal refresh")
}
})
// Test 2: Circuit breaker activation
t.Run("CircuitBreaker", func(t *testing.T) {
failingRefresh := func() (*TokenResponse, error) {
return nil, fmt.Errorf("service unavailable")
}
// Trigger circuit breaker
for i := 0; i < 4; i++ {
ctx := context.Background()
_, _ = coordinator.CoordinateRefresh(ctx,
fmt.Sprintf("cb_session_%d", i), "refresh_cb", failingRefresh)
}
// Next request should be blocked by circuit breaker
ctx := context.Background()
_, err := coordinator.CoordinateRefresh(ctx, "cb_session_blocked", "refresh_cb", failingRefresh)
if err == nil || !strings.Contains(err.Error(), "circuit breaker") {
t.Errorf("Circuit breaker should have blocked request: %v", err)
}
})
// Test 3: Rate limiting
t.Run("RateLimiting", func(t *testing.T) {
// Reset circuit breaker to closed state for this test
coordinator.circuitBreaker.mutex.Lock()
atomic.StoreInt32(&coordinator.circuitBreaker.state, 0) // closed
atomic.StoreInt32(&coordinator.circuitBreaker.failures, 0)
coordinator.circuitBreaker.mutex.Unlock()
// Temporarily increase circuit breaker threshold to not interfere
oldMaxFailures := coordinator.circuitBreaker.config.MaxFailures
coordinator.circuitBreaker.config.MaxFailures = 20
defer func() {
coordinator.circuitBreaker.config.MaxFailures = oldMaxFailures
}()
failingRefresh := func() (*TokenResponse, error) {
return nil, fmt.Errorf("failed")
}
sessionID := "rate_limit_session"
// Exhaust attempts
for i := 0; i < config.MaxRefreshAttempts+1; i++ {
ctx := context.Background()
_, _ = coordinator.CoordinateRefresh(ctx, sessionID, "refresh_rl", failingRefresh)
// Add delay to ensure operations complete and aren't deduplicated
time.Sleep(150 * time.Millisecond)
}
// Should be in cooldown
ctx := context.Background()
_, err := coordinator.CoordinateRefresh(ctx, sessionID, "refresh_rl", failingRefresh)
if err == nil || !strings.Contains(err.Error(), "cooldown") {
t.Errorf("Rate limiting should have triggered cooldown: %v", err)
}
})
// Test 4: Cleanup
t.Run("Cleanup", func(t *testing.T) {
// Add some sessions
for i := 0; i < 5; i++ {
coordinator.recordRefreshAttempt(fmt.Sprintf("cleanup_session_%d", i))
}
// Wait for cleanup
time.Sleep(config.CleanupInterval * 3)
// Old sessions should be cleaned up
coordinator.attemptsMutex.RLock()
count := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
// Should have fewer sessions after cleanup
if count > 10 {
t.Errorf("Cleanup not working, %d sessions remain", count)
}
})
// Verify final metrics
metrics := coordinator.GetMetrics()
t.Logf("Final metrics: %+v", metrics)
}
+6
View File
@@ -32,6 +32,12 @@ func GetMemoryOptimizations() *MemoryOptimizations {
return globalMemoryOpts
}
// ResetGlobalMemoryOptimizations resets the global memory optimizations for testing
func ResetGlobalMemoryOptimizations() {
globalMemoryOptsOnce = sync.Once{}
globalMemoryOpts = nil
}
// BufferPool manages a pool of byte buffers
type BufferPool struct {
pool sync.Pool
+3 -1
View File
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
)
@@ -95,7 +96,8 @@ func (mc *MetadataCache) GetProviderMetadata(ctx context.Context, providerURL st
}
// Fetch from provider
metadataURL := providerURL + "/.well-known/openid-configuration"
// Ensure no double slashes by trimming trailing slash from provider URL
metadataURL := strings.TrimRight(providerURL, "/") + "/.well-known/openid-configuration"
mc.logger.Infof("Fetching provider metadata from: %s", metadataURL)
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
+596
View File
@@ -0,0 +1,596 @@
package traefikoidc
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"sync"
"sync/atomic"
"time"
)
// RefreshCoordinator prevents duplicate refresh token operations and manages
// refresh attempt tracking to prevent infinite loops and OOM conditions.
// It implements request coalescing, rate limiting, and circuit breaking
// specifically for token refresh operations.
type RefreshCoordinator struct {
// inFlightRefreshes tracks active refresh operations by refresh token hash
inFlightRefreshes map[string]*refreshOperation
// refreshMutex protects the inFlightRefreshes map
refreshMutex sync.RWMutex
// sessionRefreshAttempts tracks refresh attempts per session
sessionRefreshAttempts map[string]*refreshAttemptTracker
// attemptsMutex protects sessionRefreshAttempts map
attemptsMutex sync.RWMutex
// Circuit breaker for refresh operations
circuitBreaker *RefreshCircuitBreaker
// Configuration
config RefreshCoordinatorConfig
// Metrics
metrics *RefreshMetrics
// Logger
logger *Logger
// Cleanup goroutine control
stopChan chan struct{}
wg sync.WaitGroup
}
// RefreshCoordinatorConfig configures the refresh coordinator behavior
type RefreshCoordinatorConfig struct {
// Maximum refresh attempts per session before giving up
MaxRefreshAttempts int
// Time window for refresh attempt tracking
RefreshAttemptWindow time.Duration
// Cooldown period after max attempts reached
RefreshCooldownPeriod time.Duration
// Maximum concurrent refresh operations
MaxConcurrentRefreshes int
// Timeout for individual refresh operations
RefreshTimeout time.Duration
// Enable memory pressure detection
EnableMemoryPressureDetection bool
// Memory pressure threshold (in MB)
MemoryPressureThresholdMB uint64
// Cleanup interval for stale entries
CleanupInterval time.Duration
// Delay before cleaning up completed refresh operations from deduplication map
// Set to 0 for immediate cleanup (useful for tests)
DeduplicationCleanupDelay time.Duration
}
// DefaultRefreshCoordinatorConfig returns production-ready configuration
func DefaultRefreshCoordinatorConfig() RefreshCoordinatorConfig {
return RefreshCoordinatorConfig{
MaxRefreshAttempts: 5,
RefreshAttemptWindow: 5 * time.Minute,
RefreshCooldownPeriod: 10 * time.Minute,
MaxConcurrentRefreshes: 10,
RefreshTimeout: 30 * time.Second,
EnableMemoryPressureDetection: true,
MemoryPressureThresholdMB: 500, // 500MB threshold
CleanupInterval: 1 * time.Minute,
DeduplicationCleanupDelay: 100 * time.Millisecond, // Default 100ms for production
}
}
// refreshOperation represents an in-flight refresh operation
type refreshOperation struct {
// refreshToken being refreshed (for validation)
refreshToken string
// result stores the final result
result *refreshResult
// done signals when the operation is complete
done chan struct{}
// startTime tracks when the operation started
startTime time.Time
// waiterCount tracks number of goroutines waiting
waiterCount int32
// mutex protects the result field
mutex sync.RWMutex
}
// refreshResult contains the result of a refresh operation
type refreshResult struct {
tokenResponse *TokenResponse
err error
fromCache bool
}
// refreshAttemptTracker tracks refresh attempts for a session
type refreshAttemptTracker struct {
// attempts counts refresh attempts in current window
attempts int32
// lastAttemptTime is the timestamp of the last attempt
lastAttemptTime time.Time
// windowStartTime is when the current tracking window started
windowStartTime time.Time
// inCooldown indicates if this session is in cooldown
inCooldown bool
// cooldownEndTime is when cooldown period ends
cooldownEndTime time.Time
// consecutiveFailures tracks consecutive refresh failures
consecutiveFailures int32
}
// RefreshMetrics tracks coordinator performance metrics
type RefreshMetrics struct {
totalRefreshRequests int64
deduplicatedRequests int64
successfulRefreshes int64
failedRefreshes int64
circuitBreakerTrips int64
memoryPressureEvents int64
cooldownsTriggered int64
currentInFlightRefreshes int32
}
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations
type RefreshCircuitBreaker struct {
state int32 // 0=closed, 1=open, 2=half-open
failures int32
lastFailureTime time.Time
lastSuccessTime time.Time
config RefreshCircuitBreakerConfig
mutex sync.RWMutex
}
// RefreshCircuitBreakerConfig configures the refresh circuit breaker
type RefreshCircuitBreakerConfig struct {
MaxFailures int
OpenDuration time.Duration
HalfOpenRequests int
}
// NewRefreshCoordinator creates a new refresh coordinator
func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *RefreshCoordinator {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
rc := &RefreshCoordinator{
inFlightRefreshes: make(map[string]*refreshOperation),
sessionRefreshAttempts: make(map[string]*refreshAttemptTracker),
config: config,
metrics: &RefreshMetrics{},
logger: logger,
stopChan: make(chan struct{}),
circuitBreaker: &RefreshCircuitBreaker{
config: RefreshCircuitBreakerConfig{
MaxFailures: 3,
OpenDuration: 30 * time.Second,
HalfOpenRequests: 1,
},
},
}
// Start cleanup goroutine
rc.wg.Add(1)
go rc.cleanupRoutine()
return rc
}
// CoordinateRefresh ensures only one refresh operation happens per refresh token
// and implements request coalescing for concurrent refresh attempts
func (rc *RefreshCoordinator) CoordinateRefresh(
ctx context.Context,
sessionID string,
refreshToken string,
refreshFunc func() (*TokenResponse, error),
) (*TokenResponse, error) {
// Increment total request count
atomic.AddInt64(&rc.metrics.totalRefreshRequests, 1)
// Check circuit breaker first
if !rc.circuitBreaker.AllowRequest() {
atomic.AddInt64(&rc.metrics.circuitBreakerTrips, 1)
return nil, fmt.Errorf("refresh circuit breaker is open due to repeated failures")
}
// Create hash of refresh token for deduplication
tokenHash := rc.hashRefreshToken(refreshToken)
// CRITICAL FIX: Atomically check for existing operation OR create new one
// This prevents the race where multiple goroutines check, find nothing, then all create
operation, isNew, err := rc.getOrCreateOperation(ctx, sessionID, tokenHash, refreshToken)
if err != nil {
// Operation creation was rejected (rate limit, memory pressure, concurrent limit)
return nil, err
}
if isNew {
// We created a new operation, so we need to execute it
go rc.executeRefreshAsync(operation, sessionID, tokenHash, refreshFunc)
} else {
// Joined existing operation - this is a deduplicated request
atomic.AddInt64(&rc.metrics.deduplicatedRequests, 1)
}
// Wait for the operation to complete
select {
case <-operation.done:
// Get the result
operation.mutex.RLock()
result := operation.result
operation.mutex.RUnlock()
if result != nil {
// Record metrics based on result
if result.err != nil {
rc.circuitBreaker.RecordFailure()
rc.recordRefreshFailure(sessionID)
atomic.AddInt64(&rc.metrics.failedRefreshes, 1)
} else {
rc.circuitBreaker.RecordSuccess()
rc.recordRefreshSuccess(sessionID)
atomic.AddInt64(&rc.metrics.successfulRefreshes, 1)
}
return result.tokenResponse, result.err
}
return nil, fmt.Errorf("refresh operation completed without result")
case <-ctx.Done():
return nil, ctx.Err()
}
}
// getOrCreateOperation atomically checks for an existing operation or creates a new one
// Returns (operation, true, nil) if a new operation was created
// Returns (operation, false, nil) if joined an existing operation
// Returns (nil, false, error) if the operation was rejected
func (rc *RefreshCoordinator) getOrCreateOperation(
ctx context.Context,
sessionID string,
tokenHash string,
refreshToken string,
) (*refreshOperation, bool, error) {
rc.refreshMutex.Lock()
defer rc.refreshMutex.Unlock()
// Check for existing operation while holding the lock
if existingOp, exists := rc.inFlightRefreshes[tokenHash]; exists {
if existingOp.refreshToken == refreshToken {
// Join existing operation
atomic.AddInt32(&existingOp.waiterCount, 1)
return existingOp, false, nil
}
// Different refresh token for same hash - should not happen
return nil, false, fmt.Errorf("refresh token mismatch")
}
// No existing operation - check if we can create a new one
// All checks happen while holding the lock to prevent races
// Check and record refresh attempt for rate limiting
rc.recordRefreshAttempt(sessionID)
if rc.isInCooldown(sessionID) {
atomic.AddInt64(&rc.metrics.cooldownsTriggered, 1)
return nil, false, fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
}
// Check memory pressure
if rc.config.EnableMemoryPressureDetection && rc.isUnderMemoryPressure() {
atomic.AddInt64(&rc.metrics.memoryPressureEvents, 1)
return nil, false, fmt.Errorf("system under memory pressure, refresh denied")
}
// Check and reserve concurrent refresh slot atomically
current := atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes)
if int(current) >= rc.config.MaxConcurrentRefreshes {
return nil, false, fmt.Errorf("maximum concurrent refresh operations reached")
}
// Reserve the slot - we're still holding the lock so this is safe
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
// Create and register new operation
operation := &refreshOperation{
refreshToken: refreshToken,
done: make(chan struct{}),
startTime: time.Now(),
waiterCount: 1,
}
rc.inFlightRefreshes[tokenHash] = operation
return operation, true, nil
}
// executeRefreshAsync performs the actual refresh operation asynchronously
func (rc *RefreshCoordinator) executeRefreshAsync(
operation *refreshOperation,
sessionID string,
tokenHash string,
refreshFunc func() (*TokenResponse, error),
) {
defer func() {
// Signal completion to all waiters
close(operation.done)
// Clean up operation after a configurable delay to allow waiters to read result
go func() {
if rc.config.DeduplicationCleanupDelay > 0 {
time.Sleep(rc.config.DeduplicationCleanupDelay)
}
rc.refreshMutex.Lock()
delete(rc.inFlightRefreshes, tokenHash)
rc.refreshMutex.Unlock()
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
}()
}()
// Create timeout context
refreshCtx, cancel := context.WithTimeout(context.Background(), rc.config.RefreshTimeout)
defer cancel()
// Execute refresh in goroutine to respect timeout
resultChan := make(chan struct {
resp *TokenResponse
err error
}, 1)
go func() {
resp, err := refreshFunc()
select {
case resultChan <- struct {
resp *TokenResponse
err error
}{resp, err}:
case <-refreshCtx.Done():
}
}()
select {
case result := <-resultChan:
// Store result for all waiters
operation.mutex.Lock()
operation.result = &refreshResult{
tokenResponse: result.resp,
err: result.err,
fromCache: false,
}
operation.mutex.Unlock()
case <-refreshCtx.Done():
// Timeout occurred
timeoutErr := fmt.Errorf("refresh operation timed out after %v", rc.config.RefreshTimeout)
operation.mutex.Lock()
operation.result = &refreshResult{
tokenResponse: nil,
err: timeoutErr,
fromCache: false,
}
operation.mutex.Unlock()
}
}
// isInCooldown checks if a session is in cooldown after recording an attempt
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
tracker, exists := rc.sessionRefreshAttempts[sessionID]
if !exists {
return false // No tracker means first attempt, not in cooldown
}
now := time.Now()
// Check if already in cooldown
if tracker.inCooldown {
if now.After(tracker.cooldownEndTime) {
// Cooldown expired, reset tracker
tracker.inCooldown = false
tracker.attempts = 1 // Already recorded one attempt
tracker.consecutiveFailures = 0
tracker.windowStartTime = now
return false
}
return true // Still in cooldown
}
// Check if window expired
if now.Sub(tracker.windowStartTime) > rc.config.RefreshAttemptWindow {
// Reset window
tracker.attempts = 1 // Already recorded one attempt
tracker.windowStartTime = now
return false
}
// Check if just exceeded attempt limit
if int(tracker.attempts) >= rc.config.MaxRefreshAttempts {
// Enter cooldown now
tracker.inCooldown = true
tracker.cooldownEndTime = now.Add(rc.config.RefreshCooldownPeriod)
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
sessionID, tracker.attempts)
return true
}
return false
}
// recordRefreshAttempt records a refresh attempt for rate limiting
func (rc *RefreshCoordinator) recordRefreshAttempt(sessionID string) {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
tracker, exists := rc.sessionRefreshAttempts[sessionID]
if !exists {
tracker = &refreshAttemptTracker{
windowStartTime: time.Now(),
}
rc.sessionRefreshAttempts[sessionID] = tracker
}
atomic.AddInt32(&tracker.attempts, 1)
tracker.lastAttemptTime = time.Now()
}
// recordRefreshSuccess records a successful refresh
func (rc *RefreshCoordinator) recordRefreshSuccess(sessionID string) {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
tracker.consecutiveFailures = 0
}
}
// recordRefreshFailure records a failed refresh
func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
atomic.AddInt32(&tracker.consecutiveFailures, 1)
}
}
// hashRefreshToken creates a hash of the refresh token for deduplication
func (rc *RefreshCoordinator) hashRefreshToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
// isUnderMemoryPressure checks if the system is under memory pressure
func (rc *RefreshCoordinator) isUnderMemoryPressure() bool {
// This is a simplified check - in production you'd want to use runtime.MemStats
// or system-specific memory monitoring
return false // Placeholder - implement actual memory check
}
// cleanupRoutine periodically cleans up stale tracking entries
func (rc *RefreshCoordinator) cleanupRoutine() {
defer rc.wg.Done()
ticker := time.NewTicker(rc.config.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
rc.cleanupStaleEntries()
case <-rc.stopChan:
return
}
}
}
// cleanupStaleEntries removes outdated tracking entries
func (rc *RefreshCoordinator) cleanupStaleEntries() {
now := time.Now()
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
// Clean up old session trackers
for sessionID, tracker := range rc.sessionRefreshAttempts {
// Remove trackers that haven't been used recently
if now.Sub(tracker.lastAttemptTime) > 2*rc.config.RefreshAttemptWindow {
delete(rc.sessionRefreshAttempts, sessionID)
}
}
}
// GetMetrics returns current coordinator metrics
func (rc *RefreshCoordinator) GetMetrics() map[string]interface{} {
return map[string]interface{}{
"total_requests": atomic.LoadInt64(&rc.metrics.totalRefreshRequests),
"deduplicated_requests": atomic.LoadInt64(&rc.metrics.deduplicatedRequests),
"successful_refreshes": atomic.LoadInt64(&rc.metrics.successfulRefreshes),
"failed_refreshes": atomic.LoadInt64(&rc.metrics.failedRefreshes),
"circuit_breaker_trips": atomic.LoadInt64(&rc.metrics.circuitBreakerTrips),
"memory_pressure_events": atomic.LoadInt64(&rc.metrics.memoryPressureEvents),
"cooldowns_triggered": atomic.LoadInt64(&rc.metrics.cooldownsTriggered),
"current_inflight": atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes),
"circuit_breaker_state": rc.circuitBreaker.GetState(),
}
}
// Shutdown gracefully shuts down the coordinator
func (rc *RefreshCoordinator) Shutdown() {
close(rc.stopChan)
rc.wg.Wait()
}
// AllowRequest checks if the circuit breaker allows a request
func (cb *RefreshCircuitBreaker) AllowRequest() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
state := atomic.LoadInt32(&cb.state)
switch state {
case 0: // Closed
return true
case 1: // Open
if time.Since(cb.lastFailureTime) > cb.config.OpenDuration {
// Try to transition to half-open
if atomic.CompareAndSwapInt32(&cb.state, 1, 2) {
return true
}
}
return false
case 2: // Half-open
return true
default:
return false
}
}
// RecordSuccess records a successful operation
func (cb *RefreshCircuitBreaker) RecordSuccess() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
state := atomic.LoadInt32(&cb.state)
if state == 2 { // Half-open
// Close the circuit
atomic.StoreInt32(&cb.state, 0)
atomic.StoreInt32(&cb.failures, 0)
} else if state == 0 { // Closed
// Reset failure count on success
atomic.StoreInt32(&cb.failures, 0)
}
cb.lastSuccessTime = time.Now()
}
// RecordFailure records a failed operation
func (cb *RefreshCircuitBreaker) RecordFailure() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
failures := atomic.AddInt32(&cb.failures, 1)
cb.lastFailureTime = time.Now()
state := atomic.LoadInt32(&cb.state)
if state == 0 && int(failures) >= cb.config.MaxFailures {
// Open the circuit
atomic.StoreInt32(&cb.state, 1)
} else if state == 2 {
// Half-open failed, return to open
atomic.StoreInt32(&cb.state, 1)
}
}
// GetState returns the current state of the circuit breaker
func (cb *RefreshCircuitBreaker) GetState() string {
state := atomic.LoadInt32(&cb.state)
switch state {
case 0:
return "closed"
case 1:
return "open"
case 2:
return "half-open"
default:
return "unknown"
}
}
+669
View File
@@ -0,0 +1,669 @@
package traefikoidc
import (
"context"
"fmt"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
// TestConcurrentRefreshDeduplication verifies that concurrent refresh attempts
// for the same token are deduplicated and only one refresh operation occurs
func TestConcurrentRefreshDeduplication(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
// Keep default delay for this test - it's testing deduplication behavior
// Disable rate limiting for this test since we're testing deduplication
config.MaxRefreshAttempts = 1000 // High enough to not interfere
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
// Counter to track actual refresh executions
var refreshExecutions int32
// Mock refresh function
refreshFunc := func() (*TokenResponse, error) {
atomic.AddInt32(&refreshExecutions, 1)
// Simulate some processing time
time.Sleep(100 * time.Millisecond)
return &TokenResponse{
AccessToken: "new_access_token",
RefreshToken: "new_refresh_token",
IDToken: "new_id_token",
ExpiresIn: 3600,
}, nil
}
// Number of concurrent requests
numRequests := 100
var wg sync.WaitGroup
wg.Add(numRequests)
// Channel to collect results
results := make(chan *TokenResponse, numRequests)
errors := make(chan error, numRequests)
// Launch concurrent refresh attempts with unique identifiers
refreshToken := fmt.Sprintf("test_refresh_token_%d", time.Now().UnixNano())
sessionID := fmt.Sprintf("test_session_%d", time.Now().UnixNano())
for i := 0; i < numRequests; i++ {
go func(reqID int) {
defer wg.Done()
ctx := context.Background()
resp, err := coordinator.CoordinateRefresh(
ctx,
sessionID,
refreshToken,
refreshFunc,
)
if err != nil {
errors <- err
} else {
results <- resp
}
}(i)
}
// Wait for all goroutines to complete
wg.Wait()
close(results)
close(errors)
// Verify results
actualExecutions := atomic.LoadInt32(&refreshExecutions)
// Allow for slight timing variations - up to 2 executions is acceptable
// This can happen when a second goroutine starts just as the first completes
if actualExecutions > 2 {
t.Errorf("Expected 1-2 refresh executions, got %d", actualExecutions)
}
// Verify all requests got the same result
var firstResponse *TokenResponse
responseCount := 0
for resp := range results {
responseCount++
if firstResponse == nil {
firstResponse = resp
} else {
// All responses should be identical (same pointer)
if resp.AccessToken != firstResponse.AccessToken {
t.Error("Different responses returned for concurrent requests")
}
}
}
// Check for errors
errorCount := 0
for range errors {
errorCount++
}
if errorCount > 0 {
t.Errorf("Unexpected errors in concurrent requests: %d", errorCount)
}
if responseCount != numRequests {
t.Errorf("Expected %d successful responses, got %d", numRequests, responseCount)
}
// Verify metrics
metrics := coordinator.GetMetrics()
if deduped, ok := metrics["deduplicated_requests"].(int64); ok {
// Allow for slight timing variations - at least 98 out of 100 should be deduplicated
if deduped < int64(numRequests-2) {
t.Errorf("Expected at least %d deduplicated requests, got %d", numRequests-2, deduped)
}
}
}
// TestRefreshRateLimiting verifies that refresh attempts are rate-limited per session
func TestRefreshRateLimiting(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
config.MaxRefreshAttempts = 3
config.RefreshAttemptWindow = 1 * time.Second
config.RefreshCooldownPeriod = 2 * time.Second
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
// Set circuit breaker to not interfere with rate limiting test
// We want to test rate limiting, not circuit breaker
coordinator.circuitBreaker.config.MaxFailures = 10
sessionID := "rate_limited_session"
refreshToken := "test_refresh_token"
// Mock refresh function that always fails
refreshFunc := func() (*TokenResponse, error) {
return nil, fmt.Errorf("refresh failed")
}
// Attempt refreshes beyond the limit
var attempts int
var cooldownTriggered bool
for i := 0; i < 5; i++ {
ctx := context.Background()
_, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
if err != nil {
if err.Error() == "refresh attempts exceeded for session, in cooldown period" {
cooldownTriggered = true
break
}
}
attempts++
// Add delay to ensure operations complete and aren't deduplicated
time.Sleep(150 * time.Millisecond)
}
// Verify that cooldown was triggered after max attempts
// With the new logic, the Nth attempt triggers cooldown, so we get N-1 successful attempts
expectedSuccessfulAttempts := config.MaxRefreshAttempts - 1
if attempts != expectedSuccessfulAttempts {
t.Errorf("Expected %d successful attempts before cooldown, got %d", expectedSuccessfulAttempts, attempts)
}
if !cooldownTriggered {
t.Error("Cooldown was not triggered after max attempts")
}
// Verify that requests are blocked during cooldown
ctx := context.Background()
_, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" {
t.Error("Request should be blocked during cooldown period")
}
// Wait for cooldown to expire
time.Sleep(config.RefreshCooldownPeriod + 100*time.Millisecond)
// Verify that requests are allowed after cooldown
_, err = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
if err != nil && err.Error() == "refresh attempts exceeded for session, in cooldown period" {
t.Error("Request should be allowed after cooldown period")
}
}
// TestCircuitBreakerProtection verifies that the circuit breaker prevents
// cascading failures during repeated refresh failures
func TestCircuitBreakerProtection(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
// Set circuit breaker to trip after 3 failures
coordinator.circuitBreaker.config.MaxFailures = 3
coordinator.circuitBreaker.config.OpenDuration = 1 * time.Second
// Mock refresh function that always fails
refreshFunc := func() (*TokenResponse, error) {
return nil, fmt.Errorf("service unavailable")
}
// Cause circuit breaker to trip
var tripCount int
for i := 0; i < 5; i++ {
ctx := context.Background()
_, err := coordinator.CoordinateRefresh(
ctx,
fmt.Sprintf("session_%d", i), // Different sessions
"refresh_token",
refreshFunc,
)
if err != nil && err.Error() == "refresh circuit breaker is open due to repeated failures" {
tripCount++
}
}
// Verify circuit breaker tripped
if tripCount == 0 {
t.Error("Circuit breaker did not trip after repeated failures")
}
// Verify circuit breaker state
if coordinator.circuitBreaker.GetState() != "open" {
t.Errorf("Expected circuit breaker state 'open', got '%s'", coordinator.circuitBreaker.GetState())
}
// Wait for circuit to transition to half-open
time.Sleep(coordinator.circuitBreaker.config.OpenDuration + 100*time.Millisecond)
// Mock successful refresh
successfulRefreshFunc := func() (*TokenResponse, error) {
return &TokenResponse{
AccessToken: "new_token",
}, nil
}
// Verify circuit allows request in half-open state
ctx := context.Background()
_, err := coordinator.CoordinateRefresh(ctx, "session_recovery", "refresh_token", successfulRefreshFunc)
if err != nil {
t.Errorf("Circuit breaker should allow request in half-open state: %v", err)
}
// Verify circuit closed after success
if coordinator.circuitBreaker.GetState() != "closed" {
t.Errorf("Expected circuit breaker state 'closed' after successful request, got '%s'",
coordinator.circuitBreaker.GetState())
}
}
// TestMemoryLeakPrevention verifies that the coordinator doesn't leak memory
// during sustained concurrent refresh operations
func TestMemoryLeakPrevention(t *testing.T) {
if testing.Short() {
t.Skip("Skipping memory leak test in short mode")
}
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
config.CleanupInterval = 100 * time.Millisecond
config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
// Force garbage collection and record initial memory
runtime.GC()
runtime.GC()
var initialMem runtime.MemStats
runtime.ReadMemStats(&initialMem)
// Run sustained concurrent operations
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var wg sync.WaitGroup
numWorkers := 10
wg.Add(numWorkers)
// Each worker continuously attempts refreshes
for i := 0; i < numWorkers; i++ {
go func(workerID int) {
defer wg.Done()
refreshCount := 0
refreshFunc := func() (*TokenResponse, error) {
// Simulate varying response times
time.Sleep(time.Duration(workerID*10) * time.Millisecond)
return &TokenResponse{
AccessToken: fmt.Sprintf("token_%d_%d", workerID, refreshCount),
RefreshToken: fmt.Sprintf("refresh_%d_%d", workerID, refreshCount),
}, nil
}
for {
select {
case <-ctx.Done():
return
default:
sessionID := fmt.Sprintf("session_%d", workerID)
refreshToken := fmt.Sprintf("refresh_%d_%d", workerID, refreshCount)
_, _ = coordinator.CoordinateRefresh(
context.Background(),
sessionID,
refreshToken,
refreshFunc,
)
refreshCount++
// Small delay to prevent CPU saturation
time.Sleep(10 * time.Millisecond)
}
}
}(i)
}
// Wait for workers to complete
wg.Wait()
// Allow cleanup to run
time.Sleep(2 * config.CleanupInterval)
// Force garbage collection and check memory
runtime.GC()
runtime.GC()
var finalMem runtime.MemStats
runtime.ReadMemStats(&finalMem)
// Calculate memory growth safely to prevent underflow
var memGrowthMB float64
if finalMem.HeapAlloc >= initialMem.HeapAlloc {
memGrowthMB = float64(finalMem.HeapAlloc-initialMem.HeapAlloc) / (1024 * 1024)
} else {
// Memory decreased (GC occurred), treat as 0 growth
memGrowthMB = 0
}
// Log memory statistics for debugging
t.Logf("Initial memory: %.2f MB", float64(initialMem.HeapAlloc)/(1024*1024))
t.Logf("Final memory: %.2f MB", float64(finalMem.HeapAlloc)/(1024*1024))
t.Logf("Memory growth: %.2f MB", memGrowthMB)
// Check for excessive memory growth (threshold: 50MB)
if memGrowthMB > 50 {
t.Errorf("Excessive memory growth detected: %.2f MB", memGrowthMB)
}
// Verify no lingering operations
metrics := coordinator.GetMetrics()
if inflight, ok := metrics["current_inflight"].(int32); ok {
if inflight != 0 {
t.Errorf("Expected 0 in-flight operations after completion, got %d", inflight)
}
}
// Verify cleanup is working
coordinator.attemptsMutex.RLock()
sessionCount := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
// Should have cleaned up old sessions (only recent ones remain)
if sessionCount > numWorkers*2 {
t.Errorf("Session cleanup not working properly, %d sessions remain", sessionCount)
}
}
// TestRefreshTimeoutHandling verifies that refresh operations timeout properly
func TestRefreshTimeoutHandling(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
config.RefreshTimeout = 100 * time.Millisecond
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
// Mock refresh function that hangs
refreshFunc := func() (*TokenResponse, error) {
time.Sleep(1 * time.Second) // Much longer than timeout
return &TokenResponse{AccessToken: "token"}, nil
}
ctx := context.Background()
start := time.Now()
_, err := coordinator.CoordinateRefresh(ctx, "session", "refresh_token", refreshFunc)
elapsed := time.Since(start)
// Verify timeout occurred
if err == nil {
t.Error("Expected timeout error, got nil")
}
// Verify it timed out within reasonable bounds
if elapsed > 200*time.Millisecond {
t.Errorf("Timeout took too long: %v", elapsed)
}
if err != nil && err.Error() != fmt.Sprintf("refresh operation timed out after %v", config.RefreshTimeout) {
t.Errorf("Unexpected error message: %v", err)
}
}
// TestConcurrentDifferentTokens verifies that refreshes for different tokens
// proceed independently without blocking each other
func TestConcurrentDifferentTokens(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
numTokens := 10
var wg sync.WaitGroup
wg.Add(numTokens)
// Track execution order
executionOrder := make([]int, 0, numTokens)
var executionMutex sync.Mutex
for i := 0; i < numTokens; i++ {
go func(tokenID int) {
defer wg.Done()
refreshFunc := func() (*TokenResponse, error) {
executionMutex.Lock()
executionOrder = append(executionOrder, tokenID)
executionMutex.Unlock()
// Varying processing times
time.Sleep(time.Duration(tokenID*10) * time.Millisecond)
return &TokenResponse{
AccessToken: fmt.Sprintf("token_%d", tokenID),
RefreshToken: fmt.Sprintf("refresh_%d", tokenID),
}, nil
}
ctx := context.Background()
resp, err := coordinator.CoordinateRefresh(
ctx,
fmt.Sprintf("session_%d", tokenID),
fmt.Sprintf("refresh_token_%d", tokenID),
refreshFunc,
)
if err != nil {
t.Errorf("Token %d refresh failed: %v", tokenID, err)
}
if resp == nil || resp.AccessToken != fmt.Sprintf("token_%d", tokenID) {
t.Errorf("Token %d got wrong response", tokenID)
}
}(i)
}
wg.Wait()
// Verify all tokens were processed
if len(executionOrder) != numTokens {
t.Errorf("Expected %d executions, got %d", numTokens, len(executionOrder))
}
// Verify no deduplication occurred (all different tokens)
metrics := coordinator.GetMetrics()
if deduped, ok := metrics["deduplicated_requests"].(int64); ok {
if deduped != 0 {
t.Errorf("No deduplication expected for different tokens, got %d", deduped)
}
}
}
// TestMaxConcurrentRefreshes verifies that the coordinator respects
// the maximum concurrent refresh limit
func TestMaxConcurrentRefreshes(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
config.MaxConcurrentRefreshes = 2
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
// Track concurrent executions
var currentConcurrent int32
var maxConcurrent int32
refreshFunc := func() (*TokenResponse, error) {
current := atomic.AddInt32(&currentConcurrent, 1)
// Update max if needed
for {
max := atomic.LoadInt32(&maxConcurrent)
if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) {
break
}
}
time.Sleep(100 * time.Millisecond)
atomic.AddInt32(&currentConcurrent, -1)
return &TokenResponse{AccessToken: "token"}, nil
}
numRequests := 10
var wg sync.WaitGroup
wg.Add(numRequests)
errors := make([]error, 0, numRequests)
var errorMutex sync.Mutex
for i := 0; i < numRequests; i++ {
go func(id int) {
defer wg.Done()
ctx := context.Background()
_, err := coordinator.CoordinateRefresh(
ctx,
fmt.Sprintf("session_%d", id),
fmt.Sprintf("token_%d", id),
refreshFunc,
)
if err != nil {
errorMutex.Lock()
errors = append(errors, err)
errorMutex.Unlock()
}
}(i)
}
wg.Wait()
// Some requests should have been rejected due to concurrency limit
if len(errors) == 0 {
t.Error("Expected some requests to be rejected due to concurrency limit")
}
// Verify max concurrent never exceeded limit
if maxConcurrent > int32(config.MaxConcurrentRefreshes) {
t.Errorf("Max concurrent refreshes (%d) exceeded limit (%d)",
maxConcurrent, config.MaxConcurrentRefreshes)
}
}
// TestSessionWindowReset verifies that refresh attempt windows reset properly
func TestSessionWindowReset(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
config.MaxRefreshAttempts = 2
config.RefreshAttemptWindow = 500 * time.Millisecond
config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
// Set circuit breaker to not interfere with rate limiting test
coordinator.circuitBreaker.config.MaxFailures = 10
// Use unique identifiers to prevent test interference
sessionID := fmt.Sprintf("window_test_session_%d", time.Now().UnixNano())
refreshToken := fmt.Sprintf("test_refresh_token_%d", time.Now().UnixNano())
// Mock refresh function that always fails
refreshFunc := func() (*TokenResponse, error) {
return nil, fmt.Errorf("refresh failed")
}
// Use up the attempts in the first window
for i := 0; i < config.MaxRefreshAttempts; i++ {
ctx := context.Background()
_, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
}
// Next attempt should trigger cooldown
ctx := context.Background()
_, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" {
t.Error("Expected cooldown after max attempts")
}
// Wait for window to expire (but not cooldown)
time.Sleep(config.RefreshAttemptWindow + 100*time.Millisecond)
// Should still be in cooldown (cooldown > window)
_, err = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" {
t.Error("Should still be in cooldown period")
}
}
// BenchmarkConcurrentRefreshDeduplication measures performance of deduplication
func BenchmarkConcurrentRefreshDeduplication(b *testing.B) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
refreshFunc := func() (*TokenResponse, error) {
time.Sleep(10 * time.Millisecond)
return &TokenResponse{
AccessToken: "token",
}, nil
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
ctx := context.Background()
sessionID := fmt.Sprintf("session_%d", i%10) // Reuse 10 sessions
refreshToken := fmt.Sprintf("token_%d", i%10) // Reuse 10 tokens
_, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
i++
}
})
b.StopTimer()
// Report metrics
metrics := coordinator.GetMetrics()
b.Logf("Total requests: %v", metrics["total_requests"])
b.Logf("Deduplicated: %v", metrics["deduplicated_requests"])
}
// TestCleanupRoutine verifies that the cleanup routine removes stale entries
func TestCleanupRoutine(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
config.CleanupInterval = 100 * time.Millisecond
config.RefreshAttemptWindow = 200 * time.Millisecond
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
// Add some sessions
for i := 0; i < 5; i++ {
coordinator.recordRefreshAttempt(fmt.Sprintf("session_%d", i))
}
// Verify sessions exist
coordinator.attemptsMutex.RLock()
initialCount := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
if initialCount != 5 {
t.Errorf("Expected 5 sessions, got %d", initialCount)
}
// Wait for cleanup to run (2x window + cleanup interval)
time.Sleep(2*config.RefreshAttemptWindow + 2*config.CleanupInterval)
// Verify sessions were cleaned up
coordinator.attemptsMutex.RLock()
finalCount := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
if finalCount != 0 {
t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount)
}
}
+159
View File
@@ -0,0 +1,159 @@
package traefikoidc
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
// TestRefreshCoordinatorRaceCondition specifically tests for race conditions
// in the refresh coordinator's concurrent operation handling
func TestRefreshCoordinatorRaceCondition(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
// Increase rate limit for this race condition test
config.MaxRefreshAttempts = 100 // Allow many attempts for race testing
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
// Test concurrent access to the same refresh token
var executions int32
refreshFunc := func() (*TokenResponse, error) {
atomic.AddInt32(&executions, 1)
time.Sleep(50 * time.Millisecond) // Simulate work
return &TokenResponse{
AccessToken: "test_token",
RefreshToken: "test_refresh",
}, nil
}
// Launch many goroutines concurrently
const numGoroutines = 50
var wg sync.WaitGroup
wg.Add(numGoroutines)
ctx := context.Background()
sessionID := "test_session"
refreshToken := "test_refresh_token"
// Use a channel to ensure all goroutines start at the same time
startChan := make(chan struct{})
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
// Wait for signal to start
<-startChan
// All goroutines try to refresh at the same time
result, err := coordinator.CoordinateRefresh(
ctx,
sessionID,
refreshToken,
refreshFunc,
)
// Basic validation
if err != nil {
t.Errorf("Goroutine %d: unexpected error: %v", id, err)
}
if result == nil || result.AccessToken != "test_token" {
t.Errorf("Goroutine %d: invalid result", id)
}
}(i)
}
// Release all goroutines at once
close(startChan)
// Wait for completion
wg.Wait()
// Check that deduplication worked
actualExecutions := atomic.LoadInt32(&executions)
t.Logf("Executions: %d out of %d goroutines", actualExecutions, numGoroutines)
// With proper deduplication, we should have very few executions
// Allow for some timing slack - up to 3 executions is acceptable
if actualExecutions > 3 {
t.Errorf("Too many refresh executions: %d (expected <= 3)", actualExecutions)
}
// Verify metrics
metrics := coordinator.GetMetrics()
if total, ok := metrics["total_requests"].(int64); ok {
if total != int64(numGoroutines) {
t.Errorf("Expected %d total requests, got %d", numGoroutines, total)
}
}
}
// TestRefreshCoordinatorNoRaceWithDifferentTokens verifies no interference
// between different refresh tokens
func TestRefreshCoordinatorNoRaceWithDifferentTokens(t *testing.T) {
logger := GetSingletonNoOpLogger()
config := DefaultRefreshCoordinatorConfig()
// Increase concurrent limit to handle 10 different tokens
config.MaxConcurrentRefreshes = 15
config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior
// Increase rate limit since we have 5 goroutines per token
config.MaxRefreshAttempts = 10 // Allow multiple attempts per session
coordinator := NewRefreshCoordinator(config, logger)
defer coordinator.Shutdown()
const numTokens = 10
const goroutinesPerToken = 5
var totalExecutions int32
var wg sync.WaitGroup
wg.Add(numTokens * goroutinesPerToken)
refreshFunc := func() (*TokenResponse, error) {
atomic.AddInt32(&totalExecutions, 1)
time.Sleep(10 * time.Millisecond)
return &TokenResponse{
AccessToken: "token",
}, nil
}
// Launch goroutines for different tokens with unique identifiers
baseID := time.Now().UnixNano()
for tokenID := 0; tokenID < numTokens; tokenID++ {
sessionID := fmt.Sprintf("session_%d_%d", baseID, tokenID)
refreshToken := fmt.Sprintf("refresh_%d_%d", baseID, tokenID)
for i := 0; i < goroutinesPerToken; i++ {
go func(tid, gid int) {
defer wg.Done()
ctx := context.Background()
_, err := coordinator.CoordinateRefresh(
ctx,
sessionID,
refreshToken,
refreshFunc,
)
if err != nil && err.Error() != "maximum concurrent refresh operations reached" {
// Only log non-concurrent-limit errors as failures
t.Errorf("Token %d, Goroutine %d: unexpected error: %v", tid, gid, err)
}
}(tokenID, i)
}
}
wg.Wait()
// Each token should have had ~1 execution (maybe 2 due to timing)
actualExecutions := atomic.LoadInt32(&totalExecutions)
t.Logf("Total executions: %d for %d different tokens", actualExecutions, numTokens)
// Should be close to numTokens (one per unique token)
if actualExecutions > numTokens*2 {
t.Errorf("Too many executions: %d (expected ~%d)", actualExecutions, numTokens)
}
}
+5
View File
@@ -34,6 +34,11 @@ var (
globalMaxSessions int64 = 5000 // CRITICAL FIX: Global limit of 5000 total sessions
)
// ResetGlobalSessionCounters resets global session tracking for testing
func ResetGlobalSessionCounters() {
atomic.StoreInt64(&globalSessionCount, 0)
}
// Predefined configurations for each token type
var (
AccessTokenConfig = TokenConfig{
+5
View File
@@ -33,6 +33,11 @@ var (
globalMaxSessions int64 = 5000 // CRITICAL FIX: Global limit of 5000 total sessions
)
// ResetGlobalSessionCounters resets global session tracking for testing
func ResetGlobalSessionCounters() {
atomic.StoreInt64(&globalSessionCount, 0)
}
// Predefined configurations for each token type
var (
AccessTokenConfig = TokenConfig{
+17
View File
@@ -11,6 +11,8 @@ import (
"sync/atomic"
"testing"
"time"
"github.com/lukaszraczylo/traefikoidc/session/chunking"
)
// GlobalTestCleanup tracks and cleans up test resources
@@ -113,6 +115,15 @@ func (g *GlobalTestCleanup) CleanupAll() {
// Reset all global singletons to prevent state pollution between tests
ResetGlobalMemoryMonitor()
ResetGlobalTaskRegistry()
ResetGlobalMemoryOptimizations()
ResetSingletonNoOpLogger()
// Reset global session counters to prevent overflow in memory calculations
ResetGlobalSessionCounters()
// Reset global session counters in chunking package as well
// Note: This calls the function in session/chunking package
resetChunkingGlobalSessionCounters()
// Give background tasks time to finish cleanup
time.Sleep(100 * time.Millisecond)
@@ -949,3 +960,9 @@ func (h *PerformanceTestHelper) Reset() {
defer h.mu.Unlock()
h.samples = h.samples[:0]
}
// resetChunkingGlobalSessionCounters resets the global session counters
// in the chunking package to prevent test interference
func resetChunkingGlobalSessionCounters() {
chunking.ResetGlobalSessionCounters()
}