mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b07247f674 | |||
| 1e4142a7fb |
@@ -0,0 +1,920 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// generateLargeRealisticToken creates a realistic JWT token with a large payload
|
||||
// that mimics real-world OAuth tokens but with enough data to test chunking
|
||||
func generateLargeRealisticToken() string {
|
||||
// Create a realistic JWT header
|
||||
header := map[string]interface{}{
|
||||
"alg": "RS256",
|
||||
"typ": "JWT",
|
||||
"kid": "test-key-id",
|
||||
}
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
|
||||
// Create a large but realistic payload with many claims
|
||||
claims := map[string]interface{}{
|
||||
"iss": "https://auth.example.com/",
|
||||
"sub": "auth0|507f1f77bcf86cd799439011",
|
||||
"aud": []string{"https://api.example.com", "https://app.example.com"},
|
||||
"iat": 1516239022,
|
||||
"exp": 1516325422,
|
||||
"azp": "my_client_id",
|
||||
"scope": "openid profile email read:users write:users admin",
|
||||
"gty": "client-credentials",
|
||||
}
|
||||
|
||||
// Add many custom claims to make the token large
|
||||
for i := 0; i < 100; i++ {
|
||||
claimName := fmt.Sprintf("custom_claim_%d", i)
|
||||
claimValue := fmt.Sprintf("This is a test value for claim %d with some additional data to make it larger", i)
|
||||
claims[claimName] = claimValue
|
||||
}
|
||||
|
||||
// Add some array claims with multiple values
|
||||
claims["permissions"] = []string{
|
||||
"read:users", "write:users", "delete:users", "create:users",
|
||||
"read:posts", "write:posts", "delete:posts", "create:posts",
|
||||
"admin:all", "super:admin", "system:manage", "audit:view",
|
||||
}
|
||||
|
||||
claims["groups"] = []string{
|
||||
"administrators", "developers", "qa_team", "devops",
|
||||
"product_managers", "support_team", "security_team",
|
||||
}
|
||||
|
||||
payloadJSON, _ := json.Marshal(claims)
|
||||
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
|
||||
// Create a mock signature (in real scenario this would be cryptographic)
|
||||
signature := base64.RawURLEncoding.EncodeToString(
|
||||
[]byte("mock_signature_with_some_additional_bytes_for_testing_purposes"))
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signature)
|
||||
}
|
||||
|
||||
// TestAuth0RedirectLoopFix tests the fixes applied to prevent Auth0 redirect loops
|
||||
// specifically focusing on:
|
||||
// 1. Consistent cookie configuration (Path="/", SameSite=Lax)
|
||||
// 2. CSRF token accessibility during OAuth callbacks
|
||||
// 3. Session cookie persistence across OAuth flow
|
||||
// 4. Redirect loop prevention
|
||||
func TestAuth0RedirectLoopFix(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
encryptionKey := "0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
sm, err := NewSessionManager(encryptionKey, false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
defer sm.Shutdown()
|
||||
|
||||
t.Run("CookieConfigurationConsistency", func(t *testing.T) {
|
||||
testCookieConfigurationConsistency(t, sm)
|
||||
})
|
||||
|
||||
t.Run("CSRFTokenAccessibility", func(t *testing.T) {
|
||||
testCSRFTokenAccessibility(t, sm)
|
||||
})
|
||||
|
||||
t.Run("SessionPersistenceAcrossOAuth", func(t *testing.T) {
|
||||
testSessionPersistenceAcrossOAuth(t, sm)
|
||||
})
|
||||
|
||||
t.Run("RedirectLoopPrevention", func(t *testing.T) {
|
||||
testRedirectLoopPrevention(t, sm)
|
||||
})
|
||||
|
||||
t.Run("CallbackCSRFValidation", func(t *testing.T) {
|
||||
testCallbackCSRFValidation(t, sm)
|
||||
})
|
||||
|
||||
t.Run("EdgeCases", func(t *testing.T) {
|
||||
testEdgeCases(t, sm)
|
||||
})
|
||||
}
|
||||
|
||||
// testCookieConfigurationConsistency verifies that cookies are configured
|
||||
// consistently with Path="/" and SameSite=Lax regardless of request headers
|
||||
func testCookieConfigurationConsistency(t *testing.T, sm *SessionManager) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
expectPath string
|
||||
expectSame http.SameSite
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "StandardRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
},
|
||||
expectPath: "/",
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "Standard HTTP request should get consistent cookie config",
|
||||
},
|
||||
{
|
||||
name: "XMLHttpRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectPath: "/",
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "XMLHttpRequest should still use SameSite=Lax (fix for redirect loop)",
|
||||
},
|
||||
{
|
||||
name: "HTTPSRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectPath: "/",
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "HTTPS requests should have consistent cookie config",
|
||||
},
|
||||
{
|
||||
name: "CustomDomainRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "auth.example.com",
|
||||
"X-Forwarded-Host": "auth.example.com",
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectPath: "/",
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "Custom domain requests should maintain consistent config",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/callback", nil)
|
||||
|
||||
// Set headers
|
||||
for key, value := range tt.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Get session and save it to trigger cookie setting
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set some session data to ensure it gets saved
|
||||
session.SetCSRF("test-csrf-token")
|
||||
session.SetAuthenticated(false)
|
||||
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookie configuration
|
||||
cookies := rw.Result().Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Fatal("No cookies set in response")
|
||||
}
|
||||
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_raczylo") {
|
||||
if cookie.Path != tt.expectPath {
|
||||
t.Errorf("Expected Path=%s, got Path=%s for cookie %s",
|
||||
tt.expectPath, cookie.Path, cookie.Name)
|
||||
}
|
||||
if cookie.SameSite != tt.expectSame {
|
||||
t.Errorf("Expected SameSite=%v, got SameSite=%v for cookie %s",
|
||||
tt.expectSame, cookie.SameSite, cookie.Name)
|
||||
}
|
||||
t.Logf("Cookie %s: Path=%s, SameSite=%v, Secure=%v, HttpOnly=%v",
|
||||
cookie.Name, cookie.Path, cookie.SameSite, cookie.Secure, cookie.HttpOnly)
|
||||
}
|
||||
}
|
||||
|
||||
session.Clear(req, nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testCSRFTokenAccessibility verifies that CSRF tokens remain accessible
|
||||
// during OAuth callbacks regardless of request type
|
||||
func testCSRFTokenAccessibility(t *testing.T, sm *SessionManager) {
|
||||
csrfToken := uuid.New().String()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "StandardCallback",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
},
|
||||
description: "Standard OAuth callback should access CSRF token",
|
||||
},
|
||||
{
|
||||
name: "AjaxCallback",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
},
|
||||
description: "AJAX OAuth callback should access CSRF token",
|
||||
},
|
||||
{
|
||||
name: "HTTPSCallback",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
description: "HTTPS OAuth callback should access CSRF token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Phase 1: Store CSRF token in session (auth initiation)
|
||||
initReq := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
||||
for key, value := range tt.headers {
|
||||
initReq.Header.Set(key, value)
|
||||
}
|
||||
|
||||
initRw := httptest.NewRecorder()
|
||||
|
||||
session, err := sm.GetSession(initReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce("test-nonce")
|
||||
session.SetIncomingPath("/protected")
|
||||
|
||||
err = session.Save(initReq, initRw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get cookies from response to simulate browser behavior
|
||||
storedCookies := initRw.Result().Cookies()
|
||||
|
||||
// Phase 2: OAuth callback with same cookies
|
||||
callbackReq := httptest.NewRequest("GET",
|
||||
"http://example.com/callback?state="+csrfToken+"&code=auth_code", nil)
|
||||
|
||||
for key, value := range tt.headers {
|
||||
callbackReq.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Add cookies to callback request
|
||||
for _, cookie := range storedCookies {
|
||||
callbackReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get session in callback
|
||||
callbackSession, err := sm.GetSession(callbackReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get callback session: %v", err)
|
||||
}
|
||||
defer callbackSession.Clear(callbackReq, nil)
|
||||
|
||||
// Verify CSRF token is accessible
|
||||
retrievedCSRF := callbackSession.GetCSRF()
|
||||
if retrievedCSRF == "" {
|
||||
t.Error("CSRF token not accessible in callback session")
|
||||
}
|
||||
if retrievedCSRF != csrfToken {
|
||||
t.Errorf("CSRF token mismatch: expected %s, got %s", csrfToken, retrievedCSRF)
|
||||
}
|
||||
|
||||
// Verify other session data is accessible
|
||||
if callbackSession.GetNonce() != "test-nonce" {
|
||||
t.Error("Nonce not accessible in callback session")
|
||||
}
|
||||
if callbackSession.GetIncomingPath() != "/protected" {
|
||||
t.Error("Incoming path not accessible in callback session")
|
||||
}
|
||||
|
||||
t.Logf("CSRF token successfully retrieved in %s: %s", tt.name, retrievedCSRF)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testSessionPersistenceAcrossOAuth verifies that session data persists
|
||||
// correctly across the OAuth flow without being lost due to cookie issues
|
||||
func testSessionPersistenceAcrossOAuth(t *testing.T, sm *SessionManager) {
|
||||
// Simulate complete OAuth flow
|
||||
req := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
||||
req.Header.Set("Host", "example.com")
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// Phase 1: Initial authentication request
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get initial session: %v", err)
|
||||
}
|
||||
|
||||
csrfToken := uuid.New().String()
|
||||
nonce := "test-nonce-" + uuid.New().String()
|
||||
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
session.SetIncomingPath("/protected")
|
||||
session.SetCodeVerifier("test-code-verifier")
|
||||
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save initial session: %v", err)
|
||||
}
|
||||
|
||||
initialCookies := rw.Result().Cookies()
|
||||
if len(initialCookies) == 0 {
|
||||
t.Fatal("No cookies set in initial response")
|
||||
}
|
||||
|
||||
// Phase 2: OAuth provider redirect (user authenticates)
|
||||
redirectReq := httptest.NewRequest("GET", "https://auth0.example.com/authorize", nil)
|
||||
// Add cookies as browser would
|
||||
for _, cookie := range initialCookies {
|
||||
redirectReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Phase 3: OAuth callback
|
||||
callbackReq := httptest.NewRequest("GET",
|
||||
"http://example.com/callback?state="+csrfToken+"&code=auth_code_12345", nil)
|
||||
callbackReq.Header.Set("Host", "example.com")
|
||||
callbackReq.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
// Add all cookies from initial response
|
||||
for _, cookie := range initialCookies {
|
||||
callbackReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
callbackRw := httptest.NewRecorder()
|
||||
|
||||
callbackSession, err := sm.GetSession(callbackReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get callback session: %v", err)
|
||||
}
|
||||
defer callbackSession.Clear(callbackReq, nil)
|
||||
|
||||
// Verify all session data persisted
|
||||
if callbackSession.GetCSRF() != csrfToken {
|
||||
t.Errorf("CSRF token not persisted: expected %s, got %s",
|
||||
csrfToken, callbackSession.GetCSRF())
|
||||
}
|
||||
if callbackSession.GetNonce() != nonce {
|
||||
t.Errorf("Nonce not persisted: expected %s, got %s",
|
||||
nonce, callbackSession.GetNonce())
|
||||
}
|
||||
if callbackSession.GetIncomingPath() != "/protected" {
|
||||
t.Errorf("Incoming path not persisted: expected /protected, got %s",
|
||||
callbackSession.GetIncomingPath())
|
||||
}
|
||||
if callbackSession.GetCodeVerifier() != "test-code-verifier" {
|
||||
t.Errorf("Code verifier not persisted: expected test-code-verifier, got %s",
|
||||
callbackSession.GetCodeVerifier())
|
||||
}
|
||||
|
||||
// Simulate successful authentication
|
||||
callbackSession.SetAuthenticated(true)
|
||||
callbackSession.SetEmail("user@example.com")
|
||||
callbackSession.SetAccessToken("access_token_12345")
|
||||
callbackSession.SetRefreshToken("refresh_token_12345")
|
||||
callbackSession.SetIDToken("id_token_12345")
|
||||
|
||||
// Clear OAuth-specific data
|
||||
callbackSession.SetCSRF("")
|
||||
callbackSession.SetNonce("")
|
||||
callbackSession.SetCodeVerifier("")
|
||||
callbackSession.ResetRedirectCount()
|
||||
|
||||
err = callbackSession.Save(callbackReq, callbackRw)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save callback session: %v", err)
|
||||
}
|
||||
|
||||
t.Log("OAuth flow simulation completed successfully - session data persisted")
|
||||
}
|
||||
|
||||
// testRedirectLoopPrevention verifies that the redirect loop prevention
|
||||
// mechanisms work correctly
|
||||
func testRedirectLoopPrevention(t *testing.T, sm *SessionManager) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
||||
req.Header.Set("Host", "example.com")
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.Clear(req, nil)
|
||||
|
||||
// Test redirect count tracking
|
||||
initialCount := session.GetRedirectCount()
|
||||
if initialCount != 0 {
|
||||
t.Errorf("Initial redirect count should be 0, got %d", initialCount)
|
||||
}
|
||||
|
||||
// Simulate multiple redirect attempts
|
||||
for i := 1; i <= 6; i++ {
|
||||
session.IncrementRedirectCount()
|
||||
count := session.GetRedirectCount()
|
||||
if count != i {
|
||||
t.Errorf("Expected redirect count %d, got %d", i, count)
|
||||
}
|
||||
|
||||
// Test that redirect loop detection kicks in at 5 redirects
|
||||
if i >= 5 {
|
||||
t.Logf("Redirect count at %d - should trigger loop detection", count)
|
||||
}
|
||||
}
|
||||
|
||||
// Test reset functionality
|
||||
session.ResetRedirectCount()
|
||||
if session.GetRedirectCount() != 0 {
|
||||
t.Errorf("Redirect count should be 0 after reset, got %d", session.GetRedirectCount())
|
||||
}
|
||||
|
||||
t.Log("Redirect loop prevention tests passed")
|
||||
}
|
||||
|
||||
// testCallbackCSRFValidation tests CSRF token validation in OAuth callbacks
|
||||
func testCallbackCSRFValidation(t *testing.T, sm *SessionManager) {
|
||||
tests := []struct {
|
||||
name string
|
||||
storedCSRF string
|
||||
callbackState string
|
||||
shouldSucceed bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidCSRF",
|
||||
storedCSRF: "valid-csrf-token-123",
|
||||
callbackState: "valid-csrf-token-123",
|
||||
shouldSucceed: true,
|
||||
description: "Valid CSRF token should pass validation",
|
||||
},
|
||||
{
|
||||
name: "InvalidCSRF",
|
||||
storedCSRF: "valid-csrf-token-123",
|
||||
callbackState: "different-csrf-token-456",
|
||||
shouldSucceed: false,
|
||||
description: "Invalid CSRF token should fail validation",
|
||||
},
|
||||
{
|
||||
name: "EmptyStoredCSRF",
|
||||
storedCSRF: "",
|
||||
callbackState: "some-csrf-token",
|
||||
shouldSucceed: false,
|
||||
description: "Empty stored CSRF should fail validation",
|
||||
},
|
||||
{
|
||||
name: "EmptyCallbackState",
|
||||
storedCSRF: "valid-csrf-token-123",
|
||||
callbackState: "",
|
||||
shouldSucceed: false,
|
||||
description: "Empty callback state should fail validation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup phase - store CSRF token
|
||||
setupReq := httptest.NewRequest("GET", "http://example.com/auth", nil)
|
||||
setupReq.Header.Set("Host", "example.com")
|
||||
|
||||
session, err := sm.GetSession(setupReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get setup session: %v", err)
|
||||
}
|
||||
|
||||
if tt.storedCSRF != "" {
|
||||
session.SetCSRF(tt.storedCSRF)
|
||||
}
|
||||
|
||||
setupRw := httptest.NewRecorder()
|
||||
err = session.Save(setupReq, setupRw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save setup session: %v", err)
|
||||
}
|
||||
|
||||
setupCookies := setupRw.Result().Cookies()
|
||||
|
||||
// Callback phase - validate CSRF
|
||||
callbackURL := "http://example.com/callback"
|
||||
if tt.callbackState != "" {
|
||||
callbackURL += "?state=" + tt.callbackState + "&code=test_code"
|
||||
} else {
|
||||
callbackURL += "?code=test_code"
|
||||
}
|
||||
|
||||
callbackReq := httptest.NewRequest("GET", callbackURL, nil)
|
||||
callbackReq.Header.Set("Host", "example.com")
|
||||
|
||||
// Add cookies
|
||||
for _, cookie := range setupCookies {
|
||||
callbackReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
callbackSession, err := sm.GetSession(callbackReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get callback session: %v", err)
|
||||
}
|
||||
defer callbackSession.Clear(callbackReq, nil)
|
||||
|
||||
// Perform CSRF validation
|
||||
storedCSRF := callbackSession.GetCSRF()
|
||||
stateParam := callbackReq.URL.Query().Get("state")
|
||||
|
||||
csrfValid := (storedCSRF != "" && stateParam != "" && storedCSRF == stateParam)
|
||||
|
||||
if tt.shouldSucceed && !csrfValid {
|
||||
t.Errorf("CSRF validation should have succeeded but failed. Stored: '%s', State: '%s'",
|
||||
storedCSRF, stateParam)
|
||||
}
|
||||
if !tt.shouldSucceed && csrfValid {
|
||||
t.Errorf("CSRF validation should have failed but succeeded. Stored: '%s', State: '%s'",
|
||||
storedCSRF, stateParam)
|
||||
}
|
||||
|
||||
t.Logf("CSRF validation test '%s': stored='%s', state='%s', valid=%v",
|
||||
tt.name, storedCSRF, stateParam, csrfValid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testEdgeCases tests various edge cases that could cause redirect loops
|
||||
func testEdgeCases(t *testing.T, sm *SessionManager) {
|
||||
t.Run("MissingHeaders", func(t *testing.T) {
|
||||
// Test with minimal headers
|
||||
req := httptest.NewRequest("GET", "http://localhost/callback", nil)
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session with minimal headers: %v", err)
|
||||
}
|
||||
defer session.Clear(req, nil)
|
||||
|
||||
session.SetCSRF("test-csrf")
|
||||
rw := httptest.NewRecorder()
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save session with minimal headers: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookies still have consistent configuration
|
||||
cookies := rw.Result().Cookies()
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_raczylo") {
|
||||
if cookie.Path != "/" {
|
||||
t.Errorf("Cookie path inconsistent with minimal headers: got %s", cookie.Path)
|
||||
}
|
||||
if cookie.SameSite != http.SameSiteLaxMode {
|
||||
t.Errorf("Cookie SameSite inconsistent with minimal headers: got %v", cookie.SameSite)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DifferentDomains", func(t *testing.T) {
|
||||
domains := []string{"example.com", "auth.example.com", "sub.auth.example.com"}
|
||||
|
||||
for _, domain := range domains {
|
||||
req := httptest.NewRequest("GET", "http://"+domain+"/callback", nil)
|
||||
req.Header.Set("Host", domain)
|
||||
req.Header.Set("X-Forwarded-Host", domain)
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session for domain %s: %v", domain, err)
|
||||
}
|
||||
|
||||
session.SetCSRF("test-csrf-" + domain)
|
||||
rw := httptest.NewRecorder()
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save session for domain %s: %v", domain, err)
|
||||
}
|
||||
|
||||
// Verify consistent cookie configuration across domains
|
||||
cookies := rw.Result().Cookies()
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_raczylo") {
|
||||
if cookie.Path != "/" {
|
||||
t.Errorf("Domain %s: Cookie path inconsistent: got %s", domain, cookie.Path)
|
||||
}
|
||||
if cookie.SameSite != http.SameSiteLaxMode {
|
||||
t.Errorf("Domain %s: Cookie SameSite inconsistent: got %v", domain, cookie.SameSite)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session.Clear(req, nil)
|
||||
t.Logf("Domain %s: Cookie configuration consistent", domain)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcurrentSessions", func(t *testing.T) {
|
||||
// Test that multiple concurrent sessions don't interfere
|
||||
const numSessions = 5
|
||||
sessions := make([]*SessionData, numSessions)
|
||||
|
||||
for i := 0; i < numSessions; i++ {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
req.Header.Set("Host", "example.com")
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session %d: %v", i, err)
|
||||
}
|
||||
sessions[i] = session
|
||||
|
||||
// Set unique data for each session
|
||||
session.SetCSRF("csrf-" + string(rune('A'+i)))
|
||||
session.SetNonce("nonce-" + string(rune('A'+i)))
|
||||
}
|
||||
|
||||
// Verify each session has its own data
|
||||
for i, session := range sessions {
|
||||
expectedCSRF := "csrf-" + string(rune('A'+i))
|
||||
expectedNonce := "nonce-" + string(rune('A'+i))
|
||||
|
||||
if session.GetCSRF() != expectedCSRF {
|
||||
t.Errorf("Session %d CSRF mismatch: expected %s, got %s",
|
||||
i, expectedCSRF, session.GetCSRF())
|
||||
}
|
||||
if session.GetNonce() != expectedNonce {
|
||||
t.Errorf("Session %d nonce mismatch: expected %s, got %s",
|
||||
i, expectedNonce, session.GetNonce())
|
||||
}
|
||||
|
||||
session.Clear(nil, nil)
|
||||
}
|
||||
|
||||
t.Log("Concurrent sessions test passed")
|
||||
})
|
||||
|
||||
t.Run("LargeCookieHandling", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
req.Header.Set("Host", "example.com")
|
||||
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.Clear(req, nil)
|
||||
|
||||
// Test with large realistic JWT token that might require chunking
|
||||
largeToken := generateLargeRealisticToken()
|
||||
session.SetAccessToken(largeToken)
|
||||
session.SetCSRF("test-csrf")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save session with large token: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookies are still consistent even with chunking
|
||||
cookies := rw.Result().Cookies()
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_raczylo") {
|
||||
if cookie.Path != "/" {
|
||||
t.Errorf("Large cookie path inconsistent: got %s", cookie.Path)
|
||||
}
|
||||
if cookie.SameSite != http.SameSiteLaxMode {
|
||||
t.Errorf("Large cookie SameSite inconsistent: got %v", cookie.SameSite)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify token can be retrieved correctly
|
||||
if session.GetAccessToken() != largeToken {
|
||||
t.Error("Large access token not retrieved correctly")
|
||||
}
|
||||
|
||||
t.Log("Large cookie handling test passed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSessionManagerEnhanceSessionSecurity tests the enhanced session security
|
||||
// to ensure SameSite is consistently Lax and not dynamically switched
|
||||
func TestSessionManagerEnhanceSessionSecurity(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
encryptionKey := "0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
sm, err := NewSessionManager(encryptionKey, false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
defer sm.Shutdown()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
expectSame http.SameSite
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "StandardRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
},
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "Standard request should use SameSite=Lax",
|
||||
},
|
||||
{
|
||||
name: "XMLHttpRequestHeader",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
},
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "XMLHttpRequest should still use SameSite=Lax (no dynamic switching)",
|
||||
},
|
||||
{
|
||||
name: "AjaxWithForwardedProto",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
expectSame: http.SameSiteLaxMode,
|
||||
description: "AJAX HTTPS request should use SameSite=Lax (no dynamic switching)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
for key, value := range tt.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Test the EnhanceSessionSecurity method directly
|
||||
options := &sessions.Options{}
|
||||
enhanced := sm.EnhanceSessionSecurity(options, req)
|
||||
|
||||
if enhanced.SameSite != tt.expectSame {
|
||||
t.Errorf("Expected SameSite=%v, got SameSite=%v for %s",
|
||||
tt.expectSame, enhanced.SameSite, tt.description)
|
||||
}
|
||||
|
||||
// Verify Path is always "/"
|
||||
if enhanced.Path != "/" {
|
||||
t.Errorf("Expected Path='/', got Path='%s' for %s",
|
||||
enhanced.Path, tt.description)
|
||||
}
|
||||
|
||||
// Verify HttpOnly is always true
|
||||
if !enhanced.HttpOnly {
|
||||
t.Errorf("Expected HttpOnly=true, got HttpOnly=false for %s", tt.description)
|
||||
}
|
||||
|
||||
t.Logf("%s: SameSite=%v, Path=%s, HttpOnly=%v, Secure=%v",
|
||||
tt.name, enhanced.SameSite, enhanced.Path, enhanced.HttpOnly, enhanced.Secure)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCallbackHandlerIntegration tests the full callback handler integration
|
||||
// to ensure CSRF tokens work correctly with the fixed cookie configuration
|
||||
func TestCallbackHandlerIntegration(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
encryptionKey := "0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
|
||||
sm, err := NewSessionManager(encryptionKey, false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
defer sm.Shutdown()
|
||||
|
||||
// Simulate a complete OAuth flow with various request types
|
||||
scenarios := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
}{
|
||||
{
|
||||
name: "StandardBrowser",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"User-Agent": "Mozilla/5.0 (Browser)",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "AjaxRequest",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"User-Agent": "Mozilla/5.0 (Browser)",
|
||||
"X-Requested-With": "XMLHttpRequest",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "HTTPSProxy",
|
||||
headers: map[string]string{
|
||||
"Host": "example.com",
|
||||
"User-Agent": "Mozilla/5.0 (Browser)",
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
t.Run(scenario.name, func(t *testing.T) {
|
||||
// Phase 1: Auth initiation - store CSRF token
|
||||
initReq := httptest.NewRequest("GET", "http://example.com/protected", nil)
|
||||
for key, value := range scenario.headers {
|
||||
initReq.Header.Set(key, value)
|
||||
}
|
||||
|
||||
initRw := httptest.NewRecorder()
|
||||
|
||||
session, err := sm.GetSession(initReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get init session: %v", err)
|
||||
}
|
||||
|
||||
csrfToken := uuid.New().String()
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce("test-nonce")
|
||||
session.SetIncomingPath("/protected")
|
||||
|
||||
err = session.Save(initReq, initRw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save init session: %v", err)
|
||||
}
|
||||
|
||||
initCookies := initRw.Result().Cookies()
|
||||
|
||||
// Phase 2: OAuth callback - validate CSRF token access
|
||||
callbackReq := httptest.NewRequest("GET",
|
||||
"http://example.com/callback?state="+csrfToken+"&code=test_code", nil)
|
||||
|
||||
for key, value := range scenario.headers {
|
||||
callbackReq.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Add cookies from init phase
|
||||
for _, cookie := range initCookies {
|
||||
callbackReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
callbackSession, err := sm.GetSession(callbackReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get callback session: %v", err)
|
||||
}
|
||||
defer callbackSession.Clear(callbackReq, nil)
|
||||
|
||||
// This is the critical test - CSRF token must be accessible
|
||||
retrievedCSRF := callbackSession.GetCSRF()
|
||||
if retrievedCSRF == "" {
|
||||
t.Errorf("Scenario %s: CSRF token not accessible in callback", scenario.name)
|
||||
}
|
||||
if retrievedCSRF != csrfToken {
|
||||
t.Errorf("Scenario %s: CSRF token mismatch - expected %s, got %s",
|
||||
scenario.name, csrfToken, retrievedCSRF)
|
||||
}
|
||||
|
||||
// Validate state parameter matches CSRF token
|
||||
stateParam := callbackReq.URL.Query().Get("state")
|
||||
if stateParam != csrfToken {
|
||||
t.Errorf("Scenario %s: State parameter mismatch - expected %s, got %s",
|
||||
scenario.name, csrfToken, stateParam)
|
||||
}
|
||||
|
||||
// Simulate successful CSRF validation
|
||||
if retrievedCSRF != "" && retrievedCSRF == stateParam {
|
||||
t.Logf("Scenario %s: CSRF validation successful", scenario.name)
|
||||
} else {
|
||||
t.Errorf("Scenario %s: CSRF validation failed", scenario.name)
|
||||
}
|
||||
|
||||
// Verify other session data persisted
|
||||
if callbackSession.GetNonce() != "test-nonce" {
|
||||
t.Errorf("Scenario %s: Nonce not persisted", scenario.name)
|
||||
}
|
||||
if callbackSession.GetIncomingPath() != "/protected" {
|
||||
t.Errorf("Scenario %s: Incoming path not persisted", scenario.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -648,6 +648,13 @@ func (s *MockSession) SetEmail(email string) {
|
||||
s.values["email"] = email
|
||||
}
|
||||
|
||||
func (s *MockSession) GetEmail() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
email, _ := s.values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
func (s *MockSession) SetCSRF(csrf string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -39,6 +39,10 @@ type SessionData interface {
|
||||
GetCodeVerifier() string
|
||||
GetIncomingPath() string
|
||||
GetAuthenticated() bool
|
||||
GetAccessToken() string
|
||||
GetRefreshToken() string
|
||||
GetIDToken() string
|
||||
GetEmail() string
|
||||
SetAuthenticated(bool) error
|
||||
SetEmail(string)
|
||||
SetIDToken(string)
|
||||
@@ -100,6 +104,20 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
h.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Debug logging for cookie configuration
|
||||
h.logger.Debugf("Callback request headers - Host: %s, X-Forwarded-Host: %s, X-Forwarded-Proto: %s",
|
||||
req.Host, req.Header.Get("X-Forwarded-Host"), req.Header.Get("X-Forwarded-Proto"))
|
||||
|
||||
// Log all cookies in the request for debugging
|
||||
cookies := req.Cookies()
|
||||
h.logger.Debugf("Total cookies in callback request: %d", len(cookies))
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_") {
|
||||
h.logger.Debugf("Cookie found - Name: %s, Domain: %s, Path: %s, SameSite: %v, Secure: %v, HttpOnly: %v, Value length: %d",
|
||||
cookie.Name, cookie.Domain, cookie.Path, cookie.SameSite, cookie.Secure, cookie.HttpOnly, len(cookie.Value))
|
||||
}
|
||||
}
|
||||
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
if errorDescription == "" {
|
||||
@@ -117,22 +135,36 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Debug log the state parameter received
|
||||
h.logger.Debugf("State parameter received in callback: %s (length: %d)", state, len(state))
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
h.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
|
||||
session.GetAuthenticated(), req.URL.String())
|
||||
|
||||
// Enhanced debugging for missing CSRF token
|
||||
cookie, err := req.Cookie("_oidc_raczylo_m")
|
||||
if err != nil {
|
||||
h.logger.Errorf("Main session cookie not found in request: %v", err)
|
||||
h.logger.Debugf("Available cookies: %v", req.Header.Get("Cookie"))
|
||||
} else {
|
||||
h.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
|
||||
h.logger.Debugf("Cookie details - Domain: %s, Path: %s, Secure: %v, HttpOnly: %v, SameSite: %v",
|
||||
cookie.Domain, cookie.Path, cookie.Secure, cookie.HttpOnly, cookie.SameSite)
|
||||
}
|
||||
|
||||
// Log session state for debugging
|
||||
h.logger.Debugf("Session state during CSRF check - Authenticated: %v, Has AccessToken: %v",
|
||||
session.GetAuthenticated(), session.GetAccessToken() != "")
|
||||
|
||||
h.sendErrorResponseFunc(rw, req, "CSRF token missing in session", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Debug log successful CSRF token retrieval
|
||||
h.logger.Debugf("CSRF token retrieved from session: %s (length: %d)", csrfToken, len(csrfToken))
|
||||
|
||||
if state != csrfToken {
|
||||
h.logger.Error("State parameter does not match CSRF token in session during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
|
||||
|
||||
@@ -0,0 +1,541 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestIssue67_InfiniteRefreshLoop reproduces and verifies the fix for issue #67
|
||||
// where concurrent requests with expired tokens caused an infinite refresh loop
|
||||
// leading to OOM conditions
|
||||
func TestIssue67_InfiniteRefreshLoop(t *testing.T) {
|
||||
// Track memory at start
|
||||
runtime.GC()
|
||||
var startMem runtime.MemStats
|
||||
runtime.ReadMemStats(&startMem)
|
||||
|
||||
// Create a mock authorization server
|
||||
var refreshAttempts int32
|
||||
var concurrentRefreshes int32
|
||||
var maxConcurrent int32
|
||||
|
||||
// Create a handler with server URL to be set after creation
|
||||
var serverURL string
|
||||
|
||||
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/token":
|
||||
// Track concurrent refresh attempts
|
||||
current := atomic.AddInt32(&concurrentRefreshes, 1)
|
||||
defer atomic.AddInt32(&concurrentRefreshes, -1)
|
||||
|
||||
// Update max concurrent
|
||||
for {
|
||||
max := atomic.LoadInt32(&maxConcurrent)
|
||||
if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
attempts := atomic.AddInt32(&refreshAttempts, 1)
|
||||
|
||||
// Simulate slow/failing token endpoint (like in the issue)
|
||||
if attempts < 5 {
|
||||
// First few attempts fail to trigger retries
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
w.Write([]byte(`{"error": "temporarily_unavailable"}`))
|
||||
} else {
|
||||
// Eventually succeed
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{
|
||||
"access_token": "new_access_token",
|
||||
"refresh_token": "new_refresh_token",
|
||||
"id_token": "new_id_token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer"
|
||||
}`))
|
||||
}
|
||||
|
||||
case "/.well-known/openid-configuration":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(fmt.Sprintf(`{
|
||||
"issuer": "%s",
|
||||
"authorization_endpoint": "%s/authorize",
|
||||
"token_endpoint": "%s/token",
|
||||
"jwks_uri": "%s/keys",
|
||||
"response_types_supported": ["code"],
|
||||
"subject_types_supported": ["public"],
|
||||
"id_token_signing_alg_values_supported": ["RS256"],
|
||||
"scopes_supported": ["openid", "profile", "email"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
|
||||
"claims_supported": ["sub", "name", "email"]
|
||||
}`, serverURL, serverURL, serverURL, serverURL)))
|
||||
|
||||
case "/keys":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{
|
||||
"keys": [{
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"kid": "test-key",
|
||||
"n": "test",
|
||||
"e": "AQAB"
|
||||
}]
|
||||
}`))
|
||||
}
|
||||
}))
|
||||
defer authServer.Close()
|
||||
|
||||
// Set the server URL after creation
|
||||
serverURL = authServer.URL
|
||||
|
||||
// Setup TraefikOIDC with refresh coordinator
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
config.MaxRefreshAttempts = 3
|
||||
config.RefreshAttemptWindow = 1 * time.Second
|
||||
config.MaxConcurrentRefreshes = 2
|
||||
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
// Simulate expired session
|
||||
expiredSession := &MockExpiredSession{
|
||||
refreshToken: "test_refresh_token",
|
||||
sessionID: "test_session",
|
||||
isExpired: true,
|
||||
}
|
||||
|
||||
// Simulate multiple concurrent requests (as reported in issue)
|
||||
numConcurrentRequests := 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numConcurrentRequests)
|
||||
|
||||
// Track results
|
||||
var successCount int32
|
||||
var errorCount int32
|
||||
errors := make([]error, 0, numConcurrentRequests)
|
||||
var errorMutex sync.Mutex
|
||||
|
||||
// Launch concurrent requests with expired tokens
|
||||
startTime := time.Now()
|
||||
timeout := 5 * time.Second
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
for i := 0; i < numConcurrentRequests; i++ {
|
||||
go func(reqID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Each request tries to refresh the expired token
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
// Simulate calling the token endpoint
|
||||
resp, err := http.Post(
|
||||
serverURL+"/token",
|
||||
"application/x-www-form-urlencoded",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token refresh failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return &TokenResponse{
|
||||
AccessToken: fmt.Sprintf("new_access_%d", reqID),
|
||||
RefreshToken: "new_refresh",
|
||||
IDToken: "new_id",
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Use coordinator to prevent infinite loop
|
||||
result, err := coordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
expiredSession.sessionID,
|
||||
expiredSession.refreshToken,
|
||||
refreshFunc,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
atomic.AddInt32(&errorCount, 1)
|
||||
errorMutex.Lock()
|
||||
errors = append(errors, err)
|
||||
errorMutex.Unlock()
|
||||
} else if result != nil {
|
||||
atomic.AddInt32(&successCount, 1)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for completion or timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Completed normally
|
||||
case <-ctx.Done():
|
||||
t.Fatal("Test timed out - possible infinite loop detected!")
|
||||
}
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
|
||||
// Verify no infinite loop occurred
|
||||
if elapsed > timeout {
|
||||
t.Fatalf("Requests took too long: %v (possible infinite loop)", elapsed)
|
||||
}
|
||||
|
||||
// Check memory usage
|
||||
runtime.GC()
|
||||
var endMem runtime.MemStats
|
||||
runtime.ReadMemStats(&endMem)
|
||||
|
||||
// Calculate memory growth safely to prevent underflow
|
||||
var memGrowthMB float64
|
||||
if endMem.HeapAlloc >= startMem.HeapAlloc {
|
||||
memGrowthMB = float64(endMem.HeapAlloc-startMem.HeapAlloc) / (1024 * 1024)
|
||||
} else {
|
||||
// Memory decreased (GC occurred), treat as 0 growth
|
||||
memGrowthMB = 0
|
||||
}
|
||||
t.Logf("Memory stats: start=%d bytes, end=%d bytes, growth=%.2f MB",
|
||||
startMem.HeapAlloc, endMem.HeapAlloc, memGrowthMB)
|
||||
|
||||
// Memory should not grow excessively (issue reported OOM at 2GB)
|
||||
if memGrowthMB > 100 {
|
||||
t.Errorf("Excessive memory growth: %.2f MB (possible memory leak)", memGrowthMB)
|
||||
}
|
||||
|
||||
// Verify refresh deduplication worked
|
||||
actualRefreshAttempts := atomic.LoadInt32(&refreshAttempts)
|
||||
t.Logf("Total refresh attempts to server: %d", actualRefreshAttempts)
|
||||
t.Logf("Max concurrent refreshes: %d", maxConcurrent)
|
||||
t.Logf("Successful refreshes: %d", successCount)
|
||||
t.Logf("Failed refreshes: %d", errorCount)
|
||||
|
||||
// With deduplication, refresh attempts should be much less than concurrent requests
|
||||
if actualRefreshAttempts > int32(numConcurrentRequests/2) {
|
||||
t.Errorf("Too many refresh attempts (%d), deduplication not working properly",
|
||||
actualRefreshAttempts)
|
||||
}
|
||||
|
||||
// Max concurrent should respect our limit
|
||||
if maxConcurrent > int32(config.MaxConcurrentRefreshes) {
|
||||
t.Errorf("Max concurrent refreshes (%d) exceeded configured limit (%d)",
|
||||
maxConcurrent, config.MaxConcurrentRefreshes)
|
||||
}
|
||||
|
||||
// Check coordinator metrics
|
||||
metrics := coordinator.GetMetrics()
|
||||
t.Logf("Coordinator metrics: %+v", metrics)
|
||||
|
||||
if deduped, ok := metrics["deduplicated_requests"].(int64); ok {
|
||||
if deduped == 0 {
|
||||
t.Error("No requests were deduplicated - deduplication not working")
|
||||
}
|
||||
t.Logf("Deduplicated requests: %d", deduped)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssue67_WithoutCoordinator demonstrates the issue without the fix
|
||||
// WARNING: This test may consume significant memory - skip in CI
|
||||
func TestIssue67_WithoutCoordinator(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping memory-intensive test in short mode")
|
||||
}
|
||||
|
||||
// Only run this test with explicit flag to demonstrate the issue
|
||||
if !testing.Verbose() {
|
||||
t.Skip("Skipping demonstration of issue without fix (run with -v to see)")
|
||||
}
|
||||
|
||||
// Track memory at start
|
||||
runtime.GC()
|
||||
var startMem runtime.MemStats
|
||||
runtime.ReadMemStats(&startMem)
|
||||
|
||||
var refreshAttempts int32
|
||||
var maxConcurrent int32
|
||||
var currentConcurrent int32
|
||||
|
||||
// Simulate the issue: multiple goroutines attempting refresh without coordination
|
||||
numRequests := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numRequests)
|
||||
|
||||
// Use a context with short timeout to prevent actual OOM
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Simulate retry logic without deduplication (the bug)
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
current := atomic.AddInt32(¤tConcurrent, 1)
|
||||
|
||||
// Track max concurrent
|
||||
for {
|
||||
max := atomic.LoadInt32(&maxConcurrent)
|
||||
if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddInt32(&refreshAttempts, 1)
|
||||
|
||||
// Simulate token refresh with exponential backoff
|
||||
time.Sleep(time.Duration(attempt*100) * time.Millisecond)
|
||||
|
||||
// Allocate memory to simulate token processing
|
||||
_ = make([]byte, 1024*10) // 10KB per attempt
|
||||
|
||||
atomic.AddInt32(¤tConcurrent, -1)
|
||||
|
||||
// Simulate failure requiring retry
|
||||
if attempt < 2 {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Completed
|
||||
case <-ctx.Done():
|
||||
// Timed out (expected in problematic scenario)
|
||||
}
|
||||
|
||||
// Check memory usage
|
||||
runtime.GC()
|
||||
var endMem runtime.MemStats
|
||||
runtime.ReadMemStats(&endMem)
|
||||
|
||||
memGrowthMB := float64(endMem.HeapAlloc-startMem.HeapAlloc) / (1024 * 1024)
|
||||
|
||||
t.Logf("WITHOUT COORDINATOR:")
|
||||
t.Logf(" Refresh attempts: %d", refreshAttempts)
|
||||
t.Logf(" Max concurrent: %d", maxConcurrent)
|
||||
t.Logf(" Memory growth: %.2f MB", memGrowthMB)
|
||||
|
||||
// This demonstrates the issue - high concurrency and many attempts
|
||||
if refreshAttempts < int32(numRequests*2) {
|
||||
t.Logf("Note: Without coordinator, saw %d refresh attempts for %d requests",
|
||||
refreshAttempts, numRequests)
|
||||
}
|
||||
}
|
||||
|
||||
// MockExpiredSession simulates an expired session for testing
|
||||
type MockExpiredSession struct {
|
||||
refreshToken string
|
||||
sessionID string
|
||||
isExpired bool
|
||||
}
|
||||
|
||||
func (m *MockExpiredSession) GetRefreshToken() string {
|
||||
return m.refreshToken
|
||||
}
|
||||
|
||||
func (m *MockExpiredSession) GetSessionID() string {
|
||||
return m.sessionID
|
||||
}
|
||||
|
||||
func (m *MockExpiredSession) IsExpired() bool {
|
||||
return m.isExpired
|
||||
}
|
||||
|
||||
// BenchmarkRefreshWithCoordinator measures performance with the fix
|
||||
func BenchmarkRefreshWithCoordinator(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
// Simulate token refresh
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
return &TokenResponse{
|
||||
AccessToken: "new_token",
|
||||
RefreshToken: "new_refresh",
|
||||
}, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
ctx := context.Background()
|
||||
sessionID := fmt.Sprintf("session_%d", i%10)
|
||||
refreshToken := "refresh_token"
|
||||
|
||||
_, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
||||
i++
|
||||
}
|
||||
})
|
||||
|
||||
b.StopTimer()
|
||||
|
||||
metrics := coordinator.GetMetrics()
|
||||
b.Logf("Total requests: %v", metrics["total_requests"])
|
||||
b.Logf("Deduplicated: %v", metrics["deduplicated_requests"])
|
||||
b.Logf("Success rate: %.2f%%",
|
||||
float64(metrics["successful_refreshes"].(int64))/
|
||||
float64(metrics["total_requests"].(int64))*100)
|
||||
}
|
||||
|
||||
// TestRefreshCoordinatorIntegration tests the full integration
|
||||
func TestRefreshCoordinatorIntegration(t *testing.T) {
|
||||
// This test verifies the coordinator integrates properly with:
|
||||
// 1. Circuit breaker
|
||||
// 2. Rate limiting
|
||||
// 3. Deduplication
|
||||
// 4. Memory management
|
||||
// 5. Cleanup routines
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
config.MaxRefreshAttempts = 5
|
||||
config.RefreshAttemptWindow = 1 * time.Second
|
||||
config.RefreshCooldownPeriod = 2 * time.Second
|
||||
config.MaxConcurrentRefreshes = 3
|
||||
config.CleanupInterval = 500 * time.Millisecond
|
||||
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
// Test 1: Normal operation
|
||||
t.Run("NormalOperation", func(t *testing.T) {
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
return &TokenResponse{AccessToken: "token1"}, nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
result, err := coordinator.CoordinateRefresh(ctx, "session1", "refresh1", refreshFunc)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Normal refresh failed: %v", err)
|
||||
}
|
||||
if result == nil || result.AccessToken != "token1" {
|
||||
t.Error("Invalid result from normal refresh")
|
||||
}
|
||||
})
|
||||
|
||||
// Test 2: Circuit breaker activation
|
||||
t.Run("CircuitBreaker", func(t *testing.T) {
|
||||
failingRefresh := func() (*TokenResponse, error) {
|
||||
return nil, fmt.Errorf("service unavailable")
|
||||
}
|
||||
|
||||
// Trigger circuit breaker
|
||||
for i := 0; i < 4; i++ {
|
||||
ctx := context.Background()
|
||||
_, _ = coordinator.CoordinateRefresh(ctx,
|
||||
fmt.Sprintf("cb_session_%d", i), "refresh_cb", failingRefresh)
|
||||
}
|
||||
|
||||
// Next request should be blocked by circuit breaker
|
||||
ctx := context.Background()
|
||||
_, err := coordinator.CoordinateRefresh(ctx, "cb_session_blocked", "refresh_cb", failingRefresh)
|
||||
|
||||
if err == nil || !strings.Contains(err.Error(), "circuit breaker") {
|
||||
t.Errorf("Circuit breaker should have blocked request: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test 3: Rate limiting
|
||||
t.Run("RateLimiting", func(t *testing.T) {
|
||||
// Reset circuit breaker to closed state for this test
|
||||
coordinator.circuitBreaker.mutex.Lock()
|
||||
atomic.StoreInt32(&coordinator.circuitBreaker.state, 0) // closed
|
||||
atomic.StoreInt32(&coordinator.circuitBreaker.failures, 0)
|
||||
coordinator.circuitBreaker.mutex.Unlock()
|
||||
|
||||
// Temporarily increase circuit breaker threshold to not interfere
|
||||
oldMaxFailures := coordinator.circuitBreaker.config.MaxFailures
|
||||
coordinator.circuitBreaker.config.MaxFailures = 20
|
||||
defer func() {
|
||||
coordinator.circuitBreaker.config.MaxFailures = oldMaxFailures
|
||||
}()
|
||||
|
||||
failingRefresh := func() (*TokenResponse, error) {
|
||||
return nil, fmt.Errorf("failed")
|
||||
}
|
||||
|
||||
sessionID := "rate_limit_session"
|
||||
|
||||
// Exhaust attempts
|
||||
for i := 0; i < config.MaxRefreshAttempts+1; i++ {
|
||||
ctx := context.Background()
|
||||
_, _ = coordinator.CoordinateRefresh(ctx, sessionID, "refresh_rl", failingRefresh)
|
||||
// Add delay to ensure operations complete and aren't deduplicated
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Should be in cooldown
|
||||
ctx := context.Background()
|
||||
_, err := coordinator.CoordinateRefresh(ctx, sessionID, "refresh_rl", failingRefresh)
|
||||
|
||||
if err == nil || !strings.Contains(err.Error(), "cooldown") {
|
||||
t.Errorf("Rate limiting should have triggered cooldown: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test 4: Cleanup
|
||||
t.Run("Cleanup", func(t *testing.T) {
|
||||
// Add some sessions
|
||||
for i := 0; i < 5; i++ {
|
||||
coordinator.recordRefreshAttempt(fmt.Sprintf("cleanup_session_%d", i))
|
||||
}
|
||||
|
||||
// Wait for cleanup
|
||||
time.Sleep(config.CleanupInterval * 3)
|
||||
|
||||
// Old sessions should be cleaned up
|
||||
coordinator.attemptsMutex.RLock()
|
||||
count := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
|
||||
// Should have fewer sessions after cleanup
|
||||
if count > 10 {
|
||||
t.Errorf("Cleanup not working, %d sessions remain", count)
|
||||
}
|
||||
})
|
||||
|
||||
// Verify final metrics
|
||||
metrics := coordinator.GetMetrics()
|
||||
t.Logf("Final metrics: %+v", metrics)
|
||||
}
|
||||
@@ -32,6 +32,12 @@ func GetMemoryOptimizations() *MemoryOptimizations {
|
||||
return globalMemoryOpts
|
||||
}
|
||||
|
||||
// ResetGlobalMemoryOptimizations resets the global memory optimizations for testing
|
||||
func ResetGlobalMemoryOptimizations() {
|
||||
globalMemoryOptsOnce = sync.Once{}
|
||||
globalMemoryOpts = nil
|
||||
}
|
||||
|
||||
// BufferPool manages a pool of byte buffers
|
||||
type BufferPool struct {
|
||||
pool sync.Pool
|
||||
|
||||
+3
-1
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -95,7 +96,8 @@ func (mc *MetadataCache) GetProviderMetadata(ctx context.Context, providerURL st
|
||||
}
|
||||
|
||||
// Fetch from provider
|
||||
metadataURL := providerURL + "/.well-known/openid-configuration"
|
||||
// Ensure no double slashes by trimming trailing slash from provider URL
|
||||
metadataURL := strings.TrimRight(providerURL, "/") + "/.well-known/openid-configuration"
|
||||
mc.logger.Infof("Fetching provider metadata from: %s", metadataURL)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil)
|
||||
|
||||
+34
-3
@@ -7,10 +7,27 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// isRaceDetectorEnabled returns true if the Go race detector is enabled.
|
||||
// This is determined by checking the build info for the race build tag.
|
||||
func isRaceDetectorEnabled() bool {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, setting := range info.Settings {
|
||||
if setting.Key == "-race" && setting.Value == "true" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Alternative method: check if GORACE environment variable is set
|
||||
return os.Getenv("GORACE") != ""
|
||||
}
|
||||
|
||||
func TestProfilingManager(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
pm := NewProfilingManager(logger)
|
||||
@@ -462,13 +479,27 @@ func TestProviderMetadataMemoryLeakDetection(t *testing.T) {
|
||||
}
|
||||
|
||||
// Phase 2: Continue with more fetches to test sustained operation
|
||||
t.Log("Phase 2: Testing sustained operation with 1000 iterations...")
|
||||
for i := 20; i < 1020; i++ {
|
||||
// Adjust iterations based on race detector presence to avoid timeouts
|
||||
var phase2Iterations int
|
||||
var sleepDuration time.Duration
|
||||
if isRaceDetectorEnabled() {
|
||||
// With race detector: reduce iterations significantly to stay well within timeout
|
||||
phase2Iterations = 100
|
||||
sleepDuration = 100 * time.Millisecond // Slightly longer sleep to reduce CPU contention
|
||||
t.Log("Phase 2: Testing sustained operation with 100 iterations (race detector enabled)...")
|
||||
} else {
|
||||
// Without race detector: use original values for thorough testing
|
||||
phase2Iterations = 1000
|
||||
sleepDuration = 50 * time.Millisecond
|
||||
t.Log("Phase 2: Testing sustained operation with 1000 iterations...")
|
||||
}
|
||||
|
||||
for i := 20; i < 20+phase2Iterations; i++ {
|
||||
_, err := metadataCache.GetMetadata(providerURL, httpClient, logger)
|
||||
if err != nil {
|
||||
t.Logf("Metadata fetch %d failed: %v", i+1, err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond) // Reduced sleep for faster execution
|
||||
time.Sleep(sleepDuration)
|
||||
}
|
||||
|
||||
// Take final snapshot
|
||||
|
||||
@@ -0,0 +1,596 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RefreshCoordinator prevents duplicate refresh token operations and manages
|
||||
// refresh attempt tracking to prevent infinite loops and OOM conditions.
|
||||
// It implements request coalescing, rate limiting, and circuit breaking
|
||||
// specifically for token refresh operations.
|
||||
type RefreshCoordinator struct {
|
||||
// inFlightRefreshes tracks active refresh operations by refresh token hash
|
||||
inFlightRefreshes map[string]*refreshOperation
|
||||
// refreshMutex protects the inFlightRefreshes map
|
||||
refreshMutex sync.RWMutex
|
||||
|
||||
// sessionRefreshAttempts tracks refresh attempts per session
|
||||
sessionRefreshAttempts map[string]*refreshAttemptTracker
|
||||
// attemptsMutex protects sessionRefreshAttempts map
|
||||
attemptsMutex sync.RWMutex
|
||||
|
||||
// Circuit breaker for refresh operations
|
||||
circuitBreaker *RefreshCircuitBreaker
|
||||
|
||||
// Configuration
|
||||
config RefreshCoordinatorConfig
|
||||
|
||||
// Metrics
|
||||
metrics *RefreshMetrics
|
||||
|
||||
// Logger
|
||||
logger *Logger
|
||||
|
||||
// Cleanup goroutine control
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// RefreshCoordinatorConfig configures the refresh coordinator behavior
|
||||
type RefreshCoordinatorConfig struct {
|
||||
// Maximum refresh attempts per session before giving up
|
||||
MaxRefreshAttempts int
|
||||
// Time window for refresh attempt tracking
|
||||
RefreshAttemptWindow time.Duration
|
||||
// Cooldown period after max attempts reached
|
||||
RefreshCooldownPeriod time.Duration
|
||||
// Maximum concurrent refresh operations
|
||||
MaxConcurrentRefreshes int
|
||||
// Timeout for individual refresh operations
|
||||
RefreshTimeout time.Duration
|
||||
// Enable memory pressure detection
|
||||
EnableMemoryPressureDetection bool
|
||||
// Memory pressure threshold (in MB)
|
||||
MemoryPressureThresholdMB uint64
|
||||
// Cleanup interval for stale entries
|
||||
CleanupInterval time.Duration
|
||||
// Delay before cleaning up completed refresh operations from deduplication map
|
||||
// Set to 0 for immediate cleanup (useful for tests)
|
||||
DeduplicationCleanupDelay time.Duration
|
||||
}
|
||||
|
||||
// DefaultRefreshCoordinatorConfig returns production-ready configuration
|
||||
func DefaultRefreshCoordinatorConfig() RefreshCoordinatorConfig {
|
||||
return RefreshCoordinatorConfig{
|
||||
MaxRefreshAttempts: 5,
|
||||
RefreshAttemptWindow: 5 * time.Minute,
|
||||
RefreshCooldownPeriod: 10 * time.Minute,
|
||||
MaxConcurrentRefreshes: 10,
|
||||
RefreshTimeout: 30 * time.Second,
|
||||
EnableMemoryPressureDetection: true,
|
||||
MemoryPressureThresholdMB: 500, // 500MB threshold
|
||||
CleanupInterval: 1 * time.Minute,
|
||||
DeduplicationCleanupDelay: 100 * time.Millisecond, // Default 100ms for production
|
||||
}
|
||||
}
|
||||
|
||||
// refreshOperation represents an in-flight refresh operation
|
||||
type refreshOperation struct {
|
||||
// refreshToken being refreshed (for validation)
|
||||
refreshToken string
|
||||
// result stores the final result
|
||||
result *refreshResult
|
||||
// done signals when the operation is complete
|
||||
done chan struct{}
|
||||
// startTime tracks when the operation started
|
||||
startTime time.Time
|
||||
// waiterCount tracks number of goroutines waiting
|
||||
waiterCount int32
|
||||
// mutex protects the result field
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// refreshResult contains the result of a refresh operation
|
||||
type refreshResult struct {
|
||||
tokenResponse *TokenResponse
|
||||
err error
|
||||
fromCache bool
|
||||
}
|
||||
|
||||
// refreshAttemptTracker tracks refresh attempts for a session
|
||||
type refreshAttemptTracker struct {
|
||||
// attempts counts refresh attempts in current window
|
||||
attempts int32
|
||||
// lastAttemptTime is the timestamp of the last attempt
|
||||
lastAttemptTime time.Time
|
||||
// windowStartTime is when the current tracking window started
|
||||
windowStartTime time.Time
|
||||
// inCooldown indicates if this session is in cooldown
|
||||
inCooldown bool
|
||||
// cooldownEndTime is when cooldown period ends
|
||||
cooldownEndTime time.Time
|
||||
// consecutiveFailures tracks consecutive refresh failures
|
||||
consecutiveFailures int32
|
||||
}
|
||||
|
||||
// RefreshMetrics tracks coordinator performance metrics
|
||||
type RefreshMetrics struct {
|
||||
totalRefreshRequests int64
|
||||
deduplicatedRequests int64
|
||||
successfulRefreshes int64
|
||||
failedRefreshes int64
|
||||
circuitBreakerTrips int64
|
||||
memoryPressureEvents int64
|
||||
cooldownsTriggered int64
|
||||
currentInFlightRefreshes int32
|
||||
}
|
||||
|
||||
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations
|
||||
type RefreshCircuitBreaker struct {
|
||||
state int32 // 0=closed, 1=open, 2=half-open
|
||||
failures int32
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
config RefreshCircuitBreakerConfig
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// RefreshCircuitBreakerConfig configures the refresh circuit breaker
|
||||
type RefreshCircuitBreakerConfig struct {
|
||||
MaxFailures int
|
||||
OpenDuration time.Duration
|
||||
HalfOpenRequests int
|
||||
}
|
||||
|
||||
// NewRefreshCoordinator creates a new refresh coordinator
|
||||
func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *RefreshCoordinator {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
rc := &RefreshCoordinator{
|
||||
inFlightRefreshes: make(map[string]*refreshOperation),
|
||||
sessionRefreshAttempts: make(map[string]*refreshAttemptTracker),
|
||||
config: config,
|
||||
metrics: &RefreshMetrics{},
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
circuitBreaker: &RefreshCircuitBreaker{
|
||||
config: RefreshCircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
OpenDuration: 30 * time.Second,
|
||||
HalfOpenRequests: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
rc.wg.Add(1)
|
||||
go rc.cleanupRoutine()
|
||||
|
||||
return rc
|
||||
}
|
||||
|
||||
// CoordinateRefresh ensures only one refresh operation happens per refresh token
|
||||
// and implements request coalescing for concurrent refresh attempts
|
||||
func (rc *RefreshCoordinator) CoordinateRefresh(
|
||||
ctx context.Context,
|
||||
sessionID string,
|
||||
refreshToken string,
|
||||
refreshFunc func() (*TokenResponse, error),
|
||||
) (*TokenResponse, error) {
|
||||
// Increment total request count
|
||||
atomic.AddInt64(&rc.metrics.totalRefreshRequests, 1)
|
||||
|
||||
// Check circuit breaker first
|
||||
if !rc.circuitBreaker.AllowRequest() {
|
||||
atomic.AddInt64(&rc.metrics.circuitBreakerTrips, 1)
|
||||
return nil, fmt.Errorf("refresh circuit breaker is open due to repeated failures")
|
||||
}
|
||||
|
||||
// Create hash of refresh token for deduplication
|
||||
tokenHash := rc.hashRefreshToken(refreshToken)
|
||||
|
||||
// CRITICAL FIX: Atomically check for existing operation OR create new one
|
||||
// This prevents the race where multiple goroutines check, find nothing, then all create
|
||||
operation, isNew, err := rc.getOrCreateOperation(ctx, sessionID, tokenHash, refreshToken)
|
||||
|
||||
if err != nil {
|
||||
// Operation creation was rejected (rate limit, memory pressure, concurrent limit)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isNew {
|
||||
// We created a new operation, so we need to execute it
|
||||
go rc.executeRefreshAsync(operation, sessionID, tokenHash, refreshFunc)
|
||||
} else {
|
||||
// Joined existing operation - this is a deduplicated request
|
||||
atomic.AddInt64(&rc.metrics.deduplicatedRequests, 1)
|
||||
}
|
||||
|
||||
// Wait for the operation to complete
|
||||
select {
|
||||
case <-operation.done:
|
||||
// Get the result
|
||||
operation.mutex.RLock()
|
||||
result := operation.result
|
||||
operation.mutex.RUnlock()
|
||||
|
||||
if result != nil {
|
||||
// Record metrics based on result
|
||||
if result.err != nil {
|
||||
rc.circuitBreaker.RecordFailure()
|
||||
rc.recordRefreshFailure(sessionID)
|
||||
atomic.AddInt64(&rc.metrics.failedRefreshes, 1)
|
||||
} else {
|
||||
rc.circuitBreaker.RecordSuccess()
|
||||
rc.recordRefreshSuccess(sessionID)
|
||||
atomic.AddInt64(&rc.metrics.successfulRefreshes, 1)
|
||||
}
|
||||
return result.tokenResponse, result.err
|
||||
}
|
||||
return nil, fmt.Errorf("refresh operation completed without result")
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateOperation atomically checks for an existing operation or creates a new one
|
||||
// Returns (operation, true, nil) if a new operation was created
|
||||
// Returns (operation, false, nil) if joined an existing operation
|
||||
// Returns (nil, false, error) if the operation was rejected
|
||||
func (rc *RefreshCoordinator) getOrCreateOperation(
|
||||
ctx context.Context,
|
||||
sessionID string,
|
||||
tokenHash string,
|
||||
refreshToken string,
|
||||
) (*refreshOperation, bool, error) {
|
||||
rc.refreshMutex.Lock()
|
||||
defer rc.refreshMutex.Unlock()
|
||||
|
||||
// Check for existing operation while holding the lock
|
||||
if existingOp, exists := rc.inFlightRefreshes[tokenHash]; exists {
|
||||
if existingOp.refreshToken == refreshToken {
|
||||
// Join existing operation
|
||||
atomic.AddInt32(&existingOp.waiterCount, 1)
|
||||
return existingOp, false, nil
|
||||
}
|
||||
// Different refresh token for same hash - should not happen
|
||||
return nil, false, fmt.Errorf("refresh token mismatch")
|
||||
}
|
||||
|
||||
// No existing operation - check if we can create a new one
|
||||
// All checks happen while holding the lock to prevent races
|
||||
|
||||
// Check and record refresh attempt for rate limiting
|
||||
rc.recordRefreshAttempt(sessionID)
|
||||
if rc.isInCooldown(sessionID) {
|
||||
atomic.AddInt64(&rc.metrics.cooldownsTriggered, 1)
|
||||
return nil, false, fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
|
||||
}
|
||||
|
||||
// Check memory pressure
|
||||
if rc.config.EnableMemoryPressureDetection && rc.isUnderMemoryPressure() {
|
||||
atomic.AddInt64(&rc.metrics.memoryPressureEvents, 1)
|
||||
return nil, false, fmt.Errorf("system under memory pressure, refresh denied")
|
||||
}
|
||||
|
||||
// Check and reserve concurrent refresh slot atomically
|
||||
current := atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes)
|
||||
if int(current) >= rc.config.MaxConcurrentRefreshes {
|
||||
return nil, false, fmt.Errorf("maximum concurrent refresh operations reached")
|
||||
}
|
||||
|
||||
// Reserve the slot - we're still holding the lock so this is safe
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
|
||||
|
||||
// Create and register new operation
|
||||
operation := &refreshOperation{
|
||||
refreshToken: refreshToken,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
waiterCount: 1,
|
||||
}
|
||||
rc.inFlightRefreshes[tokenHash] = operation
|
||||
|
||||
return operation, true, nil
|
||||
}
|
||||
|
||||
// executeRefreshAsync performs the actual refresh operation asynchronously
|
||||
func (rc *RefreshCoordinator) executeRefreshAsync(
|
||||
operation *refreshOperation,
|
||||
sessionID string,
|
||||
tokenHash string,
|
||||
refreshFunc func() (*TokenResponse, error),
|
||||
) {
|
||||
defer func() {
|
||||
// Signal completion to all waiters
|
||||
close(operation.done)
|
||||
|
||||
// Clean up operation after a configurable delay to allow waiters to read result
|
||||
go func() {
|
||||
if rc.config.DeduplicationCleanupDelay > 0 {
|
||||
time.Sleep(rc.config.DeduplicationCleanupDelay)
|
||||
}
|
||||
rc.refreshMutex.Lock()
|
||||
delete(rc.inFlightRefreshes, tokenHash)
|
||||
rc.refreshMutex.Unlock()
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
|
||||
}()
|
||||
}()
|
||||
|
||||
// Create timeout context
|
||||
refreshCtx, cancel := context.WithTimeout(context.Background(), rc.config.RefreshTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Execute refresh in goroutine to respect timeout
|
||||
resultChan := make(chan struct {
|
||||
resp *TokenResponse
|
||||
err error
|
||||
}, 1)
|
||||
|
||||
go func() {
|
||||
resp, err := refreshFunc()
|
||||
select {
|
||||
case resultChan <- struct {
|
||||
resp *TokenResponse
|
||||
err error
|
||||
}{resp, err}:
|
||||
case <-refreshCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
// Store result for all waiters
|
||||
operation.mutex.Lock()
|
||||
operation.result = &refreshResult{
|
||||
tokenResponse: result.resp,
|
||||
err: result.err,
|
||||
fromCache: false,
|
||||
}
|
||||
operation.mutex.Unlock()
|
||||
case <-refreshCtx.Done():
|
||||
// Timeout occurred
|
||||
timeoutErr := fmt.Errorf("refresh operation timed out after %v", rc.config.RefreshTimeout)
|
||||
operation.mutex.Lock()
|
||||
operation.result = &refreshResult{
|
||||
tokenResponse: nil,
|
||||
err: timeoutErr,
|
||||
fromCache: false,
|
||||
}
|
||||
operation.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// isInCooldown checks if a session is in cooldown after recording an attempt
|
||||
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
tracker, exists := rc.sessionRefreshAttempts[sessionID]
|
||||
if !exists {
|
||||
return false // No tracker means first attempt, not in cooldown
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Check if already in cooldown
|
||||
if tracker.inCooldown {
|
||||
if now.After(tracker.cooldownEndTime) {
|
||||
// Cooldown expired, reset tracker
|
||||
tracker.inCooldown = false
|
||||
tracker.attempts = 1 // Already recorded one attempt
|
||||
tracker.consecutiveFailures = 0
|
||||
tracker.windowStartTime = now
|
||||
return false
|
||||
}
|
||||
return true // Still in cooldown
|
||||
}
|
||||
|
||||
// Check if window expired
|
||||
if now.Sub(tracker.windowStartTime) > rc.config.RefreshAttemptWindow {
|
||||
// Reset window
|
||||
tracker.attempts = 1 // Already recorded one attempt
|
||||
tracker.windowStartTime = now
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if just exceeded attempt limit
|
||||
if int(tracker.attempts) >= rc.config.MaxRefreshAttempts {
|
||||
// Enter cooldown now
|
||||
tracker.inCooldown = true
|
||||
tracker.cooldownEndTime = now.Add(rc.config.RefreshCooldownPeriod)
|
||||
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
|
||||
sessionID, tracker.attempts)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// recordRefreshAttempt records a refresh attempt for rate limiting
|
||||
func (rc *RefreshCoordinator) recordRefreshAttempt(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
tracker, exists := rc.sessionRefreshAttempts[sessionID]
|
||||
if !exists {
|
||||
tracker = &refreshAttemptTracker{
|
||||
windowStartTime: time.Now(),
|
||||
}
|
||||
rc.sessionRefreshAttempts[sessionID] = tracker
|
||||
}
|
||||
|
||||
atomic.AddInt32(&tracker.attempts, 1)
|
||||
tracker.lastAttemptTime = time.Now()
|
||||
}
|
||||
|
||||
// recordRefreshSuccess records a successful refresh
|
||||
func (rc *RefreshCoordinator) recordRefreshSuccess(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
|
||||
tracker.consecutiveFailures = 0
|
||||
}
|
||||
}
|
||||
|
||||
// recordRefreshFailure records a failed refresh
|
||||
func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
|
||||
atomic.AddInt32(&tracker.consecutiveFailures, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// hashRefreshToken creates a hash of the refresh token for deduplication
|
||||
func (rc *RefreshCoordinator) hashRefreshToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// isUnderMemoryPressure checks if the system is under memory pressure
|
||||
func (rc *RefreshCoordinator) isUnderMemoryPressure() bool {
|
||||
// This is a simplified check - in production you'd want to use runtime.MemStats
|
||||
// or system-specific memory monitoring
|
||||
return false // Placeholder - implement actual memory check
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically cleans up stale tracking entries
|
||||
func (rc *RefreshCoordinator) cleanupRoutine() {
|
||||
defer rc.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(rc.config.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
rc.cleanupStaleEntries()
|
||||
case <-rc.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleEntries removes outdated tracking entries
|
||||
func (rc *RefreshCoordinator) cleanupStaleEntries() {
|
||||
now := time.Now()
|
||||
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
// Clean up old session trackers
|
||||
for sessionID, tracker := range rc.sessionRefreshAttempts {
|
||||
// Remove trackers that haven't been used recently
|
||||
if now.Sub(tracker.lastAttemptTime) > 2*rc.config.RefreshAttemptWindow {
|
||||
delete(rc.sessionRefreshAttempts, sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns current coordinator metrics
|
||||
func (rc *RefreshCoordinator) GetMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"total_requests": atomic.LoadInt64(&rc.metrics.totalRefreshRequests),
|
||||
"deduplicated_requests": atomic.LoadInt64(&rc.metrics.deduplicatedRequests),
|
||||
"successful_refreshes": atomic.LoadInt64(&rc.metrics.successfulRefreshes),
|
||||
"failed_refreshes": atomic.LoadInt64(&rc.metrics.failedRefreshes),
|
||||
"circuit_breaker_trips": atomic.LoadInt64(&rc.metrics.circuitBreakerTrips),
|
||||
"memory_pressure_events": atomic.LoadInt64(&rc.metrics.memoryPressureEvents),
|
||||
"cooldowns_triggered": atomic.LoadInt64(&rc.metrics.cooldownsTriggered),
|
||||
"current_inflight": atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes),
|
||||
"circuit_breaker_state": rc.circuitBreaker.GetState(),
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the coordinator
|
||||
func (rc *RefreshCoordinator) Shutdown() {
|
||||
close(rc.stopChan)
|
||||
rc.wg.Wait()
|
||||
}
|
||||
|
||||
// AllowRequest checks if the circuit breaker allows a request
|
||||
func (cb *RefreshCircuitBreaker) AllowRequest() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
|
||||
switch state {
|
||||
case 0: // Closed
|
||||
return true
|
||||
case 1: // Open
|
||||
if time.Since(cb.lastFailureTime) > cb.config.OpenDuration {
|
||||
// Try to transition to half-open
|
||||
if atomic.CompareAndSwapInt32(&cb.state, 1, 2) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
case 2: // Half-open
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation
|
||||
func (cb *RefreshCircuitBreaker) RecordSuccess() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
if state == 2 { // Half-open
|
||||
// Close the circuit
|
||||
atomic.StoreInt32(&cb.state, 0)
|
||||
atomic.StoreInt32(&cb.failures, 0)
|
||||
} else if state == 0 { // Closed
|
||||
// Reset failure count on success
|
||||
atomic.StoreInt32(&cb.failures, 0)
|
||||
}
|
||||
cb.lastSuccessTime = time.Now()
|
||||
}
|
||||
|
||||
// RecordFailure records a failed operation
|
||||
func (cb *RefreshCircuitBreaker) RecordFailure() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
failures := atomic.AddInt32(&cb.failures, 1)
|
||||
cb.lastFailureTime = time.Now()
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
|
||||
if state == 0 && int(failures) >= cb.config.MaxFailures {
|
||||
// Open the circuit
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
} else if state == 2 {
|
||||
// Half-open failed, return to open
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current state of the circuit breaker
|
||||
func (cb *RefreshCircuitBreaker) GetState() string {
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
switch state {
|
||||
case 0:
|
||||
return "closed"
|
||||
case 1:
|
||||
return "open"
|
||||
case 2:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,669 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestConcurrentRefreshDeduplication verifies that concurrent refresh attempts
|
||||
// for the same token are deduplicated and only one refresh operation occurs
|
||||
func TestConcurrentRefreshDeduplication(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
// Keep default delay for this test - it's testing deduplication behavior
|
||||
// Disable rate limiting for this test since we're testing deduplication
|
||||
config.MaxRefreshAttempts = 1000 // High enough to not interfere
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
// Counter to track actual refresh executions
|
||||
var refreshExecutions int32
|
||||
|
||||
// Mock refresh function
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
atomic.AddInt32(&refreshExecutions, 1)
|
||||
// Simulate some processing time
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return &TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
RefreshToken: "new_refresh_token",
|
||||
IDToken: "new_id_token",
|
||||
ExpiresIn: 3600,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Number of concurrent requests
|
||||
numRequests := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numRequests)
|
||||
|
||||
// Channel to collect results
|
||||
results := make(chan *TokenResponse, numRequests)
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
// Launch concurrent refresh attempts with unique identifiers
|
||||
refreshToken := fmt.Sprintf("test_refresh_token_%d", time.Now().UnixNano())
|
||||
sessionID := fmt.Sprintf("test_session_%d", time.Now().UnixNano())
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(reqID int) {
|
||||
defer wg.Done()
|
||||
|
||||
ctx := context.Background()
|
||||
resp, err := coordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
sessionID,
|
||||
refreshToken,
|
||||
refreshFunc,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errors <- err
|
||||
} else {
|
||||
results <- resp
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
wg.Wait()
|
||||
close(results)
|
||||
close(errors)
|
||||
|
||||
// Verify results
|
||||
actualExecutions := atomic.LoadInt32(&refreshExecutions)
|
||||
// Allow for slight timing variations - up to 2 executions is acceptable
|
||||
// This can happen when a second goroutine starts just as the first completes
|
||||
if actualExecutions > 2 {
|
||||
t.Errorf("Expected 1-2 refresh executions, got %d", actualExecutions)
|
||||
}
|
||||
|
||||
// Verify all requests got the same result
|
||||
var firstResponse *TokenResponse
|
||||
responseCount := 0
|
||||
|
||||
for resp := range results {
|
||||
responseCount++
|
||||
if firstResponse == nil {
|
||||
firstResponse = resp
|
||||
} else {
|
||||
// All responses should be identical (same pointer)
|
||||
if resp.AccessToken != firstResponse.AccessToken {
|
||||
t.Error("Different responses returned for concurrent requests")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
errorCount := 0
|
||||
for range errors {
|
||||
errorCount++
|
||||
}
|
||||
|
||||
if errorCount > 0 {
|
||||
t.Errorf("Unexpected errors in concurrent requests: %d", errorCount)
|
||||
}
|
||||
|
||||
if responseCount != numRequests {
|
||||
t.Errorf("Expected %d successful responses, got %d", numRequests, responseCount)
|
||||
}
|
||||
|
||||
// Verify metrics
|
||||
metrics := coordinator.GetMetrics()
|
||||
if deduped, ok := metrics["deduplicated_requests"].(int64); ok {
|
||||
// Allow for slight timing variations - at least 98 out of 100 should be deduplicated
|
||||
if deduped < int64(numRequests-2) {
|
||||
t.Errorf("Expected at least %d deduplicated requests, got %d", numRequests-2, deduped)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshRateLimiting verifies that refresh attempts are rate-limited per session
|
||||
func TestRefreshRateLimiting(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
config.MaxRefreshAttempts = 3
|
||||
config.RefreshAttemptWindow = 1 * time.Second
|
||||
config.RefreshCooldownPeriod = 2 * time.Second
|
||||
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
// Set circuit breaker to not interfere with rate limiting test
|
||||
// We want to test rate limiting, not circuit breaker
|
||||
coordinator.circuitBreaker.config.MaxFailures = 10
|
||||
|
||||
sessionID := "rate_limited_session"
|
||||
refreshToken := "test_refresh_token"
|
||||
|
||||
// Mock refresh function that always fails
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
return nil, fmt.Errorf("refresh failed")
|
||||
}
|
||||
|
||||
// Attempt refreshes beyond the limit
|
||||
var attempts int
|
||||
var cooldownTriggered bool
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
ctx := context.Background()
|
||||
_, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "refresh attempts exceeded for session, in cooldown period" {
|
||||
cooldownTriggered = true
|
||||
break
|
||||
}
|
||||
}
|
||||
attempts++
|
||||
// Add delay to ensure operations complete and aren't deduplicated
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Verify that cooldown was triggered after max attempts
|
||||
// With the new logic, the Nth attempt triggers cooldown, so we get N-1 successful attempts
|
||||
expectedSuccessfulAttempts := config.MaxRefreshAttempts - 1
|
||||
if attempts != expectedSuccessfulAttempts {
|
||||
t.Errorf("Expected %d successful attempts before cooldown, got %d", expectedSuccessfulAttempts, attempts)
|
||||
}
|
||||
|
||||
if !cooldownTriggered {
|
||||
t.Error("Cooldown was not triggered after max attempts")
|
||||
}
|
||||
|
||||
// Verify that requests are blocked during cooldown
|
||||
ctx := context.Background()
|
||||
_, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
||||
if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" {
|
||||
t.Error("Request should be blocked during cooldown period")
|
||||
}
|
||||
|
||||
// Wait for cooldown to expire
|
||||
time.Sleep(config.RefreshCooldownPeriod + 100*time.Millisecond)
|
||||
|
||||
// Verify that requests are allowed after cooldown
|
||||
_, err = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
||||
if err != nil && err.Error() == "refresh attempts exceeded for session, in cooldown period" {
|
||||
t.Error("Request should be allowed after cooldown period")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreakerProtection verifies that the circuit breaker prevents
|
||||
// cascading failures during repeated refresh failures
|
||||
func TestCircuitBreakerProtection(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
// Set circuit breaker to trip after 3 failures
|
||||
coordinator.circuitBreaker.config.MaxFailures = 3
|
||||
coordinator.circuitBreaker.config.OpenDuration = 1 * time.Second
|
||||
|
||||
// Mock refresh function that always fails
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
return nil, fmt.Errorf("service unavailable")
|
||||
}
|
||||
|
||||
// Cause circuit breaker to trip
|
||||
var tripCount int
|
||||
for i := 0; i < 5; i++ {
|
||||
ctx := context.Background()
|
||||
_, err := coordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
fmt.Sprintf("session_%d", i), // Different sessions
|
||||
"refresh_token",
|
||||
refreshFunc,
|
||||
)
|
||||
|
||||
if err != nil && err.Error() == "refresh circuit breaker is open due to repeated failures" {
|
||||
tripCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Verify circuit breaker tripped
|
||||
if tripCount == 0 {
|
||||
t.Error("Circuit breaker did not trip after repeated failures")
|
||||
}
|
||||
|
||||
// Verify circuit breaker state
|
||||
if coordinator.circuitBreaker.GetState() != "open" {
|
||||
t.Errorf("Expected circuit breaker state 'open', got '%s'", coordinator.circuitBreaker.GetState())
|
||||
}
|
||||
|
||||
// Wait for circuit to transition to half-open
|
||||
time.Sleep(coordinator.circuitBreaker.config.OpenDuration + 100*time.Millisecond)
|
||||
|
||||
// Mock successful refresh
|
||||
successfulRefreshFunc := func() (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
AccessToken: "new_token",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify circuit allows request in half-open state
|
||||
ctx := context.Background()
|
||||
_, err := coordinator.CoordinateRefresh(ctx, "session_recovery", "refresh_token", successfulRefreshFunc)
|
||||
if err != nil {
|
||||
t.Errorf("Circuit breaker should allow request in half-open state: %v", err)
|
||||
}
|
||||
|
||||
// Verify circuit closed after success
|
||||
if coordinator.circuitBreaker.GetState() != "closed" {
|
||||
t.Errorf("Expected circuit breaker state 'closed' after successful request, got '%s'",
|
||||
coordinator.circuitBreaker.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
// TestMemoryLeakPrevention verifies that the coordinator doesn't leak memory
|
||||
// during sustained concurrent refresh operations
|
||||
func TestMemoryLeakPrevention(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping memory leak test in short mode")
|
||||
}
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
config.CleanupInterval = 100 * time.Millisecond
|
||||
config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
// Force garbage collection and record initial memory
|
||||
runtime.GC()
|
||||
runtime.GC()
|
||||
var initialMem runtime.MemStats
|
||||
runtime.ReadMemStats(&initialMem)
|
||||
|
||||
// Run sustained concurrent operations
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numWorkers := 10
|
||||
wg.Add(numWorkers)
|
||||
|
||||
// Each worker continuously attempts refreshes
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
refreshCount := 0
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
// Simulate varying response times
|
||||
time.Sleep(time.Duration(workerID*10) * time.Millisecond)
|
||||
return &TokenResponse{
|
||||
AccessToken: fmt.Sprintf("token_%d_%d", workerID, refreshCount),
|
||||
RefreshToken: fmt.Sprintf("refresh_%d_%d", workerID, refreshCount),
|
||||
}, nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
sessionID := fmt.Sprintf("session_%d", workerID)
|
||||
refreshToken := fmt.Sprintf("refresh_%d_%d", workerID, refreshCount)
|
||||
|
||||
_, _ = coordinator.CoordinateRefresh(
|
||||
context.Background(),
|
||||
sessionID,
|
||||
refreshToken,
|
||||
refreshFunc,
|
||||
)
|
||||
|
||||
refreshCount++
|
||||
// Small delay to prevent CPU saturation
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for workers to complete
|
||||
wg.Wait()
|
||||
|
||||
// Allow cleanup to run
|
||||
time.Sleep(2 * config.CleanupInterval)
|
||||
|
||||
// Force garbage collection and check memory
|
||||
runtime.GC()
|
||||
runtime.GC()
|
||||
var finalMem runtime.MemStats
|
||||
runtime.ReadMemStats(&finalMem)
|
||||
|
||||
// Calculate memory growth safely to prevent underflow
|
||||
var memGrowthMB float64
|
||||
if finalMem.HeapAlloc >= initialMem.HeapAlloc {
|
||||
memGrowthMB = float64(finalMem.HeapAlloc-initialMem.HeapAlloc) / (1024 * 1024)
|
||||
} else {
|
||||
// Memory decreased (GC occurred), treat as 0 growth
|
||||
memGrowthMB = 0
|
||||
}
|
||||
|
||||
// Log memory statistics for debugging
|
||||
t.Logf("Initial memory: %.2f MB", float64(initialMem.HeapAlloc)/(1024*1024))
|
||||
t.Logf("Final memory: %.2f MB", float64(finalMem.HeapAlloc)/(1024*1024))
|
||||
t.Logf("Memory growth: %.2f MB", memGrowthMB)
|
||||
|
||||
// Check for excessive memory growth (threshold: 50MB)
|
||||
if memGrowthMB > 50 {
|
||||
t.Errorf("Excessive memory growth detected: %.2f MB", memGrowthMB)
|
||||
}
|
||||
|
||||
// Verify no lingering operations
|
||||
metrics := coordinator.GetMetrics()
|
||||
if inflight, ok := metrics["current_inflight"].(int32); ok {
|
||||
if inflight != 0 {
|
||||
t.Errorf("Expected 0 in-flight operations after completion, got %d", inflight)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify cleanup is working
|
||||
coordinator.attemptsMutex.RLock()
|
||||
sessionCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
|
||||
// Should have cleaned up old sessions (only recent ones remain)
|
||||
if sessionCount > numWorkers*2 {
|
||||
t.Errorf("Session cleanup not working properly, %d sessions remain", sessionCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshTimeoutHandling verifies that refresh operations timeout properly
|
||||
func TestRefreshTimeoutHandling(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
config.RefreshTimeout = 100 * time.Millisecond
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
// Mock refresh function that hangs
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
time.Sleep(1 * time.Second) // Much longer than timeout
|
||||
return &TokenResponse{AccessToken: "token"}, nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
start := time.Now()
|
||||
|
||||
_, err := coordinator.CoordinateRefresh(ctx, "session", "refresh_token", refreshFunc)
|
||||
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Verify timeout occurred
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error, got nil")
|
||||
}
|
||||
|
||||
// Verify it timed out within reasonable bounds
|
||||
if elapsed > 200*time.Millisecond {
|
||||
t.Errorf("Timeout took too long: %v", elapsed)
|
||||
}
|
||||
|
||||
if err != nil && err.Error() != fmt.Sprintf("refresh operation timed out after %v", config.RefreshTimeout) {
|
||||
t.Errorf("Unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentDifferentTokens verifies that refreshes for different tokens
|
||||
// proceed independently without blocking each other
|
||||
func TestConcurrentDifferentTokens(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
numTokens := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numTokens)
|
||||
|
||||
// Track execution order
|
||||
executionOrder := make([]int, 0, numTokens)
|
||||
var executionMutex sync.Mutex
|
||||
|
||||
for i := 0; i < numTokens; i++ {
|
||||
go func(tokenID int) {
|
||||
defer wg.Done()
|
||||
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
executionMutex.Lock()
|
||||
executionOrder = append(executionOrder, tokenID)
|
||||
executionMutex.Unlock()
|
||||
|
||||
// Varying processing times
|
||||
time.Sleep(time.Duration(tokenID*10) * time.Millisecond)
|
||||
|
||||
return &TokenResponse{
|
||||
AccessToken: fmt.Sprintf("token_%d", tokenID),
|
||||
RefreshToken: fmt.Sprintf("refresh_%d", tokenID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
resp, err := coordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
fmt.Sprintf("session_%d", tokenID),
|
||||
fmt.Sprintf("refresh_token_%d", tokenID),
|
||||
refreshFunc,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Token %d refresh failed: %v", tokenID, err)
|
||||
}
|
||||
|
||||
if resp == nil || resp.AccessToken != fmt.Sprintf("token_%d", tokenID) {
|
||||
t.Errorf("Token %d got wrong response", tokenID)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all tokens were processed
|
||||
if len(executionOrder) != numTokens {
|
||||
t.Errorf("Expected %d executions, got %d", numTokens, len(executionOrder))
|
||||
}
|
||||
|
||||
// Verify no deduplication occurred (all different tokens)
|
||||
metrics := coordinator.GetMetrics()
|
||||
if deduped, ok := metrics["deduplicated_requests"].(int64); ok {
|
||||
if deduped != 0 {
|
||||
t.Errorf("No deduplication expected for different tokens, got %d", deduped)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMaxConcurrentRefreshes verifies that the coordinator respects
|
||||
// the maximum concurrent refresh limit
|
||||
func TestMaxConcurrentRefreshes(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
config.MaxConcurrentRefreshes = 2
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
// Track concurrent executions
|
||||
var currentConcurrent int32
|
||||
var maxConcurrent int32
|
||||
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
current := atomic.AddInt32(¤tConcurrent, 1)
|
||||
|
||||
// Update max if needed
|
||||
for {
|
||||
max := atomic.LoadInt32(&maxConcurrent)
|
||||
if current <= max || atomic.CompareAndSwapInt32(&maxConcurrent, max, current) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
atomic.AddInt32(¤tConcurrent, -1)
|
||||
|
||||
return &TokenResponse{AccessToken: "token"}, nil
|
||||
}
|
||||
|
||||
numRequests := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numRequests)
|
||||
|
||||
errors := make([]error, 0, numRequests)
|
||||
var errorMutex sync.Mutex
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := coordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
fmt.Sprintf("session_%d", id),
|
||||
fmt.Sprintf("token_%d", id),
|
||||
refreshFunc,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errorMutex.Lock()
|
||||
errors = append(errors, err)
|
||||
errorMutex.Unlock()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Some requests should have been rejected due to concurrency limit
|
||||
if len(errors) == 0 {
|
||||
t.Error("Expected some requests to be rejected due to concurrency limit")
|
||||
}
|
||||
|
||||
// Verify max concurrent never exceeded limit
|
||||
if maxConcurrent > int32(config.MaxConcurrentRefreshes) {
|
||||
t.Errorf("Max concurrent refreshes (%d) exceeded limit (%d)",
|
||||
maxConcurrent, config.MaxConcurrentRefreshes)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionWindowReset verifies that refresh attempt windows reset properly
|
||||
func TestSessionWindowReset(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
config.MaxRefreshAttempts = 2
|
||||
config.RefreshAttemptWindow = 500 * time.Millisecond
|
||||
config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior
|
||||
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
// Set circuit breaker to not interfere with rate limiting test
|
||||
coordinator.circuitBreaker.config.MaxFailures = 10
|
||||
|
||||
// Use unique identifiers to prevent test interference
|
||||
sessionID := fmt.Sprintf("window_test_session_%d", time.Now().UnixNano())
|
||||
refreshToken := fmt.Sprintf("test_refresh_token_%d", time.Now().UnixNano())
|
||||
|
||||
// Mock refresh function that always fails
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
return nil, fmt.Errorf("refresh failed")
|
||||
}
|
||||
|
||||
// Use up the attempts in the first window
|
||||
for i := 0; i < config.MaxRefreshAttempts; i++ {
|
||||
ctx := context.Background()
|
||||
_, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
||||
}
|
||||
|
||||
// Next attempt should trigger cooldown
|
||||
ctx := context.Background()
|
||||
_, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
||||
if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" {
|
||||
t.Error("Expected cooldown after max attempts")
|
||||
}
|
||||
|
||||
// Wait for window to expire (but not cooldown)
|
||||
time.Sleep(config.RefreshAttemptWindow + 100*time.Millisecond)
|
||||
|
||||
// Should still be in cooldown (cooldown > window)
|
||||
_, err = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
||||
if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" {
|
||||
t.Error("Should still be in cooldown period")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConcurrentRefreshDeduplication measures performance of deduplication
|
||||
func BenchmarkConcurrentRefreshDeduplication(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
return &TokenResponse{
|
||||
AccessToken: "token",
|
||||
}, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
ctx := context.Background()
|
||||
sessionID := fmt.Sprintf("session_%d", i%10) // Reuse 10 sessions
|
||||
refreshToken := fmt.Sprintf("token_%d", i%10) // Reuse 10 tokens
|
||||
_, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
||||
i++
|
||||
}
|
||||
})
|
||||
|
||||
b.StopTimer()
|
||||
|
||||
// Report metrics
|
||||
metrics := coordinator.GetMetrics()
|
||||
b.Logf("Total requests: %v", metrics["total_requests"])
|
||||
b.Logf("Deduplicated: %v", metrics["deduplicated_requests"])
|
||||
}
|
||||
|
||||
// TestCleanupRoutine verifies that the cleanup routine removes stale entries
|
||||
func TestCleanupRoutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
config.CleanupInterval = 100 * time.Millisecond
|
||||
config.RefreshAttemptWindow = 200 * time.Millisecond
|
||||
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
// Add some sessions
|
||||
for i := 0; i < 5; i++ {
|
||||
coordinator.recordRefreshAttempt(fmt.Sprintf("session_%d", i))
|
||||
}
|
||||
|
||||
// Verify sessions exist
|
||||
coordinator.attemptsMutex.RLock()
|
||||
initialCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
|
||||
if initialCount != 5 {
|
||||
t.Errorf("Expected 5 sessions, got %d", initialCount)
|
||||
}
|
||||
|
||||
// Wait for cleanup to run (2x window + cleanup interval)
|
||||
time.Sleep(2*config.RefreshAttemptWindow + 2*config.CleanupInterval)
|
||||
|
||||
// Verify sessions were cleaned up
|
||||
coordinator.attemptsMutex.RLock()
|
||||
finalCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
|
||||
if finalCount != 0 {
|
||||
t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestRefreshCoordinatorRaceCondition specifically tests for race conditions
|
||||
// in the refresh coordinator's concurrent operation handling
|
||||
func TestRefreshCoordinatorRaceCondition(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
// Increase rate limit for this race condition test
|
||||
config.MaxRefreshAttempts = 100 // Allow many attempts for race testing
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
// Test concurrent access to the same refresh token
|
||||
var executions int32
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
atomic.AddInt32(&executions, 1)
|
||||
time.Sleep(50 * time.Millisecond) // Simulate work
|
||||
return &TokenResponse{
|
||||
AccessToken: "test_token",
|
||||
RefreshToken: "test_refresh",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Launch many goroutines concurrently
|
||||
const numGoroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
ctx := context.Background()
|
||||
sessionID := "test_session"
|
||||
refreshToken := "test_refresh_token"
|
||||
|
||||
// Use a channel to ensure all goroutines start at the same time
|
||||
startChan := make(chan struct{})
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Wait for signal to start
|
||||
<-startChan
|
||||
|
||||
// All goroutines try to refresh at the same time
|
||||
result, err := coordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
sessionID,
|
||||
refreshToken,
|
||||
refreshFunc,
|
||||
)
|
||||
|
||||
// Basic validation
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: unexpected error: %v", id, err)
|
||||
}
|
||||
if result == nil || result.AccessToken != "test_token" {
|
||||
t.Errorf("Goroutine %d: invalid result", id)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Release all goroutines at once
|
||||
close(startChan)
|
||||
|
||||
// Wait for completion
|
||||
wg.Wait()
|
||||
|
||||
// Check that deduplication worked
|
||||
actualExecutions := atomic.LoadInt32(&executions)
|
||||
t.Logf("Executions: %d out of %d goroutines", actualExecutions, numGoroutines)
|
||||
|
||||
// With proper deduplication, we should have very few executions
|
||||
// Allow for some timing slack - up to 3 executions is acceptable
|
||||
if actualExecutions > 3 {
|
||||
t.Errorf("Too many refresh executions: %d (expected <= 3)", actualExecutions)
|
||||
}
|
||||
|
||||
// Verify metrics
|
||||
metrics := coordinator.GetMetrics()
|
||||
if total, ok := metrics["total_requests"].(int64); ok {
|
||||
if total != int64(numGoroutines) {
|
||||
t.Errorf("Expected %d total requests, got %d", numGoroutines, total)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshCoordinatorNoRaceWithDifferentTokens verifies no interference
|
||||
// between different refresh tokens
|
||||
func TestRefreshCoordinatorNoRaceWithDifferentTokens(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
// Increase concurrent limit to handle 10 different tokens
|
||||
config.MaxConcurrentRefreshes = 15
|
||||
config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior
|
||||
// Increase rate limit since we have 5 goroutines per token
|
||||
config.MaxRefreshAttempts = 10 // Allow multiple attempts per session
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
|
||||
const numTokens = 10
|
||||
const goroutinesPerToken = 5
|
||||
|
||||
var totalExecutions int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numTokens * goroutinesPerToken)
|
||||
|
||||
refreshFunc := func() (*TokenResponse, error) {
|
||||
atomic.AddInt32(&totalExecutions, 1)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
return &TokenResponse{
|
||||
AccessToken: "token",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Launch goroutines for different tokens with unique identifiers
|
||||
baseID := time.Now().UnixNano()
|
||||
for tokenID := 0; tokenID < numTokens; tokenID++ {
|
||||
sessionID := fmt.Sprintf("session_%d_%d", baseID, tokenID)
|
||||
refreshToken := fmt.Sprintf("refresh_%d_%d", baseID, tokenID)
|
||||
|
||||
for i := 0; i < goroutinesPerToken; i++ {
|
||||
go func(tid, gid int) {
|
||||
defer wg.Done()
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := coordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
sessionID,
|
||||
refreshToken,
|
||||
refreshFunc,
|
||||
)
|
||||
|
||||
if err != nil && err.Error() != "maximum concurrent refresh operations reached" {
|
||||
// Only log non-concurrent-limit errors as failures
|
||||
t.Errorf("Token %d, Goroutine %d: unexpected error: %v", tid, gid, err)
|
||||
}
|
||||
}(tokenID, i)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Each token should have had ~1 execution (maybe 2 due to timing)
|
||||
actualExecutions := atomic.LoadInt32(&totalExecutions)
|
||||
t.Logf("Total executions: %d for %d different tokens", actualExecutions, numTokens)
|
||||
|
||||
// Should be close to numTokens (one per unique token)
|
||||
if actualExecutions > numTokens*2 {
|
||||
t.Errorf("Too many executions: %d (expected ~%d)", actualExecutions, numTokens)
|
||||
}
|
||||
}
|
||||
+5
-3
@@ -637,12 +637,13 @@ func (sm *SessionManager) EnhanceSessionSecurity(options *sessions.Options, r *h
|
||||
options.Secure = true
|
||||
}
|
||||
|
||||
if r.Header.Get("X-Requested-With") == "XMLHttpRequest" {
|
||||
options.SameSite = http.SameSiteStrictMode
|
||||
}
|
||||
// Keep SameSite=Lax consistently for OAuth flows
|
||||
// Removed dynamic switching based on XMLHttpRequest header to prevent redirect loop
|
||||
options.SameSite = http.SameSiteLaxMode
|
||||
}
|
||||
|
||||
options.HttpOnly = true
|
||||
options.Path = "/" // Ensure cookies are available on all paths for OAuth flow
|
||||
|
||||
if sm.cookieDomain != "" {
|
||||
options.Domain = sm.cookieDomain
|
||||
@@ -930,6 +931,7 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
sd.mainSession.Options = options
|
||||
sd.accessSession.Options = options
|
||||
sd.refreshSession.Options = options
|
||||
sd.idTokenSession.Options = options
|
||||
|
||||
var firstErr error
|
||||
saveOrLogError := func(s *sessions.Session, name string) {
|
||||
|
||||
@@ -34,6 +34,11 @@ var (
|
||||
globalMaxSessions int64 = 5000 // CRITICAL FIX: Global limit of 5000 total sessions
|
||||
)
|
||||
|
||||
// ResetGlobalSessionCounters resets global session tracking for testing
|
||||
func ResetGlobalSessionCounters() {
|
||||
atomic.StoreInt64(&globalSessionCount, 0)
|
||||
}
|
||||
|
||||
// Predefined configurations for each token type
|
||||
var (
|
||||
AccessTokenConfig = TokenConfig{
|
||||
|
||||
@@ -33,6 +33,11 @@ var (
|
||||
globalMaxSessions int64 = 5000 // CRITICAL FIX: Global limit of 5000 total sessions
|
||||
)
|
||||
|
||||
// ResetGlobalSessionCounters resets global session tracking for testing
|
||||
func ResetGlobalSessionCounters() {
|
||||
atomic.StoreInt64(&globalSessionCount, 0)
|
||||
}
|
||||
|
||||
// Predefined configurations for each token type
|
||||
var (
|
||||
AccessTokenConfig = TokenConfig{
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/session/chunking"
|
||||
)
|
||||
|
||||
// GlobalTestCleanup tracks and cleans up test resources
|
||||
@@ -113,6 +115,15 @@ func (g *GlobalTestCleanup) CleanupAll() {
|
||||
// Reset all global singletons to prevent state pollution between tests
|
||||
ResetGlobalMemoryMonitor()
|
||||
ResetGlobalTaskRegistry()
|
||||
ResetGlobalMemoryOptimizations()
|
||||
ResetSingletonNoOpLogger()
|
||||
|
||||
// Reset global session counters to prevent overflow in memory calculations
|
||||
ResetGlobalSessionCounters()
|
||||
|
||||
// Reset global session counters in chunking package as well
|
||||
// Note: This calls the function in session/chunking package
|
||||
resetChunkingGlobalSessionCounters()
|
||||
|
||||
// Give background tasks time to finish cleanup
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -949,3 +960,9 @@ func (h *PerformanceTestHelper) Reset() {
|
||||
defer h.mu.Unlock()
|
||||
h.samples = h.samples[:0]
|
||||
}
|
||||
|
||||
// resetChunkingGlobalSessionCounters resets the global session counters
|
||||
// in the chunking package to prevent test interference
|
||||
func resetChunkingGlobalSessionCounters() {
|
||||
chunking.ResetGlobalSessionCounters()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user