From 79d34ea4c99d39699ba0cf1096a18a84f88d8ca5 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Tue, 7 Oct 2025 10:34:15 +0100 Subject: [PATCH] Fix recursion in token resilience logic (#72) --- issue67_regression_test.go | 230 +++++++++++++++++++++++++++++++++++++ token_resilience.go | 4 +- 2 files changed, 233 insertions(+), 1 deletion(-) diff --git a/issue67_regression_test.go b/issue67_regression_test.go index 7e58d57..0afe5af 100644 --- a/issue67_regression_test.go +++ b/issue67_regression_test.go @@ -539,3 +539,233 @@ func TestRefreshCoordinatorIntegration(t *testing.T) { metrics := coordinator.GetMetrics() t.Logf("Final metrics: %+v", metrics) } + +// TestIssue67_TokenResilienceRecursionBug directly tests the recursion bug identified by jetexe +// in the comment: https://github.com/lukaszraczylo/traefikoidc/issues/67#issuecomment-2391821890 +// +// The bug is in token_resilience.go:180-190 where ExecuteTokenRefresh calls +// getNewTokenWithRefreshToken which calls ExecuteTokenRefresh again, causing infinite recursion. +func TestIssue67_TokenResilienceRecursionBug(t *testing.T) { + // Track call depth to detect recursion + var callDepth int32 + var maxDepth int32 = 5 // If we reach this, we have recursion + + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/token" { + // Increment call depth + depth := atomic.AddInt32(&callDepth, 1) + defer atomic.AddInt32(&callDepth, -1) + + // Check if we've exceeded max depth (indicates recursion) + if depth > maxDepth { + t.Errorf("Call depth exceeded %d - infinite recursion detected!", maxDepth) + w.WriteHeader(http.StatusInternalServerError) + return + } + + // Simulate successful token refresh + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "id_token": "new_id_token", + "expires_in": 3600, + "token_type": "Bearer" + }`)) + } + })) + defer server.Close() + + // Create TraefikOidc with tokenResilienceManager (this triggers the bug) + logger := GetSingletonNoOpLogger() + resilienceConfig := DefaultTokenResilienceConfig() + resilienceManager := NewTokenResilienceManager(resilienceConfig, logger) + + oidc := &TraefikOidc{ + tokenURL: server.URL + "/token", + clientID: "test_client", + clientSecret: "test_secret", + tokenResilienceManager: resilienceManager, + tokenHTTPClient: &http.Client{ + Timeout: 5 * time.Second, + }, + logger: logger, + } + + // Create context with timeout to prevent hanging + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // Run in goroutine to detect stack overflow + done := make(chan struct{}) + var testErr error + + go func() { + defer func() { + if r := recover(); r != nil { + testErr = fmt.Errorf("panic recovered: %v (likely stack overflow from recursion)", r) + } + close(done) + }() + + // This call should NOT recurse infinitely after the fix + _, err := oidc.getNewTokenWithRefreshToken("test_refresh_token") + if err != nil { + testErr = err + } + }() + + // Wait for completion or timeout + select { + case <-done: + // Check for recursion via call depth + if atomic.LoadInt32(&callDepth) > maxDepth { + t.Fatal("Infinite recursion detected via call depth counter") + } + + // Check for panic/stack overflow + if testErr != nil && strings.Contains(testErr.Error(), "stack overflow") { + t.Fatalf("Stack overflow detected: %v", testErr) + } + + // After fix, this should succeed + if testErr != nil { + t.Logf("Token refresh completed with error: %v", testErr) + } + + case <-ctx.Done(): + t.Fatal("Test timed out - likely infinite recursion in getNewTokenWithRefreshToken -> ExecuteTokenRefresh loop") + } +} + +// TestIssue67_TokenResilienceManager_NoRecursion verifies ExecuteTokenRefresh +// calls exchangeTokens directly and doesn't recurse back to getNewTokenWithRefreshToken +func TestIssue67_TokenResilienceManager_NoRecursion(t *testing.T) { + var exchangeTokensCalls int32 + var getNewTokenCalls int32 + + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&exchangeTokensCalls, 1) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "access_token": "test_token", + "refresh_token": "test_refresh", + "id_token": "test_id", + "expires_in": 3600, + "token_type": "Bearer" + }`)) + })) + defer server.Close() + + // Create TraefikOidc with instrumented methods + logger := GetSingletonNoOpLogger() + resilienceConfig := DefaultTokenResilienceConfig() + resilienceManager := NewTokenResilienceManager(resilienceConfig, logger) + + // Create custom TraefikOidc to track calls + oidc := &TraefikOidc{ + tokenURL: server.URL + "/token", + clientID: "test_client", + clientSecret: "test_secret", + tokenResilienceManager: resilienceManager, + tokenHTTPClient: &http.Client{ + Timeout: 5 * time.Second, + }, + logger: logger, + } + + // Wrap getNewTokenWithRefreshToken to count calls + originalGetNewToken := oidc.getNewTokenWithRefreshToken + wrappedGetNewToken := func(refreshToken string) (*TokenResponse, error) { + atomic.AddInt32(&getNewTokenCalls, 1) + return originalGetNewToken(refreshToken) + } + _ = wrappedGetNewToken // Use the wrapper + + // Execute token refresh through resilience manager + ctx := context.Background() + _, err := resilienceManager.ExecuteTokenRefresh(ctx, oidc, "test_refresh_token") + + if err != nil { + t.Logf("Token refresh returned error (may be expected): %v", err) + } + + // Verify exchangeTokens was called + exchangeCalls := atomic.LoadInt32(&exchangeTokensCalls) + if exchangeCalls == 0 { + t.Error("exchangeTokens was never called") + } + + t.Logf("exchangeTokens called %d times", exchangeCalls) + + // After the fix, ExecuteTokenRefresh should call exchangeTokens directly + // and NOT call getNewTokenWithRefreshToken (which would cause recursion) +} + +// TestIssue67_DirectRecursionDetection uses a simpler approach to detect the recursion +func TestIssue67_DirectRecursionDetection(t *testing.T) { + // This test will fail BEFORE the fix and pass AFTER the fix + + var recursionDepth int32 + const maxAllowedDepth = 3 + + // Create a simple mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + depth := atomic.AddInt32(&recursionDepth, 1) + defer atomic.AddInt32(&recursionDepth, -1) + + if depth > maxAllowedDepth { + // Recursion detected - fail fast + t.Errorf("RECURSION BUG DETECTED: depth=%d exceeds max=%d", depth, maxAllowedDepth) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"test","refresh_token":"test","id_token":"test","expires_in":3600,"token_type":"Bearer"}`)) + })) + defer server.Close() + + logger := GetSingletonNoOpLogger() + config := DefaultTokenResilienceConfig() + config.RetryEnabled = false // Disable retries to make the test clearer + + oidc := &TraefikOidc{ + tokenURL: server.URL + "/token", + clientID: "test", + clientSecret: "test", + tokenResilienceManager: NewTokenResilienceManager(config, logger), + tokenHTTPClient: &http.Client{Timeout: 2 * time.Second}, + logger: logger, + } + + // Set a timeout to prevent infinite hangs + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + _, err := oidc.getNewTokenWithRefreshToken("test_token") + done <- err + }() + + select { + case err := <-done: + finalDepth := atomic.LoadInt32(&recursionDepth) + if finalDepth > maxAllowedDepth { + t.Fatalf("Recursion bug confirmed: max depth reached %d", finalDepth) + } + if err != nil { + t.Logf("Completed with error: %v", err) + } else { + t.Log("Token refresh completed successfully without recursion") + } + case <-ctx.Done(): + t.Fatal("RECURSION BUG: Test timed out, indicating infinite loop in getNewTokenWithRefreshToken -> ExecuteTokenRefresh") + } +} diff --git a/token_resilience.go b/token_resilience.go index 566c5bb..b93f042 100644 --- a/token_resilience.go +++ b/token_resilience.go @@ -182,7 +182,9 @@ func (trm *TokenResilienceManager) ExecuteTokenRefresh(ctx context.Context, t *T var err error err = trm.ExecuteTokenOperation(ctx, "token_refresh", func() error { - result, err = t.getNewTokenWithRefreshToken(refreshToken) + // Call exchangeTokens directly to avoid recursion back to getNewTokenWithRefreshToken + // which would call ExecuteTokenRefresh again, causing infinite loop (issue #67) + result, err = t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "") return err })