mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
827926bc3a
Three related changes addressing post-v1.0.15 code-review findings and the user's observation that we have been "throwing maps around" — under Yaegi every sync.Map / atomic / mutex dispatch costs ~1-5ms of interpreter overhead, so the number of dispatches per request matters as much as whether they are lock-free. 1. Remove cleanupTimers map + cleanupTimerMu sync.Mutex. scheduleDelayedCleanup previously tracked every pending timer in a map guarded by a mutex so a duplicate scheduling could cancel the prior timer. That "shouldn't happen" path was the only consumer of the map, but the mutex fired on every successful refresh completion — another per-request Yaegi-dispatched lock. performCleanup is already idempotent (LoadAndDelete on the sync.Map), so a duplicate firing is at worst a no-op second call. Dropped the map entirely; time.AfterFunc callback now calls performCleanup directly. Net: -1 sync.Mutex, -1 map field, -2 Lock/Unlock pairs per refresh completion. Shutdown simplified — no need to enumerate-and-stop timers since the callbacks no longer need teardown. 2. Reorder applyLeaderGates: cooldown check BEFORE recordRefreshAttempt. Previously incremented the attempt counter and then checked cooldown. Under burst load (many concurrent leaders with different token hashes but the same session) every goroutine could increment past MaxRefreshAttempts before any one of them observed the threshold, so the gate fired too late — same thundering-herd shape that drove v1.0.14 into the ground. Reordering makes the gate authoritative: only attempts that pass the gate are recorded. Semantic change: with MaxRefreshAttempts=N, exactly N attempts now run to completion before the (N+1)th is denied. Previously the Nth was denied as it tried to record (off-by-one stricter). Test assertion updated to N (was N-1). 3. Fix getOrCreateOperation MaxConcurrentRefreshes overshoot. The previous CAS-loop allowed a transient overshoot of up to N-1 leaders when several goroutines all observed `current < max` in the same scheduling slice before any one of them succeeded their CAS — visible to readers as currentInFlightRefreshes > MaxConcurrentRefreshes for a brief window. Replaced with the ticket-and-return pattern: increment optimistically, decrement if we overshot. Strictly bounded: only the goroutine that produces max+1 sees max+1 as committed; the rest decrement back immediately. No CAS retry loop needed. What was NOT done in this commit, and why: * metadataMu.RLock cached via atomic snapshot — code-reviewer flagged this at severity 7 (3 RLocks per request: middleware.go:213, token_manager.go:349, token_manager.go:408). The clean fix is an atomic.Pointer[*MetadataSnapshot], but generic atomic.Pointer[T] is NOT exposed by yaegi v0.16.1's stdlib (only legacy unsafe.Pointer primitives). atomic.Value would work but requires a snapshot-struct refactor across ~15 call sites (helpers/logout/token_introspection/ token_manager/main/middleware). Deferred to a focused future PR. * isInCooldown multi-field reset race — the cooldown-reset CAS wins on cooldownEndNano, then separately stores attempts/consecutiveFailures/ windowStartNano. A concurrent isInCooldown can briefly see the pre-reset attempts value and trigger a fresh cooldown. Semantic glitch (double-cooldown), not a correctness disaster. Fix is a single atomic pointer swap of an immutable snapshot — same atomic.Pointer constraint as above. Deferred. All tests pass with -race; golangci-lint clean.
763 lines
23 KiB
Go
763 lines
23 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// TestConcurrentRefreshDeduplication verifies that concurrent refresh attempts
|
|
// for the same token are deduplicated and only one refresh operation occurs
|
|
func TestConcurrentRefreshDeduplication(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
config := DefaultRefreshCoordinatorConfig()
|
|
// Keep default delay for this test - it's testing deduplication behavior
|
|
// Disable rate limiting for this test since we're testing deduplication
|
|
config.MaxRefreshAttempts = 1000 // High enough to not interfere
|
|
coordinator := NewRefreshCoordinator(config, logger)
|
|
defer coordinator.Shutdown()
|
|
|
|
// Counter to track actual refresh executions
|
|
var refreshExecutions int32
|
|
|
|
// Mock refresh function
|
|
refreshFunc := func() (*TokenResponse, error) {
|
|
atomic.AddInt32(&refreshExecutions, 1)
|
|
// Simulate some processing time
|
|
time.Sleep(100 * time.Millisecond)
|
|
return &TokenResponse{
|
|
AccessToken: "new_access_token",
|
|
RefreshToken: "new_refresh_token",
|
|
IDToken: "new_id_token",
|
|
ExpiresIn: 3600,
|
|
}, nil
|
|
}
|
|
|
|
// Number of concurrent requests
|
|
numRequests := 100
|
|
var wg sync.WaitGroup
|
|
wg.Add(numRequests)
|
|
|
|
// Channel to collect results
|
|
results := make(chan *TokenResponse, numRequests)
|
|
errors := make(chan error, numRequests)
|
|
|
|
// Launch concurrent refresh attempts with unique identifiers
|
|
refreshToken := fmt.Sprintf("test_refresh_token_%d", time.Now().UnixNano())
|
|
sessionID := fmt.Sprintf("test_session_%d", time.Now().UnixNano())
|
|
|
|
for i := 0; i < numRequests; i++ {
|
|
go func(reqID int) {
|
|
defer wg.Done()
|
|
|
|
ctx := context.Background()
|
|
resp, err := coordinator.CoordinateRefresh(
|
|
ctx,
|
|
sessionID,
|
|
refreshToken,
|
|
refreshFunc,
|
|
)
|
|
|
|
if err != nil {
|
|
errors <- err
|
|
} else {
|
|
results <- resp
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Wait for all goroutines to complete
|
|
wg.Wait()
|
|
close(results)
|
|
close(errors)
|
|
|
|
// Verify results
|
|
actualExecutions := atomic.LoadInt32(&refreshExecutions)
|
|
// Allow for slight timing variations - up to 2 executions is acceptable
|
|
// This can happen when a second goroutine starts just as the first completes
|
|
if actualExecutions > 2 {
|
|
t.Errorf("Expected 1-2 refresh executions, got %d", actualExecutions)
|
|
}
|
|
|
|
// Verify all requests got the same result
|
|
var firstResponse *TokenResponse
|
|
responseCount := 0
|
|
|
|
for resp := range results {
|
|
responseCount++
|
|
if firstResponse == nil {
|
|
firstResponse = resp
|
|
} else {
|
|
// All responses should be identical (same pointer)
|
|
if resp.AccessToken != firstResponse.AccessToken {
|
|
t.Error("Different responses returned for concurrent requests")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check for errors
|
|
errorCount := 0
|
|
for range errors {
|
|
errorCount++
|
|
}
|
|
|
|
if errorCount > 0 {
|
|
t.Errorf("Unexpected errors in concurrent requests: %d", errorCount)
|
|
}
|
|
|
|
if responseCount != numRequests {
|
|
t.Errorf("Expected %d successful responses, got %d", numRequests, responseCount)
|
|
}
|
|
|
|
// Verify metrics
|
|
metrics := coordinator.GetMetrics()
|
|
if deduped, ok := metrics["deduplicated_requests"].(int64); ok {
|
|
// Allow for slight timing variations - at least 98 out of 100 should be deduplicated
|
|
if deduped < int64(numRequests-2) {
|
|
t.Errorf("Expected at least %d deduplicated requests, got %d", numRequests-2, deduped)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestRefreshRateLimiting verifies that refresh attempts are rate-limited per session
|
|
func TestRefreshRateLimiting(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
config := DefaultRefreshCoordinatorConfig()
|
|
config.MaxRefreshAttempts = 3
|
|
config.RefreshAttemptWindow = 1 * time.Second
|
|
config.RefreshCooldownPeriod = 2 * time.Second
|
|
|
|
coordinator := NewRefreshCoordinator(config, logger)
|
|
defer coordinator.Shutdown()
|
|
|
|
// Set circuit breaker to not interfere with rate limiting test
|
|
// We want to test rate limiting, not circuit breaker
|
|
coordinator.circuitBreaker.config.MaxFailures = 10
|
|
|
|
sessionID := "rate_limited_session"
|
|
refreshToken := "test_refresh_token"
|
|
|
|
// Mock refresh function that always fails
|
|
refreshFunc := func() (*TokenResponse, error) {
|
|
return nil, fmt.Errorf("refresh failed")
|
|
}
|
|
|
|
// Attempt refreshes beyond the limit
|
|
var attempts int
|
|
var cooldownTriggered bool
|
|
|
|
for i := 0; i < 5; i++ {
|
|
ctx := context.Background()
|
|
_, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
|
|
|
if err != nil {
|
|
if err.Error() == "refresh attempts exceeded for session, in cooldown period" {
|
|
cooldownTriggered = true
|
|
break
|
|
}
|
|
}
|
|
attempts++
|
|
// Add delay to ensure operations complete and aren't deduplicated
|
|
time.Sleep(150 * time.Millisecond)
|
|
}
|
|
|
|
// Verify that cooldown was triggered after max attempts.
|
|
// With applyLeaderGates checking cooldown BEFORE recording the attempt
|
|
// (the v1.0.16 reorder fixing the thundering-herd off-by-one), N attempts
|
|
// run to completion and the (N+1)th is denied. Previously the Nth was
|
|
// denied as it tried to record, which under burst load let multiple
|
|
// concurrent leaders increment past the limit before any one of them
|
|
// observed the gate.
|
|
expectedSuccessfulAttempts := config.MaxRefreshAttempts
|
|
if attempts != expectedSuccessfulAttempts {
|
|
t.Errorf("Expected %d successful attempts before cooldown, got %d", expectedSuccessfulAttempts, attempts)
|
|
}
|
|
|
|
if !cooldownTriggered {
|
|
t.Error("Cooldown was not triggered after max attempts")
|
|
}
|
|
|
|
// Verify that requests are blocked during cooldown
|
|
ctx := context.Background()
|
|
_, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
|
if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" {
|
|
t.Error("Request should be blocked during cooldown period")
|
|
}
|
|
|
|
// Wait for cooldown to expire
|
|
time.Sleep(config.RefreshCooldownPeriod + 100*time.Millisecond)
|
|
|
|
// Verify that requests are allowed after cooldown
|
|
_, err = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
|
if err != nil && err.Error() == "refresh attempts exceeded for session, in cooldown period" {
|
|
t.Error("Request should be allowed after cooldown period")
|
|
}
|
|
}
|
|
|
|
// TestCircuitBreakerProtection verifies that the circuit breaker prevents
|
|
// cascading failures during repeated refresh failures
|
|
func TestCircuitBreakerProtection(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
config := DefaultRefreshCoordinatorConfig()
|
|
coordinator := NewRefreshCoordinator(config, logger)
|
|
defer coordinator.Shutdown()
|
|
|
|
// Set circuit breaker to trip after 3 failures
|
|
coordinator.circuitBreaker.config.MaxFailures = 3
|
|
coordinator.circuitBreaker.config.OpenDuration = 1 * time.Second
|
|
|
|
// Mock refresh function that always fails
|
|
refreshFunc := func() (*TokenResponse, error) {
|
|
return nil, fmt.Errorf("service unavailable")
|
|
}
|
|
|
|
// Cause circuit breaker to trip
|
|
var tripCount int
|
|
for i := 0; i < 5; i++ {
|
|
ctx := context.Background()
|
|
_, err := coordinator.CoordinateRefresh(
|
|
ctx,
|
|
fmt.Sprintf("session_%d", i), // Different sessions
|
|
"refresh_token",
|
|
refreshFunc,
|
|
)
|
|
|
|
if err != nil && err.Error() == "refresh circuit breaker is open due to repeated failures" {
|
|
tripCount++
|
|
}
|
|
}
|
|
|
|
// Verify circuit breaker tripped
|
|
if tripCount == 0 {
|
|
t.Error("Circuit breaker did not trip after repeated failures")
|
|
}
|
|
|
|
// Verify circuit breaker state
|
|
if coordinator.circuitBreaker.GetState() != "open" {
|
|
t.Errorf("Expected circuit breaker state 'open', got '%s'", coordinator.circuitBreaker.GetState())
|
|
}
|
|
|
|
// Wait for circuit to transition to half-open
|
|
time.Sleep(coordinator.circuitBreaker.config.OpenDuration + 100*time.Millisecond)
|
|
|
|
// Mock successful refresh
|
|
successfulRefreshFunc := func() (*TokenResponse, error) {
|
|
return &TokenResponse{
|
|
AccessToken: "new_token",
|
|
}, nil
|
|
}
|
|
|
|
// Verify circuit allows request in half-open state
|
|
ctx := context.Background()
|
|
_, err := coordinator.CoordinateRefresh(ctx, "session_recovery", "refresh_token", successfulRefreshFunc)
|
|
if err != nil {
|
|
t.Errorf("Circuit breaker should allow request in half-open state: %v", err)
|
|
}
|
|
|
|
// Verify circuit closed after success
|
|
if coordinator.circuitBreaker.GetState() != "closed" {
|
|
t.Errorf("Expected circuit breaker state 'closed' after successful request, got '%s'",
|
|
coordinator.circuitBreaker.GetState())
|
|
}
|
|
}
|
|
|
|
// TestMemoryLeakPrevention verifies that the coordinator doesn't leak memory
|
|
// during sustained concurrent refresh operations
|
|
func TestMemoryLeakPrevention(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping memory leak test in short mode")
|
|
}
|
|
|
|
logger := GetSingletonNoOpLogger()
|
|
config := DefaultRefreshCoordinatorConfig()
|
|
config.CleanupInterval = 100 * time.Millisecond
|
|
config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior
|
|
coordinator := NewRefreshCoordinator(config, logger)
|
|
defer coordinator.Shutdown()
|
|
|
|
// Force garbage collection and record initial memory
|
|
runtime.GC()
|
|
runtime.GC()
|
|
var initialMem runtime.MemStats
|
|
runtime.ReadMemStats(&initialMem)
|
|
|
|
// Run sustained concurrent operations
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
var wg sync.WaitGroup
|
|
numWorkers := 10
|
|
wg.Add(numWorkers)
|
|
|
|
// Each worker continuously attempts refreshes
|
|
for i := 0; i < numWorkers; i++ {
|
|
go func(workerID int) {
|
|
defer wg.Done()
|
|
|
|
refreshCount := 0
|
|
refreshFunc := func() (*TokenResponse, error) {
|
|
// Simulate varying response times
|
|
time.Sleep(time.Duration(workerID*10) * time.Millisecond)
|
|
return &TokenResponse{
|
|
AccessToken: fmt.Sprintf("token_%d_%d", workerID, refreshCount),
|
|
RefreshToken: fmt.Sprintf("refresh_%d_%d", workerID, refreshCount),
|
|
}, nil
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
sessionID := fmt.Sprintf("session_%d", workerID)
|
|
refreshToken := fmt.Sprintf("refresh_%d_%d", workerID, refreshCount)
|
|
|
|
_, _ = coordinator.CoordinateRefresh(
|
|
context.Background(),
|
|
sessionID,
|
|
refreshToken,
|
|
refreshFunc,
|
|
)
|
|
|
|
refreshCount++
|
|
// Small delay to prevent CPU saturation
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Wait for workers to complete
|
|
wg.Wait()
|
|
|
|
// Allow cleanup to run
|
|
time.Sleep(2 * config.CleanupInterval)
|
|
|
|
// Force garbage collection and check memory
|
|
runtime.GC()
|
|
runtime.GC()
|
|
var finalMem runtime.MemStats
|
|
runtime.ReadMemStats(&finalMem)
|
|
|
|
// Calculate memory growth safely to prevent underflow
|
|
var memGrowthMB float64
|
|
if finalMem.HeapAlloc >= initialMem.HeapAlloc {
|
|
memGrowthMB = float64(finalMem.HeapAlloc-initialMem.HeapAlloc) / (1024 * 1024)
|
|
} else {
|
|
// Memory decreased (GC occurred), treat as 0 growth
|
|
memGrowthMB = 0
|
|
}
|
|
|
|
// Log memory statistics for debugging
|
|
t.Logf("Initial memory: %.2f MB", float64(initialMem.HeapAlloc)/(1024*1024))
|
|
t.Logf("Final memory: %.2f MB", float64(finalMem.HeapAlloc)/(1024*1024))
|
|
t.Logf("Memory growth: %.2f MB", memGrowthMB)
|
|
|
|
// Check for excessive memory growth (threshold: 50MB)
|
|
if memGrowthMB > 50 {
|
|
t.Errorf("Excessive memory growth detected: %.2f MB", memGrowthMB)
|
|
}
|
|
|
|
// Verify no lingering operations
|
|
metrics := coordinator.GetMetrics()
|
|
if inflight, ok := metrics["current_inflight"].(int32); ok {
|
|
if inflight != 0 {
|
|
t.Errorf("Expected 0 in-flight operations after completion, got %d", inflight)
|
|
}
|
|
}
|
|
|
|
// Verify cleanup is working. sync.Map has no Len(); count via Range.
|
|
sessionCount := 0
|
|
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
|
sessionCount++
|
|
return true
|
|
})
|
|
|
|
// Should have cleaned up old sessions (only recent ones remain)
|
|
if sessionCount > numWorkers*2 {
|
|
t.Errorf("Session cleanup not working properly, %d sessions remain", sessionCount)
|
|
}
|
|
}
|
|
|
|
// TestRefreshTimeoutHandling verifies that refresh operations timeout properly
|
|
func TestRefreshTimeoutHandling(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
config := DefaultRefreshCoordinatorConfig()
|
|
config.RefreshTimeout = 100 * time.Millisecond
|
|
coordinator := NewRefreshCoordinator(config, logger)
|
|
defer coordinator.Shutdown()
|
|
|
|
// Mock refresh function that hangs
|
|
refreshFunc := func() (*TokenResponse, error) {
|
|
time.Sleep(1 * time.Second) // Much longer than timeout
|
|
return &TokenResponse{AccessToken: "token"}, nil
|
|
}
|
|
|
|
ctx := context.Background()
|
|
start := time.Now()
|
|
|
|
_, err := coordinator.CoordinateRefresh(ctx, "session", "refresh_token", refreshFunc)
|
|
|
|
elapsed := time.Since(start)
|
|
|
|
// Verify timeout occurred
|
|
if err == nil {
|
|
t.Error("Expected timeout error, got nil")
|
|
}
|
|
|
|
// Verify it timed out within reasonable bounds
|
|
if elapsed > 200*time.Millisecond {
|
|
t.Errorf("Timeout took too long: %v", elapsed)
|
|
}
|
|
|
|
if err != nil && err.Error() != fmt.Sprintf("refresh operation timed out after %v", config.RefreshTimeout) {
|
|
t.Errorf("Unexpected error message: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestConcurrentDifferentTokens verifies that refreshes for different tokens
|
|
// proceed independently without blocking each other
|
|
func TestConcurrentDifferentTokens(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
config := DefaultRefreshCoordinatorConfig()
|
|
coordinator := NewRefreshCoordinator(config, logger)
|
|
defer coordinator.Shutdown()
|
|
|
|
numTokens := 10
|
|
var wg sync.WaitGroup
|
|
wg.Add(numTokens)
|
|
|
|
// Track execution order
|
|
executionOrder := make([]int, 0, numTokens)
|
|
var executionMutex sync.Mutex
|
|
|
|
for i := 0; i < numTokens; i++ {
|
|
go func(tokenID int) {
|
|
defer wg.Done()
|
|
|
|
refreshFunc := func() (*TokenResponse, error) {
|
|
executionMutex.Lock()
|
|
executionOrder = append(executionOrder, tokenID)
|
|
executionMutex.Unlock()
|
|
|
|
// Varying processing times
|
|
time.Sleep(time.Duration(tokenID*10) * time.Millisecond)
|
|
|
|
return &TokenResponse{
|
|
AccessToken: fmt.Sprintf("token_%d", tokenID),
|
|
RefreshToken: fmt.Sprintf("refresh_%d", tokenID),
|
|
}, nil
|
|
}
|
|
|
|
ctx := context.Background()
|
|
resp, err := coordinator.CoordinateRefresh(
|
|
ctx,
|
|
fmt.Sprintf("session_%d", tokenID),
|
|
fmt.Sprintf("refresh_token_%d", tokenID),
|
|
refreshFunc,
|
|
)
|
|
|
|
if err != nil {
|
|
t.Errorf("Token %d refresh failed: %v", tokenID, err)
|
|
}
|
|
|
|
if resp == nil || resp.AccessToken != fmt.Sprintf("token_%d", tokenID) {
|
|
t.Errorf("Token %d got wrong response", tokenID)
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
// Verify all tokens were processed
|
|
if len(executionOrder) != numTokens {
|
|
t.Errorf("Expected %d executions, got %d", numTokens, len(executionOrder))
|
|
}
|
|
|
|
// Verify no deduplication occurred (all different tokens)
|
|
metrics := coordinator.GetMetrics()
|
|
if deduped, ok := metrics["deduplicated_requests"].(int64); ok {
|
|
if deduped != 0 {
|
|
t.Errorf("No deduplication expected for different tokens, got %d", deduped)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestMaxConcurrentRefreshes verifies that the coordinator respects
|
|
// the maximum concurrent refresh limit
|
|
func TestMaxConcurrentRefreshes(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
config := DefaultRefreshCoordinatorConfig()
|
|
config.MaxConcurrentRefreshes = 2
|
|
coordinator := NewRefreshCoordinator(config, logger)
|
|
defer coordinator.Shutdown()
|
|
|
|
// Track concurrent executions
|
|
var currentConcurrent int32
|
|
var maxConcurrent int32
|
|
|
|
refreshFunc := func() (*TokenResponse, error) {
|
|
current := atomic.AddInt32(¤tConcurrent, 1)
|
|
|
|
// Update max if needed
|
|
for {
|
|
max := atomic.LoadInt32(&maxConcurrent)
|
|
if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) {
|
|
break
|
|
}
|
|
}
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
atomic.AddInt32(¤tConcurrent, -1)
|
|
|
|
return &TokenResponse{AccessToken: "token"}, nil
|
|
}
|
|
|
|
numRequests := 10
|
|
var wg sync.WaitGroup
|
|
wg.Add(numRequests)
|
|
|
|
errors := make([]error, 0, numRequests)
|
|
var errorMutex sync.Mutex
|
|
|
|
for i := 0; i < numRequests; i++ {
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
|
|
ctx := context.Background()
|
|
_, err := coordinator.CoordinateRefresh(
|
|
ctx,
|
|
fmt.Sprintf("session_%d", id),
|
|
fmt.Sprintf("token_%d", id),
|
|
refreshFunc,
|
|
)
|
|
|
|
if err != nil {
|
|
errorMutex.Lock()
|
|
errors = append(errors, err)
|
|
errorMutex.Unlock()
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
// Some requests should have been rejected due to concurrency limit
|
|
if len(errors) == 0 {
|
|
t.Error("Expected some requests to be rejected due to concurrency limit")
|
|
}
|
|
|
|
// Verify max concurrent never exceeded limit
|
|
if maxConcurrent > int32(config.MaxConcurrentRefreshes) {
|
|
t.Errorf("Max concurrent refreshes (%d) exceeded limit (%d)",
|
|
maxConcurrent, config.MaxConcurrentRefreshes)
|
|
}
|
|
}
|
|
|
|
// TestSessionWindowReset verifies that refresh attempt windows reset properly
|
|
func TestSessionWindowReset(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
config := DefaultRefreshCoordinatorConfig()
|
|
config.MaxRefreshAttempts = 2
|
|
config.RefreshAttemptWindow = 500 * time.Millisecond
|
|
config.RefreshCooldownPeriod = 2 * time.Second // Explicitly set cooldown > window
|
|
config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior
|
|
|
|
coordinator := NewRefreshCoordinator(config, logger)
|
|
defer coordinator.Shutdown()
|
|
|
|
// Set circuit breaker to not interfere with rate limiting test
|
|
coordinator.circuitBreaker.config.MaxFailures = 10
|
|
|
|
// Use unique identifiers to prevent test interference
|
|
sessionID := fmt.Sprintf("window_test_session_%d", time.Now().UnixNano())
|
|
refreshToken := fmt.Sprintf("test_refresh_token_%d", time.Now().UnixNano())
|
|
|
|
// Mock refresh function that always fails
|
|
refreshFunc := func() (*TokenResponse, error) {
|
|
return nil, fmt.Errorf("refresh failed")
|
|
}
|
|
|
|
// Use up the attempts in the first window
|
|
for i := 0; i < config.MaxRefreshAttempts; i++ {
|
|
ctx := context.Background()
|
|
_, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
|
// Add small delay to ensure attempts are registered separately
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
|
|
// Next attempt should trigger cooldown
|
|
ctx := context.Background()
|
|
_, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
|
if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" {
|
|
t.Errorf("Expected cooldown after max attempts, got: %v", err)
|
|
}
|
|
|
|
// Wait for window to expire (but not cooldown)
|
|
// Use generous buffer for CI environments
|
|
time.Sleep(config.RefreshAttemptWindow + 200*time.Millisecond)
|
|
|
|
// Should still be in cooldown (cooldown=2s > window=500ms)
|
|
_, err = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
|
if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" {
|
|
t.Errorf("Should still be in cooldown period after window expiry, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// BenchmarkConcurrentRefreshDeduplication measures performance of deduplication
|
|
func BenchmarkConcurrentRefreshDeduplication(b *testing.B) {
|
|
logger := GetSingletonNoOpLogger()
|
|
config := DefaultRefreshCoordinatorConfig()
|
|
coordinator := NewRefreshCoordinator(config, logger)
|
|
defer coordinator.Shutdown()
|
|
|
|
refreshFunc := func() (*TokenResponse, error) {
|
|
time.Sleep(10 * time.Millisecond)
|
|
return &TokenResponse{
|
|
AccessToken: "token",
|
|
}, nil
|
|
}
|
|
|
|
b.ResetTimer()
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
i := 0
|
|
for pb.Next() {
|
|
ctx := context.Background()
|
|
sessionID := fmt.Sprintf("session_%d", i%10) // Reuse 10 sessions
|
|
refreshToken := fmt.Sprintf("token_%d", i%10) // Reuse 10 tokens
|
|
_, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
|
i++
|
|
}
|
|
})
|
|
|
|
b.StopTimer()
|
|
|
|
// Report metrics
|
|
metrics := coordinator.GetMetrics()
|
|
b.Logf("Total requests: %v", metrics["total_requests"])
|
|
b.Logf("Deduplicated: %v", metrics["deduplicated_requests"])
|
|
}
|
|
|
|
// TestCleanupRoutine verifies that the cleanup routine removes stale entries
|
|
func TestCleanupRoutine(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
config := DefaultRefreshCoordinatorConfig()
|
|
config.CleanupInterval = 100 * time.Millisecond
|
|
config.RefreshAttemptWindow = 200 * time.Millisecond
|
|
|
|
coordinator := NewRefreshCoordinator(config, logger)
|
|
defer coordinator.Shutdown()
|
|
|
|
// Add some sessions
|
|
for i := 0; i < 5; i++ {
|
|
coordinator.recordRefreshAttempt(fmt.Sprintf("session_%d", i))
|
|
}
|
|
|
|
countSessions := func() int {
|
|
n := 0
|
|
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
|
n++
|
|
return true
|
|
})
|
|
return n
|
|
}
|
|
|
|
if initialCount := countSessions(); initialCount != 5 {
|
|
t.Errorf("Expected 5 sessions, got %d", initialCount)
|
|
}
|
|
|
|
// Wait for cleanup to run (2x window + cleanup interval)
|
|
time.Sleep(2*config.RefreshAttemptWindow + 2*config.CleanupInterval)
|
|
|
|
if finalCount := countSessions(); finalCount != 0 {
|
|
t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount)
|
|
}
|
|
}
|
|
|
|
// TestNoGoroutineExplosionWithTimers verifies that timer-based cleanup doesn't cause goroutine explosion
|
|
// This was the original issue: spawning a goroutine per refresh to sleep and cleanup
|
|
func TestNoGoroutineExplosionWithTimers(t *testing.T) {
|
|
logger := GetSingletonNoOpLogger()
|
|
config := DefaultRefreshCoordinatorConfig()
|
|
config.DeduplicationCleanupDelay = 100 * time.Millisecond // Non-zero delay
|
|
config.MaxConcurrentRefreshes = 100 // Allow many concurrent
|
|
config.MaxRefreshAttempts = 10000 // Don't rate limit
|
|
|
|
coordinator := NewRefreshCoordinator(config, logger)
|
|
defer coordinator.Shutdown()
|
|
|
|
// Record initial goroutines (allow settling time)
|
|
time.Sleep(50 * time.Millisecond)
|
|
runtime.GC()
|
|
initialGoroutines := runtime.NumGoroutine()
|
|
t.Logf("Initial goroutines: %d", initialGoroutines)
|
|
|
|
// Submit many refresh operations rapidly
|
|
const numRefreshes = 500
|
|
var wg sync.WaitGroup
|
|
wg.Add(numRefreshes)
|
|
|
|
refreshFunc := func() (*TokenResponse, error) {
|
|
return &TokenResponse{AccessToken: "token"}, nil
|
|
}
|
|
|
|
for i := 0; i < numRefreshes; i++ {
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
ctx := context.Background()
|
|
_, _ = coordinator.CoordinateRefresh(
|
|
ctx,
|
|
fmt.Sprintf("session_%d", id),
|
|
fmt.Sprintf("token_%d", id),
|
|
refreshFunc,
|
|
)
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
// Measure goroutines immediately after all operations complete
|
|
// With the old approach, we'd have ~500 sleeping goroutines
|
|
// With the new timer approach, we should have much fewer
|
|
currentGoroutines := runtime.NumGoroutine()
|
|
t.Logf("Goroutines after %d refresh operations: %d", numRefreshes, currentGoroutines)
|
|
|
|
// (Coordinator no longer tracks pending timers; time.AfterFunc closures
|
|
// fire performCleanup directly. This test now only checks the goroutine
|
|
// budget, which was always the real invariant.)
|
|
|
|
// With timer-based cleanup, goroutine increase should be minimal
|
|
// Timers don't create goroutines - they use the runtime timer heap
|
|
goroutineIncrease := currentGoroutines - initialGoroutines
|
|
|
|
// Allow for some goroutine overhead (test framework, etc)
|
|
// With the old approach, we'd see ~500 goroutines
|
|
// With the new approach, we should see <50 (much smaller)
|
|
maxAcceptableIncrease := 100 // Very generous limit
|
|
|
|
if goroutineIncrease > maxAcceptableIncrease {
|
|
t.Errorf("Goroutine explosion detected: started with %d, now have %d (increase of %d)",
|
|
initialGoroutines, currentGoroutines, goroutineIncrease)
|
|
}
|
|
|
|
// Wait for timers to fire and cleanup.
|
|
time.Sleep(config.DeduplicationCleanupDelay + 50*time.Millisecond)
|
|
|
|
// Verify goroutines returned to near initial
|
|
runtime.GC()
|
|
time.Sleep(50 * time.Millisecond)
|
|
finalGoroutines := runtime.NumGoroutine()
|
|
t.Logf("Final goroutines: %d", finalGoroutines)
|
|
|
|
// Should be close to initial (within tolerance)
|
|
finalIncrease := finalGoroutines - initialGoroutines
|
|
if finalIncrease > 20 {
|
|
t.Errorf("Goroutine leak detected: started with %d, ended with %d (increase of %d)",
|
|
initialGoroutines, finalGoroutines, finalIncrease)
|
|
}
|
|
}
|