mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Fix recursion in token resilience logic
This commit is contained in:
@@ -539,3 +539,233 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
|
||||
metrics := coordinator.GetMetrics()
|
||||
t.Logf("Final metrics: %+v", metrics)
|
||||
}
|
||||
|
||||
// TestIssue67_TokenResilienceRecursionBug directly tests the recursion bug identified by jetexe
|
||||
// in the comment: https://github.com/lukaszraczylo/traefikoidc/issues/67#issuecomment-2391821890
|
||||
//
|
||||
// The bug is in token_resilience.go:180-190 where ExecuteTokenRefresh calls
|
||||
// getNewTokenWithRefreshToken which calls ExecuteTokenRefresh again, causing infinite recursion.
|
||||
func TestIssue67_TokenResilienceRecursionBug(t *testing.T) {
|
||||
// Track call depth to detect recursion
|
||||
var callDepth int32
|
||||
var maxDepth int32 = 5 // If we reach this, we have recursion
|
||||
|
||||
// Create mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/token" {
|
||||
// Increment call depth
|
||||
depth := atomic.AddInt32(&callDepth, 1)
|
||||
defer atomic.AddInt32(&callDepth, -1)
|
||||
|
||||
// Check if we've exceeded max depth (indicates recursion)
|
||||
if depth > maxDepth {
|
||||
t.Errorf("Call depth exceeded %d - infinite recursion detected!", maxDepth)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Simulate successful token refresh
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{
|
||||
"access_token": "new_access_token",
|
||||
"refresh_token": "new_refresh_token",
|
||||
"id_token": "new_id_token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer"
|
||||
}`))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create TraefikOidc with tokenResilienceManager (this triggers the bug)
|
||||
logger := GetSingletonNoOpLogger()
|
||||
resilienceConfig := DefaultTokenResilienceConfig()
|
||||
resilienceManager := NewTokenResilienceManager(resilienceConfig, logger)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenResilienceManager: resilienceManager,
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Create context with timeout to prevent hanging
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Run in goroutine to detect stack overflow
|
||||
done := make(chan struct{})
|
||||
var testErr error
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
testErr = fmt.Errorf("panic recovered: %v (likely stack overflow from recursion)", r)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// This call should NOT recurse infinitely after the fix
|
||||
_, err := oidc.getNewTokenWithRefreshToken("test_refresh_token")
|
||||
if err != nil {
|
||||
testErr = err
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for completion or timeout
|
||||
select {
|
||||
case <-done:
|
||||
// Check for recursion via call depth
|
||||
if atomic.LoadInt32(&callDepth) > maxDepth {
|
||||
t.Fatal("Infinite recursion detected via call depth counter")
|
||||
}
|
||||
|
||||
// Check for panic/stack overflow
|
||||
if testErr != nil && strings.Contains(testErr.Error(), "stack overflow") {
|
||||
t.Fatalf("Stack overflow detected: %v", testErr)
|
||||
}
|
||||
|
||||
// After fix, this should succeed
|
||||
if testErr != nil {
|
||||
t.Logf("Token refresh completed with error: %v", testErr)
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Test timed out - likely infinite recursion in getNewTokenWithRefreshToken -> ExecuteTokenRefresh loop")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue67_TokenResilienceManager_NoRecursion verifies ExecuteTokenRefresh
|
||||
// calls exchangeTokens directly and doesn't recurse back to getNewTokenWithRefreshToken
|
||||
func TestIssue67_TokenResilienceManager_NoRecursion(t *testing.T) {
|
||||
var exchangeTokensCalls int32
|
||||
var getNewTokenCalls int32
|
||||
|
||||
// Create mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&exchangeTokensCalls, 1)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{
|
||||
"access_token": "test_token",
|
||||
"refresh_token": "test_refresh",
|
||||
"id_token": "test_id",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer"
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create TraefikOidc with instrumented methods
|
||||
logger := GetSingletonNoOpLogger()
|
||||
resilienceConfig := DefaultTokenResilienceConfig()
|
||||
resilienceManager := NewTokenResilienceManager(resilienceConfig, logger)
|
||||
|
||||
// Create custom TraefikOidc to track calls
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenResilienceManager: resilienceManager,
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Wrap getNewTokenWithRefreshToken to count calls
|
||||
originalGetNewToken := oidc.getNewTokenWithRefreshToken
|
||||
wrappedGetNewToken := func(refreshToken string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&getNewTokenCalls, 1)
|
||||
return originalGetNewToken(refreshToken)
|
||||
}
|
||||
_ = wrappedGetNewToken // Use the wrapper
|
||||
|
||||
// Execute token refresh through resilience manager
|
||||
ctx := context.Background()
|
||||
_, err := resilienceManager.ExecuteTokenRefresh(ctx, oidc, "test_refresh_token")
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Token refresh returned error (may be expected): %v", err)
|
||||
}
|
||||
|
||||
// Verify exchangeTokens was called
|
||||
exchangeCalls := atomic.LoadInt32(&exchangeTokensCalls)
|
||||
if exchangeCalls == 0 {
|
||||
t.Error("exchangeTokens was never called")
|
||||
}
|
||||
|
||||
t.Logf("exchangeTokens called %d times", exchangeCalls)
|
||||
|
||||
// After the fix, ExecuteTokenRefresh should call exchangeTokens directly
|
||||
// and NOT call getNewTokenWithRefreshToken (which would cause recursion)
|
||||
}
|
||||
|
||||
// TestIssue67_DirectRecursionDetection uses a simpler approach to detect the recursion
|
||||
func TestIssue67_DirectRecursionDetection(t *testing.T) {
|
||||
// This test will fail BEFORE the fix and pass AFTER the fix
|
||||
|
||||
var recursionDepth int32
|
||||
const maxAllowedDepth = 3
|
||||
|
||||
// Create a simple mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
depth := atomic.AddInt32(&recursionDepth, 1)
|
||||
defer atomic.AddInt32(&recursionDepth, -1)
|
||||
|
||||
if depth > maxAllowedDepth {
|
||||
// Recursion detected - fail fast
|
||||
t.Errorf("RECURSION BUG DETECTED: depth=%d exceeds max=%d", depth, maxAllowedDepth)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"access_token":"test","refresh_token":"test","id_token":"test","expires_in":3600,"token_type":"Bearer"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultTokenResilienceConfig()
|
||||
config.RetryEnabled = false // Disable retries to make the test clearer
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test",
|
||||
clientSecret: "test",
|
||||
tokenResilienceManager: NewTokenResilienceManager(config, logger),
|
||||
tokenHTTPClient: &http.Client{Timeout: 2 * time.Second},
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Set a timeout to prevent infinite hangs
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := oidc.getNewTokenWithRefreshToken("test_token")
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
finalDepth := atomic.LoadInt32(&recursionDepth)
|
||||
if finalDepth > maxAllowedDepth {
|
||||
t.Fatalf("Recursion bug confirmed: max depth reached %d", finalDepth)
|
||||
}
|
||||
if err != nil {
|
||||
t.Logf("Completed with error: %v", err)
|
||||
} else {
|
||||
t.Log("Token refresh completed successfully without recursion")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Fatal("RECURSION BUG: Test timed out, indicating infinite loop in getNewTokenWithRefreshToken -> ExecuteTokenRefresh")
|
||||
}
|
||||
}
|
||||
|
||||
+3
-1
@@ -182,7 +182,9 @@ func (trm *TokenResilienceManager) ExecuteTokenRefresh(ctx context.Context, t *T
|
||||
var err error
|
||||
|
||||
err = trm.ExecuteTokenOperation(ctx, "token_refresh", func() error {
|
||||
result, err = t.getNewTokenWithRefreshToken(refreshToken)
|
||||
// Call exchangeTokens directly to avoid recursion back to getNewTokenWithRefreshToken
|
||||
// which would call ExecuteTokenRefresh again, causing infinite loop (issue #67)
|
||||
result, err = t.exchangeTokens(ctx, "refresh_token", refreshToken, "", "")
|
||||
return err
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user