diff --git a/issue67_regression_test.go b/issue67_regression_test.go new file mode 100644 index 0000000..7e58d57 --- /dev/null +++ b/issue67_regression_test.go @@ -0,0 +1,541 @@ +package traefikoidc + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestIssue67_InfiniteRefreshLoop reproduces and verifies the fix for issue #67 +// where concurrent requests with expired tokens caused an infinite refresh loop +// leading to OOM conditions +func TestIssue67_InfiniteRefreshLoop(t *testing.T) { + // Track memory at start + runtime.GC() + var startMem runtime.MemStats + runtime.ReadMemStats(&startMem) + + // Create a mock authorization server + var refreshAttempts int32 + var concurrentRefreshes int32 + var maxConcurrent int32 + + // Create a handler with server URL to be set after creation + var serverURL string + + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + // Track concurrent refresh attempts + current := atomic.AddInt32(&concurrentRefreshes, 1) + defer atomic.AddInt32(&concurrentRefreshes, -1) + + // Update max concurrent + for { + max := atomic.LoadInt32(&maxConcurrent) + if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) { + break + } + } + + attempts := atomic.AddInt32(&refreshAttempts, 1) + + // Simulate slow/failing token endpoint (like in the issue) + if attempts < 5 { + // First few attempts fail to trigger retries + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte(`{"error": "temporarily_unavailable"}`)) + } else { + // Eventually succeed + time.Sleep(50 * time.Millisecond) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "id_token": "new_id_token", + "expires_in": 3600, + "token_type": "Bearer" + }`)) + } + + case "/.well-known/openid-configuration": + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(fmt.Sprintf(`{ + "issuer": "%s", + "authorization_endpoint": "%s/authorize", + "token_endpoint": "%s/token", + "jwks_uri": "%s/keys", + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + "scopes_supported": ["openid", "profile", "email"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], + "claims_supported": ["sub", "name", "email"] + }`, serverURL, serverURL, serverURL, serverURL))) + + case "/keys": + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "keys": [{ + "kty": "RSA", + "use": "sig", + "kid": "test-key", + "n": "test", + "e": "AQAB" + }] + }`)) + } + })) + defer authServer.Close() + + // Set the server URL after creation + serverURL = authServer.URL + + // Setup TraefikOIDC with refresh coordinator + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + config.MaxRefreshAttempts = 3 + config.RefreshAttemptWindow = 1 * time.Second + config.MaxConcurrentRefreshes = 2 + + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Simulate expired session + expiredSession := &MockExpiredSession{ + refreshToken: "test_refresh_token", + sessionID: "test_session", + isExpired: true, + } + + // Simulate multiple concurrent requests (as reported in issue) + numConcurrentRequests := 50 + var wg sync.WaitGroup + wg.Add(numConcurrentRequests) + + // Track results + var successCount int32 + var errorCount int32 + errors := make([]error, 0, numConcurrentRequests) + var errorMutex sync.Mutex + + // Launch concurrent requests with expired tokens + startTime := time.Now() + timeout := 5 * time.Second + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for i := 0; i < numConcurrentRequests; i++ { + go func(reqID int) { + defer wg.Done() + + // Each request tries to refresh the expired token + refreshFunc := func() (*TokenResponse, error) { + // Simulate calling the token endpoint + resp, err := http.Post( + serverURL+"/token", + "application/x-www-form-urlencoded", + nil, + ) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token refresh failed: %d", resp.StatusCode) + } + + return &TokenResponse{ + AccessToken: fmt.Sprintf("new_access_%d", reqID), + RefreshToken: "new_refresh", + IDToken: "new_id", + ExpiresIn: 3600, + }, nil + } + + // Use coordinator to prevent infinite loop + result, err := coordinator.CoordinateRefresh( + ctx, + expiredSession.sessionID, + expiredSession.refreshToken, + refreshFunc, + ) + + if err != nil { + atomic.AddInt32(&errorCount, 1) + errorMutex.Lock() + errors = append(errors, err) + errorMutex.Unlock() + } else if result != nil { + atomic.AddInt32(&successCount, 1) + } + }(i) + } + + // Wait for completion or timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Completed normally + case <-ctx.Done(): + t.Fatal("Test timed out - possible infinite loop detected!") + } + + elapsed := time.Since(startTime) + + // Verify no infinite loop occurred + if elapsed > timeout { + t.Fatalf("Requests took too long: %v (possible infinite loop)", elapsed) + } + + // Check memory usage + runtime.GC() + var endMem runtime.MemStats + runtime.ReadMemStats(&endMem) + + // Calculate memory growth safely to prevent underflow + var memGrowthMB float64 + if endMem.HeapAlloc >= startMem.HeapAlloc { + memGrowthMB = float64(endMem.HeapAlloc-startMem.HeapAlloc) / (1024 * 1024) + } else { + // Memory decreased (GC occurred), treat as 0 growth + memGrowthMB = 0 + } + t.Logf("Memory stats: start=%d bytes, end=%d bytes, growth=%.2f MB", + startMem.HeapAlloc, endMem.HeapAlloc, memGrowthMB) + + // Memory should not grow excessively (issue reported OOM at 2GB) + if memGrowthMB > 100 { + t.Errorf("Excessive memory growth: %.2f MB (possible memory leak)", memGrowthMB) + } + + // Verify refresh deduplication worked + actualRefreshAttempts := atomic.LoadInt32(&refreshAttempts) + t.Logf("Total refresh attempts to server: %d", actualRefreshAttempts) + t.Logf("Max concurrent refreshes: %d", maxConcurrent) + t.Logf("Successful refreshes: %d", successCount) + t.Logf("Failed refreshes: %d", errorCount) + + // With deduplication, refresh attempts should be much less than concurrent requests + if actualRefreshAttempts > int32(numConcurrentRequests/2) { + t.Errorf("Too many refresh attempts (%d), deduplication not working properly", + actualRefreshAttempts) + } + + // Max concurrent should respect our limit + if maxConcurrent > int32(config.MaxConcurrentRefreshes) { + t.Errorf("Max concurrent refreshes (%d) exceeded configured limit (%d)", + maxConcurrent, config.MaxConcurrentRefreshes) + } + + // Check coordinator metrics + metrics := coordinator.GetMetrics() + t.Logf("Coordinator metrics: %+v", metrics) + + if deduped, ok := metrics["deduplicated_requests"].(int64); ok { + if deduped == 0 { + t.Error("No requests were deduplicated - deduplication not working") + } + t.Logf("Deduplicated requests: %d", deduped) + } +} + +// TestIssue67_WithoutCoordinator demonstrates the issue without the fix +// WARNING: This test may consume significant memory - skip in CI +func TestIssue67_WithoutCoordinator(t *testing.T) { + if testing.Short() { + t.Skip("Skipping memory-intensive test in short mode") + } + + // Only run this test with explicit flag to demonstrate the issue + if !testing.Verbose() { + t.Skip("Skipping demonstration of issue without fix (run with -v to see)") + } + + // Track memory at start + runtime.GC() + var startMem runtime.MemStats + runtime.ReadMemStats(&startMem) + + var refreshAttempts int32 + var maxConcurrent int32 + var currentConcurrent int32 + + // Simulate the issue: multiple goroutines attempting refresh without coordination + numRequests := 100 + var wg sync.WaitGroup + wg.Add(numRequests) + + // Use a context with short timeout to prevent actual OOM + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + for i := 0; i < numRequests; i++ { + go func(id int) { + defer wg.Done() + + // Simulate retry logic without deduplication (the bug) + for attempt := 0; attempt < 3; attempt++ { + select { + case <-ctx.Done(): + return + default: + } + + current := atomic.AddInt32(¤tConcurrent, 1) + + // Track max concurrent + for { + max := atomic.LoadInt32(&maxConcurrent) + if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) { + break + } + } + + atomic.AddInt32(&refreshAttempts, 1) + + // Simulate token refresh with exponential backoff + time.Sleep(time.Duration(attempt*100) * time.Millisecond) + + // Allocate memory to simulate token processing + _ = make([]byte, 1024*10) // 10KB per attempt + + atomic.AddInt32(¤tConcurrent, -1) + + // Simulate failure requiring retry + if attempt < 2 { + continue + } + break + } + }(i) + } + + // Wait with timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Completed + case <-ctx.Done(): + // Timed out (expected in problematic scenario) + } + + // Check memory usage + runtime.GC() + var endMem runtime.MemStats + runtime.ReadMemStats(&endMem) + + memGrowthMB := float64(endMem.HeapAlloc-startMem.HeapAlloc) / (1024 * 1024) + + t.Logf("WITHOUT COORDINATOR:") + t.Logf(" Refresh attempts: %d", refreshAttempts) + t.Logf(" Max concurrent: %d", maxConcurrent) + t.Logf(" Memory growth: %.2f MB", memGrowthMB) + + // This demonstrates the issue - high concurrency and many attempts + if refreshAttempts < int32(numRequests*2) { + t.Logf("Note: Without coordinator, saw %d refresh attempts for %d requests", + refreshAttempts, numRequests) + } +} + +// MockExpiredSession simulates an expired session for testing +type MockExpiredSession struct { + refreshToken string + sessionID string + isExpired bool +} + +func (m *MockExpiredSession) GetRefreshToken() string { + return m.refreshToken +} + +func (m *MockExpiredSession) GetSessionID() string { + return m.sessionID +} + +func (m *MockExpiredSession) IsExpired() bool { + return m.isExpired +} + +// BenchmarkRefreshWithCoordinator measures performance with the fix +func BenchmarkRefreshWithCoordinator(b *testing.B) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + refreshFunc := func() (*TokenResponse, error) { + // Simulate token refresh + time.Sleep(10 * time.Millisecond) + return &TokenResponse{ + AccessToken: "new_token", + RefreshToken: "new_refresh", + }, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + ctx := context.Background() + sessionID := fmt.Sprintf("session_%d", i%10) + refreshToken := "refresh_token" + + _, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc) + i++ + } + }) + + b.StopTimer() + + metrics := coordinator.GetMetrics() + b.Logf("Total requests: %v", metrics["total_requests"]) + b.Logf("Deduplicated: %v", metrics["deduplicated_requests"]) + b.Logf("Success rate: %.2f%%", + float64(metrics["successful_refreshes"].(int64))/ + float64(metrics["total_requests"].(int64))*100) +} + +// TestRefreshCoordinatorIntegration tests the full integration +func TestRefreshCoordinatorIntegration(t *testing.T) { + // This test verifies the coordinator integrates properly with: + // 1. Circuit breaker + // 2. Rate limiting + // 3. Deduplication + // 4. Memory management + // 5. Cleanup routines + + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + config.MaxRefreshAttempts = 5 + config.RefreshAttemptWindow = 1 * time.Second + config.RefreshCooldownPeriod = 2 * time.Second + config.MaxConcurrentRefreshes = 3 + config.CleanupInterval = 500 * time.Millisecond + + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Test 1: Normal operation + t.Run("NormalOperation", func(t *testing.T) { + refreshFunc := func() (*TokenResponse, error) { + return &TokenResponse{AccessToken: "token1"}, nil + } + + ctx := context.Background() + result, err := coordinator.CoordinateRefresh(ctx, "session1", "refresh1", refreshFunc) + + if err != nil { + t.Errorf("Normal refresh failed: %v", err) + } + if result == nil || result.AccessToken != "token1" { + t.Error("Invalid result from normal refresh") + } + }) + + // Test 2: Circuit breaker activation + t.Run("CircuitBreaker", func(t *testing.T) { + failingRefresh := func() (*TokenResponse, error) { + return nil, fmt.Errorf("service unavailable") + } + + // Trigger circuit breaker + for i := 0; i < 4; i++ { + ctx := context.Background() + _, _ = coordinator.CoordinateRefresh(ctx, + fmt.Sprintf("cb_session_%d", i), "refresh_cb", failingRefresh) + } + + // Next request should be blocked by circuit breaker + ctx := context.Background() + _, err := coordinator.CoordinateRefresh(ctx, "cb_session_blocked", "refresh_cb", failingRefresh) + + if err == nil || !strings.Contains(err.Error(), "circuit breaker") { + t.Errorf("Circuit breaker should have blocked request: %v", err) + } + }) + + // Test 3: Rate limiting + t.Run("RateLimiting", func(t *testing.T) { + // Reset circuit breaker to closed state for this test + coordinator.circuitBreaker.mutex.Lock() + atomic.StoreInt32(&coordinator.circuitBreaker.state, 0) // closed + atomic.StoreInt32(&coordinator.circuitBreaker.failures, 0) + coordinator.circuitBreaker.mutex.Unlock() + + // Temporarily increase circuit breaker threshold to not interfere + oldMaxFailures := coordinator.circuitBreaker.config.MaxFailures + coordinator.circuitBreaker.config.MaxFailures = 20 + defer func() { + coordinator.circuitBreaker.config.MaxFailures = oldMaxFailures + }() + + failingRefresh := func() (*TokenResponse, error) { + return nil, fmt.Errorf("failed") + } + + sessionID := "rate_limit_session" + + // Exhaust attempts + for i := 0; i < config.MaxRefreshAttempts+1; i++ { + ctx := context.Background() + _, _ = coordinator.CoordinateRefresh(ctx, sessionID, "refresh_rl", failingRefresh) + // Add delay to ensure operations complete and aren't deduplicated + time.Sleep(150 * time.Millisecond) + } + + // Should be in cooldown + ctx := context.Background() + _, err := coordinator.CoordinateRefresh(ctx, sessionID, "refresh_rl", failingRefresh) + + if err == nil || !strings.Contains(err.Error(), "cooldown") { + t.Errorf("Rate limiting should have triggered cooldown: %v", err) + } + }) + + // Test 4: Cleanup + t.Run("Cleanup", func(t *testing.T) { + // Add some sessions + for i := 0; i < 5; i++ { + coordinator.recordRefreshAttempt(fmt.Sprintf("cleanup_session_%d", i)) + } + + // Wait for cleanup + time.Sleep(config.CleanupInterval * 3) + + // Old sessions should be cleaned up + coordinator.attemptsMutex.RLock() + count := len(coordinator.sessionRefreshAttempts) + coordinator.attemptsMutex.RUnlock() + + // Should have fewer sessions after cleanup + if count > 10 { + t.Errorf("Cleanup not working, %d sessions remain", count) + } + }) + + // Verify final metrics + metrics := coordinator.GetMetrics() + t.Logf("Final metrics: %+v", metrics) +} diff --git a/memory_optimizations.go b/memory_optimizations.go index edf43d9..3ecf434 100644 --- a/memory_optimizations.go +++ b/memory_optimizations.go @@ -32,6 +32,12 @@ func GetMemoryOptimizations() *MemoryOptimizations { return globalMemoryOpts } +// ResetGlobalMemoryOptimizations resets the global memory optimizations for testing +func ResetGlobalMemoryOptimizations() { + globalMemoryOptsOnce = sync.Once{} + globalMemoryOpts = nil +} + // BufferPool manages a pool of byte buffers type BufferPool struct { pool sync.Pool diff --git a/metadata_cache.go b/metadata_cache.go index 182152c..79e47d4 100644 --- a/metadata_cache.go +++ b/metadata_cache.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "sync" "time" ) @@ -95,7 +96,8 @@ func (mc *MetadataCache) GetProviderMetadata(ctx context.Context, providerURL st } // Fetch from provider - metadataURL := providerURL + "/.well-known/openid-configuration" + // Ensure no double slashes by trimming trailing slash from provider URL + metadataURL := strings.TrimRight(providerURL, "/") + "/.well-known/openid-configuration" mc.logger.Infof("Fetching provider metadata from: %s", metadataURL) req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil) diff --git a/refresh_coordinator.go b/refresh_coordinator.go new file mode 100644 index 0000000..5c31464 --- /dev/null +++ b/refresh_coordinator.go @@ -0,0 +1,596 @@ +package traefikoidc + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "sync" + "sync/atomic" + "time" +) + +// RefreshCoordinator prevents duplicate refresh token operations and manages +// refresh attempt tracking to prevent infinite loops and OOM conditions. +// It implements request coalescing, rate limiting, and circuit breaking +// specifically for token refresh operations. +type RefreshCoordinator struct { + // inFlightRefreshes tracks active refresh operations by refresh token hash + inFlightRefreshes map[string]*refreshOperation + // refreshMutex protects the inFlightRefreshes map + refreshMutex sync.RWMutex + + // sessionRefreshAttempts tracks refresh attempts per session + sessionRefreshAttempts map[string]*refreshAttemptTracker + // attemptsMutex protects sessionRefreshAttempts map + attemptsMutex sync.RWMutex + + // Circuit breaker for refresh operations + circuitBreaker *RefreshCircuitBreaker + + // Configuration + config RefreshCoordinatorConfig + + // Metrics + metrics *RefreshMetrics + + // Logger + logger *Logger + + // Cleanup goroutine control + stopChan chan struct{} + wg sync.WaitGroup +} + +// RefreshCoordinatorConfig configures the refresh coordinator behavior +type RefreshCoordinatorConfig struct { + // Maximum refresh attempts per session before giving up + MaxRefreshAttempts int + // Time window for refresh attempt tracking + RefreshAttemptWindow time.Duration + // Cooldown period after max attempts reached + RefreshCooldownPeriod time.Duration + // Maximum concurrent refresh operations + MaxConcurrentRefreshes int + // Timeout for individual refresh operations + RefreshTimeout time.Duration + // Enable memory pressure detection + EnableMemoryPressureDetection bool + // Memory pressure threshold (in MB) + MemoryPressureThresholdMB uint64 + // Cleanup interval for stale entries + CleanupInterval time.Duration + // Delay before cleaning up completed refresh operations from deduplication map + // Set to 0 for immediate cleanup (useful for tests) + DeduplicationCleanupDelay time.Duration +} + +// DefaultRefreshCoordinatorConfig returns production-ready configuration +func DefaultRefreshCoordinatorConfig() RefreshCoordinatorConfig { + return RefreshCoordinatorConfig{ + MaxRefreshAttempts: 5, + RefreshAttemptWindow: 5 * time.Minute, + RefreshCooldownPeriod: 10 * time.Minute, + MaxConcurrentRefreshes: 10, + RefreshTimeout: 30 * time.Second, + EnableMemoryPressureDetection: true, + MemoryPressureThresholdMB: 500, // 500MB threshold + CleanupInterval: 1 * time.Minute, + DeduplicationCleanupDelay: 100 * time.Millisecond, // Default 100ms for production + } +} + +// refreshOperation represents an in-flight refresh operation +type refreshOperation struct { + // refreshToken being refreshed (for validation) + refreshToken string + // result stores the final result + result *refreshResult + // done signals when the operation is complete + done chan struct{} + // startTime tracks when the operation started + startTime time.Time + // waiterCount tracks number of goroutines waiting + waiterCount int32 + // mutex protects the result field + mutex sync.RWMutex +} + +// refreshResult contains the result of a refresh operation +type refreshResult struct { + tokenResponse *TokenResponse + err error + fromCache bool +} + +// refreshAttemptTracker tracks refresh attempts for a session +type refreshAttemptTracker struct { + // attempts counts refresh attempts in current window + attempts int32 + // lastAttemptTime is the timestamp of the last attempt + lastAttemptTime time.Time + // windowStartTime is when the current tracking window started + windowStartTime time.Time + // inCooldown indicates if this session is in cooldown + inCooldown bool + // cooldownEndTime is when cooldown period ends + cooldownEndTime time.Time + // consecutiveFailures tracks consecutive refresh failures + consecutiveFailures int32 +} + +// RefreshMetrics tracks coordinator performance metrics +type RefreshMetrics struct { + totalRefreshRequests int64 + deduplicatedRequests int64 + successfulRefreshes int64 + failedRefreshes int64 + circuitBreakerTrips int64 + memoryPressureEvents int64 + cooldownsTriggered int64 + currentInFlightRefreshes int32 +} + +// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations +type RefreshCircuitBreaker struct { + state int32 // 0=closed, 1=open, 2=half-open + failures int32 + lastFailureTime time.Time + lastSuccessTime time.Time + config RefreshCircuitBreakerConfig + mutex sync.RWMutex +} + +// RefreshCircuitBreakerConfig configures the refresh circuit breaker +type RefreshCircuitBreakerConfig struct { + MaxFailures int + OpenDuration time.Duration + HalfOpenRequests int +} + +// NewRefreshCoordinator creates a new refresh coordinator +func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *RefreshCoordinator { + if logger == nil { + logger = GetSingletonNoOpLogger() + } + + rc := &RefreshCoordinator{ + inFlightRefreshes: make(map[string]*refreshOperation), + sessionRefreshAttempts: make(map[string]*refreshAttemptTracker), + config: config, + metrics: &RefreshMetrics{}, + logger: logger, + stopChan: make(chan struct{}), + circuitBreaker: &RefreshCircuitBreaker{ + config: RefreshCircuitBreakerConfig{ + MaxFailures: 3, + OpenDuration: 30 * time.Second, + HalfOpenRequests: 1, + }, + }, + } + + // Start cleanup goroutine + rc.wg.Add(1) + go rc.cleanupRoutine() + + return rc +} + +// CoordinateRefresh ensures only one refresh operation happens per refresh token +// and implements request coalescing for concurrent refresh attempts +func (rc *RefreshCoordinator) CoordinateRefresh( + ctx context.Context, + sessionID string, + refreshToken string, + refreshFunc func() (*TokenResponse, error), +) (*TokenResponse, error) { + // Increment total request count + atomic.AddInt64(&rc.metrics.totalRefreshRequests, 1) + + // Check circuit breaker first + if !rc.circuitBreaker.AllowRequest() { + atomic.AddInt64(&rc.metrics.circuitBreakerTrips, 1) + return nil, fmt.Errorf("refresh circuit breaker is open due to repeated failures") + } + + // Create hash of refresh token for deduplication + tokenHash := rc.hashRefreshToken(refreshToken) + + // CRITICAL FIX: Atomically check for existing operation OR create new one + // This prevents the race where multiple goroutines check, find nothing, then all create + operation, isNew, err := rc.getOrCreateOperation(ctx, sessionID, tokenHash, refreshToken) + + if err != nil { + // Operation creation was rejected (rate limit, memory pressure, concurrent limit) + return nil, err + } + + if isNew { + // We created a new operation, so we need to execute it + go rc.executeRefreshAsync(operation, sessionID, tokenHash, refreshFunc) + } else { + // Joined existing operation - this is a deduplicated request + atomic.AddInt64(&rc.metrics.deduplicatedRequests, 1) + } + + // Wait for the operation to complete + select { + case <-operation.done: + // Get the result + operation.mutex.RLock() + result := operation.result + operation.mutex.RUnlock() + + if result != nil { + // Record metrics based on result + if result.err != nil { + rc.circuitBreaker.RecordFailure() + rc.recordRefreshFailure(sessionID) + atomic.AddInt64(&rc.metrics.failedRefreshes, 1) + } else { + rc.circuitBreaker.RecordSuccess() + rc.recordRefreshSuccess(sessionID) + atomic.AddInt64(&rc.metrics.successfulRefreshes, 1) + } + return result.tokenResponse, result.err + } + return nil, fmt.Errorf("refresh operation completed without result") + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// getOrCreateOperation atomically checks for an existing operation or creates a new one +// Returns (operation, true, nil) if a new operation was created +// Returns (operation, false, nil) if joined an existing operation +// Returns (nil, false, error) if the operation was rejected +func (rc *RefreshCoordinator) getOrCreateOperation( + ctx context.Context, + sessionID string, + tokenHash string, + refreshToken string, +) (*refreshOperation, bool, error) { + rc.refreshMutex.Lock() + defer rc.refreshMutex.Unlock() + + // Check for existing operation while holding the lock + if existingOp, exists := rc.inFlightRefreshes[tokenHash]; exists { + if existingOp.refreshToken == refreshToken { + // Join existing operation + atomic.AddInt32(&existingOp.waiterCount, 1) + return existingOp, false, nil + } + // Different refresh token for same hash - should not happen + return nil, false, fmt.Errorf("refresh token mismatch") + } + + // No existing operation - check if we can create a new one + // All checks happen while holding the lock to prevent races + + // Check and record refresh attempt for rate limiting + rc.recordRefreshAttempt(sessionID) + if rc.isInCooldown(sessionID) { + atomic.AddInt64(&rc.metrics.cooldownsTriggered, 1) + return nil, false, fmt.Errorf("refresh attempts exceeded for session, in cooldown period") + } + + // Check memory pressure + if rc.config.EnableMemoryPressureDetection && rc.isUnderMemoryPressure() { + atomic.AddInt64(&rc.metrics.memoryPressureEvents, 1) + return nil, false, fmt.Errorf("system under memory pressure, refresh denied") + } + + // Check and reserve concurrent refresh slot atomically + current := atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes) + if int(current) >= rc.config.MaxConcurrentRefreshes { + return nil, false, fmt.Errorf("maximum concurrent refresh operations reached") + } + + // Reserve the slot - we're still holding the lock so this is safe + atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1) + + // Create and register new operation + operation := &refreshOperation{ + refreshToken: refreshToken, + done: make(chan struct{}), + startTime: time.Now(), + waiterCount: 1, + } + rc.inFlightRefreshes[tokenHash] = operation + + return operation, true, nil +} + +// executeRefreshAsync performs the actual refresh operation asynchronously +func (rc *RefreshCoordinator) executeRefreshAsync( + operation *refreshOperation, + sessionID string, + tokenHash string, + refreshFunc func() (*TokenResponse, error), +) { + defer func() { + // Signal completion to all waiters + close(operation.done) + + // Clean up operation after a configurable delay to allow waiters to read result + go func() { + if rc.config.DeduplicationCleanupDelay > 0 { + time.Sleep(rc.config.DeduplicationCleanupDelay) + } + rc.refreshMutex.Lock() + delete(rc.inFlightRefreshes, tokenHash) + rc.refreshMutex.Unlock() + atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1) + }() + }() + + // Create timeout context + refreshCtx, cancel := context.WithTimeout(context.Background(), rc.config.RefreshTimeout) + defer cancel() + + // Execute refresh in goroutine to respect timeout + resultChan := make(chan struct { + resp *TokenResponse + err error + }, 1) + + go func() { + resp, err := refreshFunc() + select { + case resultChan <- struct { + resp *TokenResponse + err error + }{resp, err}: + case <-refreshCtx.Done(): + } + }() + + select { + case result := <-resultChan: + // Store result for all waiters + operation.mutex.Lock() + operation.result = &refreshResult{ + tokenResponse: result.resp, + err: result.err, + fromCache: false, + } + operation.mutex.Unlock() + case <-refreshCtx.Done(): + // Timeout occurred + timeoutErr := fmt.Errorf("refresh operation timed out after %v", rc.config.RefreshTimeout) + operation.mutex.Lock() + operation.result = &refreshResult{ + tokenResponse: nil, + err: timeoutErr, + fromCache: false, + } + operation.mutex.Unlock() + } +} + +// isInCooldown checks if a session is in cooldown after recording an attempt +func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool { + rc.attemptsMutex.Lock() + defer rc.attemptsMutex.Unlock() + + tracker, exists := rc.sessionRefreshAttempts[sessionID] + if !exists { + return false // No tracker means first attempt, not in cooldown + } + + now := time.Now() + + // Check if already in cooldown + if tracker.inCooldown { + if now.After(tracker.cooldownEndTime) { + // Cooldown expired, reset tracker + tracker.inCooldown = false + tracker.attempts = 1 // Already recorded one attempt + tracker.consecutiveFailures = 0 + tracker.windowStartTime = now + return false + } + return true // Still in cooldown + } + + // Check if window expired + if now.Sub(tracker.windowStartTime) > rc.config.RefreshAttemptWindow { + // Reset window + tracker.attempts = 1 // Already recorded one attempt + tracker.windowStartTime = now + return false + } + + // Check if just exceeded attempt limit + if int(tracker.attempts) >= rc.config.MaxRefreshAttempts { + // Enter cooldown now + tracker.inCooldown = true + tracker.cooldownEndTime = now.Add(rc.config.RefreshCooldownPeriod) + rc.logger.Infof("Session %s entering refresh cooldown after %d attempts", + sessionID, tracker.attempts) + return true + } + + return false +} + +// recordRefreshAttempt records a refresh attempt for rate limiting +func (rc *RefreshCoordinator) recordRefreshAttempt(sessionID string) { + rc.attemptsMutex.Lock() + defer rc.attemptsMutex.Unlock() + + tracker, exists := rc.sessionRefreshAttempts[sessionID] + if !exists { + tracker = &refreshAttemptTracker{ + windowStartTime: time.Now(), + } + rc.sessionRefreshAttempts[sessionID] = tracker + } + + atomic.AddInt32(&tracker.attempts, 1) + tracker.lastAttemptTime = time.Now() +} + +// recordRefreshSuccess records a successful refresh +func (rc *RefreshCoordinator) recordRefreshSuccess(sessionID string) { + rc.attemptsMutex.Lock() + defer rc.attemptsMutex.Unlock() + + if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists { + tracker.consecutiveFailures = 0 + } +} + +// recordRefreshFailure records a failed refresh +func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) { + rc.attemptsMutex.Lock() + defer rc.attemptsMutex.Unlock() + + if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists { + atomic.AddInt32(&tracker.consecutiveFailures, 1) + } +} + +// hashRefreshToken creates a hash of the refresh token for deduplication +func (rc *RefreshCoordinator) hashRefreshToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:]) +} + +// isUnderMemoryPressure checks if the system is under memory pressure +func (rc *RefreshCoordinator) isUnderMemoryPressure() bool { + // This is a simplified check - in production you'd want to use runtime.MemStats + // or system-specific memory monitoring + return false // Placeholder - implement actual memory check +} + +// cleanupRoutine periodically cleans up stale tracking entries +func (rc *RefreshCoordinator) cleanupRoutine() { + defer rc.wg.Done() + + ticker := time.NewTicker(rc.config.CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + rc.cleanupStaleEntries() + case <-rc.stopChan: + return + } + } +} + +// cleanupStaleEntries removes outdated tracking entries +func (rc *RefreshCoordinator) cleanupStaleEntries() { + now := time.Now() + + rc.attemptsMutex.Lock() + defer rc.attemptsMutex.Unlock() + + // Clean up old session trackers + for sessionID, tracker := range rc.sessionRefreshAttempts { + // Remove trackers that haven't been used recently + if now.Sub(tracker.lastAttemptTime) > 2*rc.config.RefreshAttemptWindow { + delete(rc.sessionRefreshAttempts, sessionID) + } + } +} + +// GetMetrics returns current coordinator metrics +func (rc *RefreshCoordinator) GetMetrics() map[string]interface{} { + return map[string]interface{}{ + "total_requests": atomic.LoadInt64(&rc.metrics.totalRefreshRequests), + "deduplicated_requests": atomic.LoadInt64(&rc.metrics.deduplicatedRequests), + "successful_refreshes": atomic.LoadInt64(&rc.metrics.successfulRefreshes), + "failed_refreshes": atomic.LoadInt64(&rc.metrics.failedRefreshes), + "circuit_breaker_trips": atomic.LoadInt64(&rc.metrics.circuitBreakerTrips), + "memory_pressure_events": atomic.LoadInt64(&rc.metrics.memoryPressureEvents), + "cooldowns_triggered": atomic.LoadInt64(&rc.metrics.cooldownsTriggered), + "current_inflight": atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes), + "circuit_breaker_state": rc.circuitBreaker.GetState(), + } +} + +// Shutdown gracefully shuts down the coordinator +func (rc *RefreshCoordinator) Shutdown() { + close(rc.stopChan) + rc.wg.Wait() +} + +// AllowRequest checks if the circuit breaker allows a request +func (cb *RefreshCircuitBreaker) AllowRequest() bool { + cb.mutex.RLock() + defer cb.mutex.RUnlock() + + state := atomic.LoadInt32(&cb.state) + + switch state { + case 0: // Closed + return true + case 1: // Open + if time.Since(cb.lastFailureTime) > cb.config.OpenDuration { + // Try to transition to half-open + if atomic.CompareAndSwapInt32(&cb.state, 1, 2) { + return true + } + } + return false + case 2: // Half-open + return true + default: + return false + } +} + +// RecordSuccess records a successful operation +func (cb *RefreshCircuitBreaker) RecordSuccess() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + state := atomic.LoadInt32(&cb.state) + if state == 2 { // Half-open + // Close the circuit + atomic.StoreInt32(&cb.state, 0) + atomic.StoreInt32(&cb.failures, 0) + } else if state == 0 { // Closed + // Reset failure count on success + atomic.StoreInt32(&cb.failures, 0) + } + cb.lastSuccessTime = time.Now() +} + +// RecordFailure records a failed operation +func (cb *RefreshCircuitBreaker) RecordFailure() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + failures := atomic.AddInt32(&cb.failures, 1) + cb.lastFailureTime = time.Now() + + state := atomic.LoadInt32(&cb.state) + + if state == 0 && int(failures) >= cb.config.MaxFailures { + // Open the circuit + atomic.StoreInt32(&cb.state, 1) + } else if state == 2 { + // Half-open failed, return to open + atomic.StoreInt32(&cb.state, 1) + } +} + +// GetState returns the current state of the circuit breaker +func (cb *RefreshCircuitBreaker) GetState() string { + state := atomic.LoadInt32(&cb.state) + switch state { + case 0: + return "closed" + case 1: + return "open" + case 2: + return "half-open" + default: + return "unknown" + } +} diff --git a/refresh_coordinator_test.go b/refresh_coordinator_test.go new file mode 100644 index 0000000..78065f6 --- /dev/null +++ b/refresh_coordinator_test.go @@ -0,0 +1,669 @@ +package traefikoidc + +import ( + "context" + "fmt" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestConcurrentRefreshDeduplication verifies that concurrent refresh attempts +// for the same token are deduplicated and only one refresh operation occurs +func TestConcurrentRefreshDeduplication(t *testing.T) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + // Keep default delay for this test - it's testing deduplication behavior + // Disable rate limiting for this test since we're testing deduplication + config.MaxRefreshAttempts = 1000 // High enough to not interfere + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Counter to track actual refresh executions + var refreshExecutions int32 + + // Mock refresh function + refreshFunc := func() (*TokenResponse, error) { + atomic.AddInt32(&refreshExecutions, 1) + // Simulate some processing time + time.Sleep(100 * time.Millisecond) + return &TokenResponse{ + AccessToken: "new_access_token", + RefreshToken: "new_refresh_token", + IDToken: "new_id_token", + ExpiresIn: 3600, + }, nil + } + + // Number of concurrent requests + numRequests := 100 + var wg sync.WaitGroup + wg.Add(numRequests) + + // Channel to collect results + results := make(chan *TokenResponse, numRequests) + errors := make(chan error, numRequests) + + // Launch concurrent refresh attempts with unique identifiers + refreshToken := fmt.Sprintf("test_refresh_token_%d", time.Now().UnixNano()) + sessionID := fmt.Sprintf("test_session_%d", time.Now().UnixNano()) + + for i := 0; i < numRequests; i++ { + go func(reqID int) { + defer wg.Done() + + ctx := context.Background() + resp, err := coordinator.CoordinateRefresh( + ctx, + sessionID, + refreshToken, + refreshFunc, + ) + + if err != nil { + errors <- err + } else { + results <- resp + } + }(i) + } + + // Wait for all goroutines to complete + wg.Wait() + close(results) + close(errors) + + // Verify results + actualExecutions := atomic.LoadInt32(&refreshExecutions) + // Allow for slight timing variations - up to 2 executions is acceptable + // This can happen when a second goroutine starts just as the first completes + if actualExecutions > 2 { + t.Errorf("Expected 1-2 refresh executions, got %d", actualExecutions) + } + + // Verify all requests got the same result + var firstResponse *TokenResponse + responseCount := 0 + + for resp := range results { + responseCount++ + if firstResponse == nil { + firstResponse = resp + } else { + // All responses should be identical (same pointer) + if resp.AccessToken != firstResponse.AccessToken { + t.Error("Different responses returned for concurrent requests") + } + } + } + + // Check for errors + errorCount := 0 + for range errors { + errorCount++ + } + + if errorCount > 0 { + t.Errorf("Unexpected errors in concurrent requests: %d", errorCount) + } + + if responseCount != numRequests { + t.Errorf("Expected %d successful responses, got %d", numRequests, responseCount) + } + + // Verify metrics + metrics := coordinator.GetMetrics() + if deduped, ok := metrics["deduplicated_requests"].(int64); ok { + // Allow for slight timing variations - at least 98 out of 100 should be deduplicated + if deduped < int64(numRequests-2) { + t.Errorf("Expected at least %d deduplicated requests, got %d", numRequests-2, deduped) + } + } +} + +// TestRefreshRateLimiting verifies that refresh attempts are rate-limited per session +func TestRefreshRateLimiting(t *testing.T) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + config.MaxRefreshAttempts = 3 + config.RefreshAttemptWindow = 1 * time.Second + config.RefreshCooldownPeriod = 2 * time.Second + + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Set circuit breaker to not interfere with rate limiting test + // We want to test rate limiting, not circuit breaker + coordinator.circuitBreaker.config.MaxFailures = 10 + + sessionID := "rate_limited_session" + refreshToken := "test_refresh_token" + + // Mock refresh function that always fails + refreshFunc := func() (*TokenResponse, error) { + return nil, fmt.Errorf("refresh failed") + } + + // Attempt refreshes beyond the limit + var attempts int + var cooldownTriggered bool + + for i := 0; i < 5; i++ { + ctx := context.Background() + _, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc) + + if err != nil { + if err.Error() == "refresh attempts exceeded for session, in cooldown period" { + cooldownTriggered = true + break + } + } + attempts++ + // Add delay to ensure operations complete and aren't deduplicated + time.Sleep(150 * time.Millisecond) + } + + // Verify that cooldown was triggered after max attempts + // With the new logic, the Nth attempt triggers cooldown, so we get N-1 successful attempts + expectedSuccessfulAttempts := config.MaxRefreshAttempts - 1 + if attempts != expectedSuccessfulAttempts { + t.Errorf("Expected %d successful attempts before cooldown, got %d", expectedSuccessfulAttempts, attempts) + } + + if !cooldownTriggered { + t.Error("Cooldown was not triggered after max attempts") + } + + // Verify that requests are blocked during cooldown + ctx := context.Background() + _, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc) + if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" { + t.Error("Request should be blocked during cooldown period") + } + + // Wait for cooldown to expire + time.Sleep(config.RefreshCooldownPeriod + 100*time.Millisecond) + + // Verify that requests are allowed after cooldown + _, err = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc) + if err != nil && err.Error() == "refresh attempts exceeded for session, in cooldown period" { + t.Error("Request should be allowed after cooldown period") + } +} + +// TestCircuitBreakerProtection verifies that the circuit breaker prevents +// cascading failures during repeated refresh failures +func TestCircuitBreakerProtection(t *testing.T) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Set circuit breaker to trip after 3 failures + coordinator.circuitBreaker.config.MaxFailures = 3 + coordinator.circuitBreaker.config.OpenDuration = 1 * time.Second + + // Mock refresh function that always fails + refreshFunc := func() (*TokenResponse, error) { + return nil, fmt.Errorf("service unavailable") + } + + // Cause circuit breaker to trip + var tripCount int + for i := 0; i < 5; i++ { + ctx := context.Background() + _, err := coordinator.CoordinateRefresh( + ctx, + fmt.Sprintf("session_%d", i), // Different sessions + "refresh_token", + refreshFunc, + ) + + if err != nil && err.Error() == "refresh circuit breaker is open due to repeated failures" { + tripCount++ + } + } + + // Verify circuit breaker tripped + if tripCount == 0 { + t.Error("Circuit breaker did not trip after repeated failures") + } + + // Verify circuit breaker state + if coordinator.circuitBreaker.GetState() != "open" { + t.Errorf("Expected circuit breaker state 'open', got '%s'", coordinator.circuitBreaker.GetState()) + } + + // Wait for circuit to transition to half-open + time.Sleep(coordinator.circuitBreaker.config.OpenDuration + 100*time.Millisecond) + + // Mock successful refresh + successfulRefreshFunc := func() (*TokenResponse, error) { + return &TokenResponse{ + AccessToken: "new_token", + }, nil + } + + // Verify circuit allows request in half-open state + ctx := context.Background() + _, err := coordinator.CoordinateRefresh(ctx, "session_recovery", "refresh_token", successfulRefreshFunc) + if err != nil { + t.Errorf("Circuit breaker should allow request in half-open state: %v", err) + } + + // Verify circuit closed after success + if coordinator.circuitBreaker.GetState() != "closed" { + t.Errorf("Expected circuit breaker state 'closed' after successful request, got '%s'", + coordinator.circuitBreaker.GetState()) + } +} + +// TestMemoryLeakPrevention verifies that the coordinator doesn't leak memory +// during sustained concurrent refresh operations +func TestMemoryLeakPrevention(t *testing.T) { + if testing.Short() { + t.Skip("Skipping memory leak test in short mode") + } + + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + config.CleanupInterval = 100 * time.Millisecond + config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Force garbage collection and record initial memory + runtime.GC() + runtime.GC() + var initialMem runtime.MemStats + runtime.ReadMemStats(&initialMem) + + // Run sustained concurrent operations + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + var wg sync.WaitGroup + numWorkers := 10 + wg.Add(numWorkers) + + // Each worker continuously attempts refreshes + for i := 0; i < numWorkers; i++ { + go func(workerID int) { + defer wg.Done() + + refreshCount := 0 + refreshFunc := func() (*TokenResponse, error) { + // Simulate varying response times + time.Sleep(time.Duration(workerID*10) * time.Millisecond) + return &TokenResponse{ + AccessToken: fmt.Sprintf("token_%d_%d", workerID, refreshCount), + RefreshToken: fmt.Sprintf("refresh_%d_%d", workerID, refreshCount), + }, nil + } + + for { + select { + case <-ctx.Done(): + return + default: + sessionID := fmt.Sprintf("session_%d", workerID) + refreshToken := fmt.Sprintf("refresh_%d_%d", workerID, refreshCount) + + _, _ = coordinator.CoordinateRefresh( + context.Background(), + sessionID, + refreshToken, + refreshFunc, + ) + + refreshCount++ + // Small delay to prevent CPU saturation + time.Sleep(10 * time.Millisecond) + } + } + }(i) + } + + // Wait for workers to complete + wg.Wait() + + // Allow cleanup to run + time.Sleep(2 * config.CleanupInterval) + + // Force garbage collection and check memory + runtime.GC() + runtime.GC() + var finalMem runtime.MemStats + runtime.ReadMemStats(&finalMem) + + // Calculate memory growth safely to prevent underflow + var memGrowthMB float64 + if finalMem.HeapAlloc >= initialMem.HeapAlloc { + memGrowthMB = float64(finalMem.HeapAlloc-initialMem.HeapAlloc) / (1024 * 1024) + } else { + // Memory decreased (GC occurred), treat as 0 growth + memGrowthMB = 0 + } + + // Log memory statistics for debugging + t.Logf("Initial memory: %.2f MB", float64(initialMem.HeapAlloc)/(1024*1024)) + t.Logf("Final memory: %.2f MB", float64(finalMem.HeapAlloc)/(1024*1024)) + t.Logf("Memory growth: %.2f MB", memGrowthMB) + + // Check for excessive memory growth (threshold: 50MB) + if memGrowthMB > 50 { + t.Errorf("Excessive memory growth detected: %.2f MB", memGrowthMB) + } + + // Verify no lingering operations + metrics := coordinator.GetMetrics() + if inflight, ok := metrics["current_inflight"].(int32); ok { + if inflight != 0 { + t.Errorf("Expected 0 in-flight operations after completion, got %d", inflight) + } + } + + // Verify cleanup is working + coordinator.attemptsMutex.RLock() + sessionCount := len(coordinator.sessionRefreshAttempts) + coordinator.attemptsMutex.RUnlock() + + // Should have cleaned up old sessions (only recent ones remain) + if sessionCount > numWorkers*2 { + t.Errorf("Session cleanup not working properly, %d sessions remain", sessionCount) + } +} + +// TestRefreshTimeoutHandling verifies that refresh operations timeout properly +func TestRefreshTimeoutHandling(t *testing.T) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + config.RefreshTimeout = 100 * time.Millisecond + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Mock refresh function that hangs + refreshFunc := func() (*TokenResponse, error) { + time.Sleep(1 * time.Second) // Much longer than timeout + return &TokenResponse{AccessToken: "token"}, nil + } + + ctx := context.Background() + start := time.Now() + + _, err := coordinator.CoordinateRefresh(ctx, "session", "refresh_token", refreshFunc) + + elapsed := time.Since(start) + + // Verify timeout occurred + if err == nil { + t.Error("Expected timeout error, got nil") + } + + // Verify it timed out within reasonable bounds + if elapsed > 200*time.Millisecond { + t.Errorf("Timeout took too long: %v", elapsed) + } + + if err != nil && err.Error() != fmt.Sprintf("refresh operation timed out after %v", config.RefreshTimeout) { + t.Errorf("Unexpected error message: %v", err) + } +} + +// TestConcurrentDifferentTokens verifies that refreshes for different tokens +// proceed independently without blocking each other +func TestConcurrentDifferentTokens(t *testing.T) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + numTokens := 10 + var wg sync.WaitGroup + wg.Add(numTokens) + + // Track execution order + executionOrder := make([]int, 0, numTokens) + var executionMutex sync.Mutex + + for i := 0; i < numTokens; i++ { + go func(tokenID int) { + defer wg.Done() + + refreshFunc := func() (*TokenResponse, error) { + executionMutex.Lock() + executionOrder = append(executionOrder, tokenID) + executionMutex.Unlock() + + // Varying processing times + time.Sleep(time.Duration(tokenID*10) * time.Millisecond) + + return &TokenResponse{ + AccessToken: fmt.Sprintf("token_%d", tokenID), + RefreshToken: fmt.Sprintf("refresh_%d", tokenID), + }, nil + } + + ctx := context.Background() + resp, err := coordinator.CoordinateRefresh( + ctx, + fmt.Sprintf("session_%d", tokenID), + fmt.Sprintf("refresh_token_%d", tokenID), + refreshFunc, + ) + + if err != nil { + t.Errorf("Token %d refresh failed: %v", tokenID, err) + } + + if resp == nil || resp.AccessToken != fmt.Sprintf("token_%d", tokenID) { + t.Errorf("Token %d got wrong response", tokenID) + } + }(i) + } + + wg.Wait() + + // Verify all tokens were processed + if len(executionOrder) != numTokens { + t.Errorf("Expected %d executions, got %d", numTokens, len(executionOrder)) + } + + // Verify no deduplication occurred (all different tokens) + metrics := coordinator.GetMetrics() + if deduped, ok := metrics["deduplicated_requests"].(int64); ok { + if deduped != 0 { + t.Errorf("No deduplication expected for different tokens, got %d", deduped) + } + } +} + +// TestMaxConcurrentRefreshes verifies that the coordinator respects +// the maximum concurrent refresh limit +func TestMaxConcurrentRefreshes(t *testing.T) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + config.MaxConcurrentRefreshes = 2 + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Track concurrent executions + var currentConcurrent int32 + var maxConcurrent int32 + + refreshFunc := func() (*TokenResponse, error) { + current := atomic.AddInt32(¤tConcurrent, 1) + + // Update max if needed + for { + max := atomic.LoadInt32(&maxConcurrent) + if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) { + break + } + } + + time.Sleep(100 * time.Millisecond) + atomic.AddInt32(¤tConcurrent, -1) + + return &TokenResponse{AccessToken: "token"}, nil + } + + numRequests := 10 + var wg sync.WaitGroup + wg.Add(numRequests) + + errors := make([]error, 0, numRequests) + var errorMutex sync.Mutex + + for i := 0; i < numRequests; i++ { + go func(id int) { + defer wg.Done() + + ctx := context.Background() + _, err := coordinator.CoordinateRefresh( + ctx, + fmt.Sprintf("session_%d", id), + fmt.Sprintf("token_%d", id), + refreshFunc, + ) + + if err != nil { + errorMutex.Lock() + errors = append(errors, err) + errorMutex.Unlock() + } + }(i) + } + + wg.Wait() + + // Some requests should have been rejected due to concurrency limit + if len(errors) == 0 { + t.Error("Expected some requests to be rejected due to concurrency limit") + } + + // Verify max concurrent never exceeded limit + if maxConcurrent > int32(config.MaxConcurrentRefreshes) { + t.Errorf("Max concurrent refreshes (%d) exceeded limit (%d)", + maxConcurrent, config.MaxConcurrentRefreshes) + } +} + +// TestSessionWindowReset verifies that refresh attempt windows reset properly +func TestSessionWindowReset(t *testing.T) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + config.MaxRefreshAttempts = 2 + config.RefreshAttemptWindow = 500 * time.Millisecond + config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior + + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Set circuit breaker to not interfere with rate limiting test + coordinator.circuitBreaker.config.MaxFailures = 10 + + // Use unique identifiers to prevent test interference + sessionID := fmt.Sprintf("window_test_session_%d", time.Now().UnixNano()) + refreshToken := fmt.Sprintf("test_refresh_token_%d", time.Now().UnixNano()) + + // Mock refresh function that always fails + refreshFunc := func() (*TokenResponse, error) { + return nil, fmt.Errorf("refresh failed") + } + + // Use up the attempts in the first window + for i := 0; i < config.MaxRefreshAttempts; i++ { + ctx := context.Background() + _, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc) + } + + // Next attempt should trigger cooldown + ctx := context.Background() + _, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc) + if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" { + t.Error("Expected cooldown after max attempts") + } + + // Wait for window to expire (but not cooldown) + time.Sleep(config.RefreshAttemptWindow + 100*time.Millisecond) + + // Should still be in cooldown (cooldown > window) + _, err = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc) + if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" { + t.Error("Should still be in cooldown period") + } +} + +// BenchmarkConcurrentRefreshDeduplication measures performance of deduplication +func BenchmarkConcurrentRefreshDeduplication(b *testing.B) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + refreshFunc := func() (*TokenResponse, error) { + time.Sleep(10 * time.Millisecond) + return &TokenResponse{ + AccessToken: "token", + }, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + ctx := context.Background() + sessionID := fmt.Sprintf("session_%d", i%10) // Reuse 10 sessions + refreshToken := fmt.Sprintf("token_%d", i%10) // Reuse 10 tokens + _, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc) + i++ + } + }) + + b.StopTimer() + + // Report metrics + metrics := coordinator.GetMetrics() + b.Logf("Total requests: %v", metrics["total_requests"]) + b.Logf("Deduplicated: %v", metrics["deduplicated_requests"]) +} + +// TestCleanupRoutine verifies that the cleanup routine removes stale entries +func TestCleanupRoutine(t *testing.T) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + config.CleanupInterval = 100 * time.Millisecond + config.RefreshAttemptWindow = 200 * time.Millisecond + + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Add some sessions + for i := 0; i < 5; i++ { + coordinator.recordRefreshAttempt(fmt.Sprintf("session_%d", i)) + } + + // Verify sessions exist + coordinator.attemptsMutex.RLock() + initialCount := len(coordinator.sessionRefreshAttempts) + coordinator.attemptsMutex.RUnlock() + + if initialCount != 5 { + t.Errorf("Expected 5 sessions, got %d", initialCount) + } + + // Wait for cleanup to run (2x window + cleanup interval) + time.Sleep(2*config.RefreshAttemptWindow + 2*config.CleanupInterval) + + // Verify sessions were cleaned up + coordinator.attemptsMutex.RLock() + finalCount := len(coordinator.sessionRefreshAttempts) + coordinator.attemptsMutex.RUnlock() + + if finalCount != 0 { + t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount) + } +} diff --git a/refresh_race_test.go b/refresh_race_test.go new file mode 100644 index 0000000..57f3075 --- /dev/null +++ b/refresh_race_test.go @@ -0,0 +1,159 @@ +package traefikoidc + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestRefreshCoordinatorRaceCondition specifically tests for race conditions +// in the refresh coordinator's concurrent operation handling +func TestRefreshCoordinatorRaceCondition(t *testing.T) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + // Increase rate limit for this race condition test + config.MaxRefreshAttempts = 100 // Allow many attempts for race testing + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + // Test concurrent access to the same refresh token + var executions int32 + refreshFunc := func() (*TokenResponse, error) { + atomic.AddInt32(&executions, 1) + time.Sleep(50 * time.Millisecond) // Simulate work + return &TokenResponse{ + AccessToken: "test_token", + RefreshToken: "test_refresh", + }, nil + } + + // Launch many goroutines concurrently + const numGoroutines = 50 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + ctx := context.Background() + sessionID := "test_session" + refreshToken := "test_refresh_token" + + // Use a channel to ensure all goroutines start at the same time + startChan := make(chan struct{}) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + + // Wait for signal to start + <-startChan + + // All goroutines try to refresh at the same time + result, err := coordinator.CoordinateRefresh( + ctx, + sessionID, + refreshToken, + refreshFunc, + ) + + // Basic validation + if err != nil { + t.Errorf("Goroutine %d: unexpected error: %v", id, err) + } + if result == nil || result.AccessToken != "test_token" { + t.Errorf("Goroutine %d: invalid result", id) + } + }(i) + } + + // Release all goroutines at once + close(startChan) + + // Wait for completion + wg.Wait() + + // Check that deduplication worked + actualExecutions := atomic.LoadInt32(&executions) + t.Logf("Executions: %d out of %d goroutines", actualExecutions, numGoroutines) + + // With proper deduplication, we should have very few executions + // Allow for some timing slack - up to 3 executions is acceptable + if actualExecutions > 3 { + t.Errorf("Too many refresh executions: %d (expected <= 3)", actualExecutions) + } + + // Verify metrics + metrics := coordinator.GetMetrics() + if total, ok := metrics["total_requests"].(int64); ok { + if total != int64(numGoroutines) { + t.Errorf("Expected %d total requests, got %d", numGoroutines, total) + } + } +} + +// TestRefreshCoordinatorNoRaceWithDifferentTokens verifies no interference +// between different refresh tokens +func TestRefreshCoordinatorNoRaceWithDifferentTokens(t *testing.T) { + logger := GetSingletonNoOpLogger() + config := DefaultRefreshCoordinatorConfig() + // Increase concurrent limit to handle 10 different tokens + config.MaxConcurrentRefreshes = 15 + config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior + // Increase rate limit since we have 5 goroutines per token + config.MaxRefreshAttempts = 10 // Allow multiple attempts per session + coordinator := NewRefreshCoordinator(config, logger) + defer coordinator.Shutdown() + + const numTokens = 10 + const goroutinesPerToken = 5 + + var totalExecutions int32 + var wg sync.WaitGroup + wg.Add(numTokens * goroutinesPerToken) + + refreshFunc := func() (*TokenResponse, error) { + atomic.AddInt32(&totalExecutions, 1) + time.Sleep(10 * time.Millisecond) + return &TokenResponse{ + AccessToken: "token", + }, nil + } + + // Launch goroutines for different tokens with unique identifiers + baseID := time.Now().UnixNano() + for tokenID := 0; tokenID < numTokens; tokenID++ { + sessionID := fmt.Sprintf("session_%d_%d", baseID, tokenID) + refreshToken := fmt.Sprintf("refresh_%d_%d", baseID, tokenID) + + for i := 0; i < goroutinesPerToken; i++ { + go func(tid, gid int) { + defer wg.Done() + + ctx := context.Background() + _, err := coordinator.CoordinateRefresh( + ctx, + sessionID, + refreshToken, + refreshFunc, + ) + + if err != nil && err.Error() != "maximum concurrent refresh operations reached" { + // Only log non-concurrent-limit errors as failures + t.Errorf("Token %d, Goroutine %d: unexpected error: %v", tid, gid, err) + } + }(tokenID, i) + } + } + + wg.Wait() + + // Each token should have had ~1 execution (maybe 2 due to timing) + actualExecutions := atomic.LoadInt32(&totalExecutions) + t.Logf("Total executions: %d for %d different tokens", actualExecutions, numTokens) + + // Should be close to numTokens (one per unique token) + if actualExecutions > numTokens*2 { + t.Errorf("Too many executions: %d (expected ~%d)", actualExecutions, numTokens) + } +} diff --git a/session/chunking/chunk_manager.go b/session/chunking/chunk_manager.go index 30f9c37..c8efe6a 100644 --- a/session/chunking/chunk_manager.go +++ b/session/chunking/chunk_manager.go @@ -34,6 +34,11 @@ var ( globalMaxSessions int64 = 5000 // CRITICAL FIX: Global limit of 5000 total sessions ) +// ResetGlobalSessionCounters resets global session tracking for testing +func ResetGlobalSessionCounters() { + atomic.StoreInt64(&globalSessionCount, 0) +} + // Predefined configurations for each token type var ( AccessTokenConfig = TokenConfig{ diff --git a/session_chunk_manager.go b/session_chunk_manager.go index ba3c593..52d49bf 100644 --- a/session_chunk_manager.go +++ b/session_chunk_manager.go @@ -33,6 +33,11 @@ var ( globalMaxSessions int64 = 5000 // CRITICAL FIX: Global limit of 5000 total sessions ) +// ResetGlobalSessionCounters resets global session tracking for testing +func ResetGlobalSessionCounters() { + atomic.StoreInt64(&globalSessionCount, 0) +} + // Predefined configurations for each token type var ( AccessTokenConfig = TokenConfig{ diff --git a/test_infrastructure.go b/test_infrastructure.go index f9bf222..58363a5 100644 --- a/test_infrastructure.go +++ b/test_infrastructure.go @@ -11,6 +11,8 @@ import ( "sync/atomic" "testing" "time" + + "github.com/lukaszraczylo/traefikoidc/session/chunking" ) // GlobalTestCleanup tracks and cleans up test resources @@ -113,6 +115,15 @@ func (g *GlobalTestCleanup) CleanupAll() { // Reset all global singletons to prevent state pollution between tests ResetGlobalMemoryMonitor() ResetGlobalTaskRegistry() + ResetGlobalMemoryOptimizations() + ResetSingletonNoOpLogger() + + // Reset global session counters to prevent overflow in memory calculations + ResetGlobalSessionCounters() + + // Reset global session counters in chunking package as well + // Note: This calls the function in session/chunking package + resetChunkingGlobalSessionCounters() // Give background tasks time to finish cleanup time.Sleep(100 * time.Millisecond) @@ -949,3 +960,9 @@ func (h *PerformanceTestHelper) Reset() { defer h.mu.Unlock() h.samples = h.samples[:0] } + +// resetChunkingGlobalSessionCounters resets the global session counters +// in the chunking package to prevent test interference +func resetChunkingGlobalSessionCounters() { + chunking.ResetGlobalSessionCounters() +}