diff --git a/autocleanup.go b/autocleanup.go index 82930d9..ecba081 100644 --- a/autocleanup.go +++ b/autocleanup.go @@ -266,18 +266,21 @@ func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error { max := atomic.LoadInt32(&cb.maxConcurrent) // For cleanup tasks, be more restrictive (singleton-like behavior) + // However, allow distinct realm-specific tasks (e.g., singleton-metadata-refresh-abc123 vs singleton-metadata-refresh-def456) if strings.Contains(taskName, "cleanup") || strings.Contains(taskName, "singleton") { cb.tasksMu.RLock() - hasCleanupTask := false + hasSameTask := false for activeTask := range cb.activeTasks { - if strings.Contains(activeTask, "cleanup") || strings.Contains(activeTask, "singleton") { - hasCleanupTask = true + // Only block if the EXACT same task is already running + // This allows realm-specific tasks like singleton-metadata-refresh-{hash} to run concurrently + if activeTask == taskName { + hasSameTask = true break } } cb.tasksMu.RUnlock() - if hasCleanupTask { + if hasSameTask { return fmt.Errorf("cleanup/singleton task already running: %s", taskName) } } diff --git a/main.go b/main.go index 66188aa..e49402a 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,8 @@ package traefikoidc import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" "net/http" "os" @@ -392,11 +394,15 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662) t.registrationURL = metadata.RegistrationURL // OIDC Dynamic Client Registration endpoint (RFC 7591) + // Copy values for logging after unlock to avoid race conditions + introspectionURL := t.introspectionURL + registrationURL := t.registrationURL + t.metadataMu.Unlock() // Log introspection endpoint availability for opaque token support - if t.introspectionURL != "" { - t.logger.Debugf("Token introspection endpoint discovered: %s", t.introspectionURL) + if introspectionURL != "" { + t.logger.Debugf("Token introspection endpoint discovered: %s", introspectionURL) if t.allowOpaqueTokens { t.logger.Debugf("Opaque token support enabled with introspection endpoint") } @@ -405,8 +411,8 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { } // Log registration endpoint availability - if t.registrationURL != "" { - t.logger.Debugf("Dynamic client registration endpoint discovered: %s", t.registrationURL) + if registrationURL != "" { + t.logger.Debugf("Dynamic client registration endpoint discovered: %s", registrationURL) } // Perform Dynamic Client Registration if enabled and ClientID is not set @@ -474,7 +480,10 @@ func (t *TraefikOidc) performDynamicClientRegistration() { func (t *TraefikOidc) startMetadataRefresh(providerURL string) { // Use singleton resource manager for metadata refresh rm := GetResourceManager() - taskName := "singleton-metadata-refresh" + // Use last 6 chars of provider URL hash to create unique task name per realm + // This fixes multi-realm support where different Keycloak realms need separate refresh tasks + hash := sha256.Sum256([]byte(providerURL)) + taskName := "singleton-metadata-refresh-" + hex.EncodeToString(hash[:])[0:6] // Create refresh function refreshFunc := func() { @@ -510,6 +519,27 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) { } } +// attemptMetadataRecovery tries to fetch provider metadata when the system is in a failed state. +// This is called periodically (every 30s) when requests come in and metadata is unavailable. +// It allows automatic recovery when the OIDC provider becomes available again. +func (t *TraefikOidc) attemptMetadataRecovery() { + if t.metadataCache == nil || t.httpClient == nil { + return + } + + // Try to fetch metadata (single attempt, no aggressive retry here since this runs every 30s) + metadata, err := t.metadataCache.GetMetadata(t.providerURL, t.httpClient, t.logger) + if err != nil { + t.safeLogDebugf("Metadata recovery attempt failed: %v", err) + return + } + + if metadata != nil { + t.updateMetadataEndpoints(metadata) + t.safeLogInfo("Successfully recovered OIDC provider metadata - service restored") + } +} + // createCaseInsensitiveStringMap creates a map with lowercase keys for case-insensitive matching. // This is used for case-insensitive matching of email addresses. // Parameters: diff --git a/main_test.go b/main_test.go index 63f665d..13234b5 100644 --- a/main_test.go +++ b/main_test.go @@ -209,11 +209,8 @@ func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient * } func (m *MockJWKCache) Cleanup() { - // Mock cleanup implementation - m.mu.Lock() - defer m.mu.Unlock() - m.JWKS = nil - m.Err = nil + // Mock cleanup is a no-op - we don't want to destroy the mock JWKS data + // Real cleanup is for expired entries, not resetting all data } // MockTokenVerifier implements TokenVerifier for testing, allowing interception of VerifyToken calls. @@ -2427,6 +2424,276 @@ func TestMultipleMiddlewareInstances(t *testing.T) { } } +// TestMultiRealmMetadataRefreshIsolation verifies that multiple middleware instances +// with different provider URLs (e.g., different Keycloak realms) get separate +// metadata refresh tasks. This addresses the issue reported in PR #88. +func TestMultiRealmMetadataRefreshIsolation(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test in short mode") + } + + // Create two mock provider metadata servers simulating different Keycloak realms + realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + w.WriteHeader(http.StatusNotFound) + return + } + metadata := ProviderMetadata{ + Issuer: "https://keycloak.example.com/realms/realm1", + AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth", + TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token", + JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs", + EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout", + } + json.NewEncoder(w).Encode(metadata) + })) + defer realm1Server.Close() + + realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + w.WriteHeader(http.StatusNotFound) + return + } + metadata := ProviderMetadata{ + Issuer: "https://keycloak.example.com/realms/realm2", + AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth", + TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token", + JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs", + EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout", + } + json.NewEncoder(w).Encode(metadata) + })) + defer realm2Server.Close() + + // Config for realm1 + config1 := &Config{ + ProviderURL: realm1Server.URL, + ClientID: "realm1-client", + ClientSecret: "realm1-secret", + CallbackURL: "/realm1/callback", + SessionEncryptionKey: "test-encryption-key-thats-long-enough", + CookiePrefix: "_oidc_realm1_", + } + + // Config for realm2 + config2 := &Config{ + ProviderURL: realm2Server.URL, + ClientID: "realm2-client", + ClientSecret: "realm2-secret", + CallbackURL: "/realm2/callback", + SessionEncryptionKey: "test-encryption-key-thats-long-enough", + CookiePrefix: "_oidc_realm2_", + } + + // Create middleware instances for both realms + middleware1, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), config1, "realm1-middleware") + if err != nil { + t.Fatalf("Failed to create middleware for realm1: %v", err) + } + + middleware2, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), config2, "realm2-middleware") + if err != nil { + t.Fatalf("Failed to create middleware for realm2: %v", err) + } + + m1, ok1 := middleware1.(*TraefikOidc) + m2, ok2 := middleware2.(*TraefikOidc) + if !ok1 || !ok2 { + t.Fatalf("Middleware is not of type *TraefikOidc") + } + + // Clean up middleware instances + defer func() { + if err := m1.Close(); err != nil { + t.Errorf("Failed to close realm1 middleware: %v", err) + } + if err := m2.Close(); err != nil { + t.Errorf("Failed to close realm2 middleware: %v", err) + } + }() + + // Wait for both instances to initialize + select { + case <-m1.initComplete: + case <-time.After(5 * time.Second): + t.Fatalf("Realm1 middleware failed to initialize") + } + + select { + case <-m2.initComplete: + case <-time.After(5 * time.Second): + t.Fatalf("Realm2 middleware failed to initialize") + } + + // Verify each instance has the correct issuer URL from their respective realms + if !strings.Contains(m1.issuerURL, "realm1") { + t.Errorf("Realm1 middleware expected issuer with realm1, got %s", m1.issuerURL) + } + if !strings.Contains(m2.issuerURL, "realm2") { + t.Errorf("Realm2 middleware expected issuer with realm2, got %s", m2.issuerURL) + } + + // Verify provider URLs are different + if m1.providerURL == m2.providerURL { + t.Errorf("Both middlewares should have different provider URLs, got same: %s", m1.providerURL) + } + + // Test that each middleware can handle requests independently + req1 := httptest.NewRequest("GET", "/realm1/protected", nil) + rr1 := httptest.NewRecorder() + m1.ServeHTTP(rr1, req1) + + req2 := httptest.NewRequest("GET", "/realm2/protected", nil) + rr2 := httptest.NewRecorder() + m2.ServeHTTP(rr2, req2) + + // Both should redirect to their respective auth URLs + if rr1.Code != http.StatusFound { + t.Errorf("Realm1: Expected redirect status %d, got %d", http.StatusFound, rr1.Code) + } + if rr2.Code != http.StatusFound { + t.Errorf("Realm2: Expected redirect status %d, got %d", http.StatusFound, rr2.Code) + } + + location1 := rr1.Header().Get("Location") + location2 := rr2.Header().Get("Location") + + if !strings.Contains(location1, "realm1") { + t.Errorf("Realm1: Expected redirect to realm1 auth URL, got %s", location1) + } + if !strings.Contains(location2, "realm2") { + t.Errorf("Realm2: Expected redirect to realm2 auth URL, got %s", location2) + } +} + +// TestMetadataRecoveryOnProviderFailure verifies that the middleware automatically +// recovers when the OIDC provider becomes available after initial failure. +func TestMetadataRecoveryOnProviderFailure(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test in short mode") + } + + // Track whether the provider is "available" + providerAvailable := false + var mu sync.Mutex + + // Create mock provider that initially fails, then becomes available + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + available := providerAvailable + mu.Unlock() + + if !available { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + + if r.URL.Path == "/.well-known/openid-configuration" { + metadata := ProviderMetadata{ + Issuer: "https://test-issuer.com", + AuthURL: "https://test-issuer.com/auth", + TokenURL: "https://test-issuer.com/token", + JWKSURL: "https://test-issuer.com/jwks", + EndSessionURL: "https://test-issuer.com/logout", + } + json.NewEncoder(w).Encode(metadata) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer mockServer.Close() + + config := &Config{ + ProviderURL: mockServer.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackURL: "/callback", + SessionEncryptionKey: "test-encryption-key-thats-long-enough", + } + + // Create middleware while provider is unavailable + middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), config, "test-recovery") + if err != nil { + t.Fatalf("Failed to create middleware: %v", err) + } + + m, ok := middleware.(*TraefikOidc) + if !ok { + t.Fatalf("Middleware is not of type *TraefikOidc") + } + defer m.Close() + + // Wait for initial initialization to complete (it should fail) + select { + case <-m.initComplete: + case <-time.After(15 * time.Second): + t.Fatal("Initialization did not complete in time") + } + + // Verify initial state - should be in failed state (no issuerURL) + m.metadataMu.RLock() + initialIssuer := m.issuerURL + m.metadataMu.RUnlock() + + if initialIssuer != "" { + t.Errorf("Expected empty issuerURL after failed init, got: %s", initialIssuer) + } + + // First request should get 503 + req1 := httptest.NewRequest("GET", "/protected", nil) + rr1 := httptest.NewRecorder() + m.ServeHTTP(rr1, req1) + + if rr1.Code != http.StatusServiceUnavailable { + t.Errorf("Expected 503 when provider unavailable, got %d", rr1.Code) + } + + // Now make the provider available + mu.Lock() + providerAvailable = true + mu.Unlock() + + // Reset the retry timer to allow immediate retry + m.metadataRetryMutex.Lock() + m.lastMetadataRetryTime = time.Time{} // Reset to zero time + m.metadataRetryMutex.Unlock() + + // Second request should trigger recovery attempt + req2 := httptest.NewRequest("GET", "/protected", nil) + rr2 := httptest.NewRecorder() + m.ServeHTTP(rr2, req2) + + // Give the async recovery a moment to complete + time.Sleep(100 * time.Millisecond) + + // Check if recovery happened + m.metadataMu.RLock() + recoveredIssuer := m.issuerURL + m.metadataMu.RUnlock() + + if recoveredIssuer == "" { + t.Error("Expected issuerURL to be recovered after provider became available") + } + + // Third request should succeed (redirect to auth, not 503) + req3 := httptest.NewRequest("GET", "/protected", nil) + rr3 := httptest.NewRecorder() + m.ServeHTTP(rr3, req3) + + if rr3.Code == http.StatusServiceUnavailable { + t.Errorf("Expected redirect after recovery, still got 503") + } + + t.Logf("Recovery test: initial_issuer=%q, recovered_issuer=%q, final_status=%d", + initialIssuer, recoveredIssuer, rr3.Code) +} + func TestServeHTTPRolesAndGroups(t *testing.T) { ts := NewTestSuite(t) ts.Setup() diff --git a/middleware.go b/middleware.go index 52372b6..afb0595 100644 --- a/middleware.go +++ b/middleware.go @@ -50,6 +50,20 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.metadataMu.RUnlock() if issuerURL == "" { + // Provider metadata initialization failed - try to recover + // Retry every 30 seconds to allow automatic recovery when provider comes back online + t.metadataRetryMutex.Lock() + shouldRetry := time.Since(t.lastMetadataRetryTime) >= 30*time.Second + if shouldRetry { + t.lastMetadataRetryTime = time.Now() + } + t.metadataRetryMutex.Unlock() + + if shouldRetry && t.providerURL != "" { + t.logger.Info("Attempting to recover OIDC provider metadata...") + go t.attemptMetadataRecovery() + } + t.logger.Error("OIDC provider metadata initialization failed or incomplete") t.sendErrorResponse(rw, req, "OIDC provider metadata initialization failed - please check provider availability and configuration", http.StatusServiceUnavailable) return diff --git a/singleton_resources_test.go b/singleton_resources_test.go index af63ace..634f6df 100644 --- a/singleton_resources_test.go +++ b/singleton_resources_test.go @@ -2,6 +2,8 @@ package traefikoidc import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" "runtime" "sync" @@ -97,6 +99,89 @@ func TestSingletonResourceManager(t *testing.T) { } }) + t.Run("MultiRealmMetadataRefreshTaskNaming", func(t *testing.T) { + // This test verifies that different provider URLs generate different task names + // which is critical for multi-realm Keycloak support (PR #88) + + // Reset singletons for clean test state + resetResourceManagerForTesting() + ResetGlobalTaskRegistry() + defer ResetGlobalTaskRegistry() + rm := GetResourceManager() + + // Simulate different Keycloak realms + providerURL1 := "https://keycloak.example.com/realms/realm1" + providerURL2 := "https://keycloak.example.com/realms/realm2" + + // Generate task names using the same logic as startMetadataRefresh + hash1 := sha256.Sum256([]byte(providerURL1)) + taskName1 := "singleton-metadata-refresh-" + hex.EncodeToString(hash1[:])[0:6] + + hash2 := sha256.Sum256([]byte(providerURL2)) + taskName2 := "singleton-metadata-refresh-" + hex.EncodeToString(hash2[:])[0:6] + + // Verify task names are different + if taskName1 == taskName2 { + t.Errorf("Task names should be different for different provider URLs: %s vs %s", taskName1, taskName2) + } + + // Register both tasks + task1Called := int32(0) + task2Called := int32(0) + + err := rm.RegisterBackgroundTask(taskName1, 100*time.Millisecond, func() { + atomic.AddInt32(&task1Called, 1) + }) + if err != nil { + t.Errorf("Failed to register task 1: %v", err) + } + + err = rm.RegisterBackgroundTask(taskName2, 100*time.Millisecond, func() { + atomic.AddInt32(&task2Called, 1) + }) + if err != nil { + t.Errorf("Failed to register task 2: %v", err) + } + + // Start both tasks + _ = rm.StartBackgroundTask(taskName1) + _ = rm.StartBackgroundTask(taskName2) + + // Wait for tasks to execute + time.Sleep(250 * time.Millisecond) + + // Verify both tasks are running independently + if !rm.IsTaskRunning(taskName1) { + t.Error("Task 1 should be running") + } + if !rm.IsTaskRunning(taskName2) { + t.Error("Task 2 should be running") + } + + // Verify both tasks were called (at least once) + if atomic.LoadInt32(&task1Called) == 0 { + t.Error("Task 1 should have been called at least once") + } + if atomic.LoadInt32(&task2Called) == 0 { + t.Error("Task 2 should have been called at least once") + } + + // Stop both tasks + _ = rm.StopBackgroundTask(taskName1) + _ = rm.StopBackgroundTask(taskName2) + + // Verify tasks are stopped + time.Sleep(50 * time.Millisecond) + if rm.IsTaskRunning(taskName1) { + t.Error("Task 1 should be stopped") + } + if rm.IsTaskRunning(taskName2) { + t.Error("Task 2 should be stopped") + } + + t.Logf("Successfully verified multi-realm task isolation: task1=%s, task2=%s", taskName1, taskName2) + }) + t.Run("ReferenceCountingCleanup", func(t *testing.T) { rm := GetResourceManager() diff --git a/types.go b/types.go index 9f8be29..607616f 100644 --- a/types.go +++ b/types.go @@ -128,8 +128,10 @@ type TraefikOidc struct { suppressDiagnosticLogs bool firstRequestReceived bool metadataRefreshStarted bool - allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks) - minimalHeaders bool // Reduce headers to prevent 431 errors + lastMetadataRetryTime time.Time // Track last metadata retry for failed state recovery + metadataRetryMutex sync.Mutex // Protects lastMetadataRetryTime + allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks) + minimalHeaders bool // Reduce headers to prevent 431 errors securityHeadersApplier func(http.ResponseWriter, *http.Request) scopeFilter *ScopeFilter // NEW - for discovery-based scope filtering scopesSupported []string // NEW - from provider metadata diff --git a/utilities.go b/utilities.go index cd7c21a..0df262b 100644 --- a/utilities.go +++ b/utilities.go @@ -3,6 +3,8 @@ package traefikoidc import ( + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "html" @@ -222,8 +224,13 @@ func (t *TraefikOidc) Close() error { rm := GetResourceManager() // Stop singleton tasks related to this instance - _ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup - _ = rm.StopBackgroundTask("singleton-metadata-refresh") // Safe to ignore: best effort cleanup + _ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup + // Stop metadata refresh task using same hash-based name as startMetadataRefresh + if t.providerURL != "" { + hash := sha256.Sum256([]byte(t.providerURL)) + taskName := "singleton-metadata-refresh-" + hex.EncodeToString(hash[:])[0:6] + _ = rm.StopBackgroundTask(taskName) // Safe to ignore: best effort cleanup + } // Remove reference for this instance rm.RemoveReference(t.name)