From d0b920c4f07e538aad3fc465d70b8bd0ba1e04b6 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Wed, 10 Dec 2025 13:07:22 +0000 Subject: [PATCH] multiple realms fix (#102) * Allow to use multiple realms This change is a ressurection of PR #88 which can't be merged due to significant refactor of the codebase. * Fix the autocleanup routine to handle multiple realms correctly, update tests. * Metadata rediscovery when provider is unavailable for any reason during the start. This one prevents the permanent 503 from the plugin when OIDC provider was for some reason unavailable during the start. --- autocleanup.go | 11 +- main.go | 40 +++++- main_test.go | 277 +++++++++++++++++++++++++++++++++++- middleware.go | 14 ++ singleton_resources_test.go | 85 +++++++++++ types.go | 6 +- utilities.go | 11 +- 7 files changed, 426 insertions(+), 18 deletions(-) diff --git a/autocleanup.go b/autocleanup.go index 82930d9..ecba081 100644 --- a/autocleanup.go +++ b/autocleanup.go @@ -266,18 +266,21 @@ func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error { max := atomic.LoadInt32(&cb.maxConcurrent) // For cleanup tasks, be more restrictive (singleton-like behavior) + // However, allow distinct realm-specific tasks (e.g., singleton-metadata-refresh-abc123 vs singleton-metadata-refresh-def456) if strings.Contains(taskName, "cleanup") || strings.Contains(taskName, "singleton") { cb.tasksMu.RLock() - hasCleanupTask := false + hasSameTask := false for activeTask := range cb.activeTasks { - if strings.Contains(activeTask, "cleanup") || strings.Contains(activeTask, "singleton") { - hasCleanupTask = true + // Only block if the EXACT same task is already running + // This allows realm-specific tasks like singleton-metadata-refresh-{hash} to run concurrently + if activeTask == taskName { + hasSameTask = true break } } cb.tasksMu.RUnlock() - if hasCleanupTask { + if hasSameTask { return fmt.Errorf("cleanup/singleton task already running: %s", taskName) } } diff --git a/main.go b/main.go index 66188aa..e49402a 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,8 @@ package traefikoidc import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" "net/http" "os" @@ -392,11 +394,15 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662) t.registrationURL = metadata.RegistrationURL // OIDC Dynamic Client Registration endpoint (RFC 7591) + // Copy values for logging after unlock to avoid race conditions + introspectionURL := t.introspectionURL + registrationURL := t.registrationURL + t.metadataMu.Unlock() // Log introspection endpoint availability for opaque token support - if t.introspectionURL != "" { - t.logger.Debugf("Token introspection endpoint discovered: %s", t.introspectionURL) + if introspectionURL != "" { + t.logger.Debugf("Token introspection endpoint discovered: %s", introspectionURL) if t.allowOpaqueTokens { t.logger.Debugf("Opaque token support enabled with introspection endpoint") } @@ -405,8 +411,8 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { } // Log registration endpoint availability - if t.registrationURL != "" { - t.logger.Debugf("Dynamic client registration endpoint discovered: %s", t.registrationURL) + if registrationURL != "" { + t.logger.Debugf("Dynamic client registration endpoint discovered: %s", registrationURL) } // Perform Dynamic Client Registration if enabled and ClientID is not set @@ -474,7 +480,10 @@ func (t *TraefikOidc) performDynamicClientRegistration() { func (t *TraefikOidc) startMetadataRefresh(providerURL string) { // Use singleton resource manager for metadata refresh rm := GetResourceManager() - taskName := "singleton-metadata-refresh" + // Use last 6 chars of provider URL hash to create unique task name per realm + // This fixes multi-realm support where different Keycloak realms need separate refresh tasks + hash := sha256.Sum256([]byte(providerURL)) + taskName := "singleton-metadata-refresh-" + hex.EncodeToString(hash[:])[0:6] // Create refresh function refreshFunc := func() { @@ -510,6 +519,27 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) { } } +// attemptMetadataRecovery tries to fetch provider metadata when the system is in a failed state. +// This is called periodically (every 30s) when requests come in and metadata is unavailable. +// It allows automatic recovery when the OIDC provider becomes available again. +func (t *TraefikOidc) attemptMetadataRecovery() { + if t.metadataCache == nil || t.httpClient == nil { + return + } + + // Try to fetch metadata (single attempt, no aggressive retry here since this runs every 30s) + metadata, err := t.metadataCache.GetMetadata(t.providerURL, t.httpClient, t.logger) + if err != nil { + t.safeLogDebugf("Metadata recovery attempt failed: %v", err) + return + } + + if metadata != nil { + t.updateMetadataEndpoints(metadata) + t.safeLogInfo("Successfully recovered OIDC provider metadata - service restored") + } +} + // createCaseInsensitiveStringMap creates a map with lowercase keys for case-insensitive matching. // This is used for case-insensitive matching of email addresses. // Parameters: diff --git a/main_test.go b/main_test.go index 63f665d..13234b5 100644 --- a/main_test.go +++ b/main_test.go @@ -209,11 +209,8 @@ func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient * } func (m *MockJWKCache) Cleanup() { - // Mock cleanup implementation - m.mu.Lock() - defer m.mu.Unlock() - m.JWKS = nil - m.Err = nil + // Mock cleanup is a no-op - we don't want to destroy the mock JWKS data + // Real cleanup is for expired entries, not resetting all data } // MockTokenVerifier implements TokenVerifier for testing, allowing interception of VerifyToken calls. @@ -2427,6 +2424,276 @@ func TestMultipleMiddlewareInstances(t *testing.T) { } } +// TestMultiRealmMetadataRefreshIsolation verifies that multiple middleware instances +// with different provider URLs (e.g., different Keycloak realms) get separate +// metadata refresh tasks. This addresses the issue reported in PR #88. +func TestMultiRealmMetadataRefreshIsolation(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test in short mode") + } + + // Create two mock provider metadata servers simulating different Keycloak realms + realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + w.WriteHeader(http.StatusNotFound) + return + } + metadata := ProviderMetadata{ + Issuer: "https://keycloak.example.com/realms/realm1", + AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth", + TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token", + JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs", + EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout", + } + json.NewEncoder(w).Encode(metadata) + })) + defer realm1Server.Close() + + realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + w.WriteHeader(http.StatusNotFound) + return + } + metadata := ProviderMetadata{ + Issuer: "https://keycloak.example.com/realms/realm2", + AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth", + TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token", + JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs", + EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout", + } + json.NewEncoder(w).Encode(metadata) + })) + defer realm2Server.Close() + + // Config for realm1 + config1 := &Config{ + ProviderURL: realm1Server.URL, + ClientID: "realm1-client", + ClientSecret: "realm1-secret", + CallbackURL: "/realm1/callback", + SessionEncryptionKey: "test-encryption-key-thats-long-enough", + CookiePrefix: "_oidc_realm1_", + } + + // Config for realm2 + config2 := &Config{ + ProviderURL: realm2Server.URL, + ClientID: "realm2-client", + ClientSecret: "realm2-secret", + CallbackURL: "/realm2/callback", + SessionEncryptionKey: "test-encryption-key-thats-long-enough", + CookiePrefix: "_oidc_realm2_", + } + + // Create middleware instances for both realms + middleware1, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), config1, "realm1-middleware") + if err != nil { + t.Fatalf("Failed to create middleware for realm1: %v", err) + } + + middleware2, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), config2, "realm2-middleware") + if err != nil { + t.Fatalf("Failed to create middleware for realm2: %v", err) + } + + m1, ok1 := middleware1.(*TraefikOidc) + m2, ok2 := middleware2.(*TraefikOidc) + if !ok1 || !ok2 { + t.Fatalf("Middleware is not of type *TraefikOidc") + } + + // Clean up middleware instances + defer func() { + if err := m1.Close(); err != nil { + t.Errorf("Failed to close realm1 middleware: %v", err) + } + if err := m2.Close(); err != nil { + t.Errorf("Failed to close realm2 middleware: %v", err) + } + }() + + // Wait for both instances to initialize + select { + case <-m1.initComplete: + case <-time.After(5 * time.Second): + t.Fatalf("Realm1 middleware failed to initialize") + } + + select { + case <-m2.initComplete: + case <-time.After(5 * time.Second): + t.Fatalf("Realm2 middleware failed to initialize") + } + + // Verify each instance has the correct issuer URL from their respective realms + if !strings.Contains(m1.issuerURL, "realm1") { + t.Errorf("Realm1 middleware expected issuer with realm1, got %s", m1.issuerURL) + } + if !strings.Contains(m2.issuerURL, "realm2") { + t.Errorf("Realm2 middleware expected issuer with realm2, got %s", m2.issuerURL) + } + + // Verify provider URLs are different + if m1.providerURL == m2.providerURL { + t.Errorf("Both middlewares should have different provider URLs, got same: %s", m1.providerURL) + } + + // Test that each middleware can handle requests independently + req1 := httptest.NewRequest("GET", "/realm1/protected", nil) + rr1 := httptest.NewRecorder() + m1.ServeHTTP(rr1, req1) + + req2 := httptest.NewRequest("GET", "/realm2/protected", nil) + rr2 := httptest.NewRecorder() + m2.ServeHTTP(rr2, req2) + + // Both should redirect to their respective auth URLs + if rr1.Code != http.StatusFound { + t.Errorf("Realm1: Expected redirect status %d, got %d", http.StatusFound, rr1.Code) + } + if rr2.Code != http.StatusFound { + t.Errorf("Realm2: Expected redirect status %d, got %d", http.StatusFound, rr2.Code) + } + + location1 := rr1.Header().Get("Location") + location2 := rr2.Header().Get("Location") + + if !strings.Contains(location1, "realm1") { + t.Errorf("Realm1: Expected redirect to realm1 auth URL, got %s", location1) + } + if !strings.Contains(location2, "realm2") { + t.Errorf("Realm2: Expected redirect to realm2 auth URL, got %s", location2) + } +} + +// TestMetadataRecoveryOnProviderFailure verifies that the middleware automatically +// recovers when the OIDC provider becomes available after initial failure. +func TestMetadataRecoveryOnProviderFailure(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test in short mode") + } + + // Track whether the provider is "available" + providerAvailable := false + var mu sync.Mutex + + // Create mock provider that initially fails, then becomes available + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + available := providerAvailable + mu.Unlock() + + if !available { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + + if r.URL.Path == "/.well-known/openid-configuration" { + metadata := ProviderMetadata{ + Issuer: "https://test-issuer.com", + AuthURL: "https://test-issuer.com/auth", + TokenURL: "https://test-issuer.com/token", + JWKSURL: "https://test-issuer.com/jwks", + EndSessionURL: "https://test-issuer.com/logout", + } + json.NewEncoder(w).Encode(metadata) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer mockServer.Close() + + config := &Config{ + ProviderURL: mockServer.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackURL: "/callback", + SessionEncryptionKey: "test-encryption-key-thats-long-enough", + } + + // Create middleware while provider is unavailable + middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), config, "test-recovery") + if err != nil { + t.Fatalf("Failed to create middleware: %v", err) + } + + m, ok := middleware.(*TraefikOidc) + if !ok { + t.Fatalf("Middleware is not of type *TraefikOidc") + } + defer m.Close() + + // Wait for initial initialization to complete (it should fail) + select { + case <-m.initComplete: + case <-time.After(15 * time.Second): + t.Fatal("Initialization did not complete in time") + } + + // Verify initial state - should be in failed state (no issuerURL) + m.metadataMu.RLock() + initialIssuer := m.issuerURL + m.metadataMu.RUnlock() + + if initialIssuer != "" { + t.Errorf("Expected empty issuerURL after failed init, got: %s", initialIssuer) + } + + // First request should get 503 + req1 := httptest.NewRequest("GET", "/protected", nil) + rr1 := httptest.NewRecorder() + m.ServeHTTP(rr1, req1) + + if rr1.Code != http.StatusServiceUnavailable { + t.Errorf("Expected 503 when provider unavailable, got %d", rr1.Code) + } + + // Now make the provider available + mu.Lock() + providerAvailable = true + mu.Unlock() + + // Reset the retry timer to allow immediate retry + m.metadataRetryMutex.Lock() + m.lastMetadataRetryTime = time.Time{} // Reset to zero time + m.metadataRetryMutex.Unlock() + + // Second request should trigger recovery attempt + req2 := httptest.NewRequest("GET", "/protected", nil) + rr2 := httptest.NewRecorder() + m.ServeHTTP(rr2, req2) + + // Give the async recovery a moment to complete + time.Sleep(100 * time.Millisecond) + + // Check if recovery happened + m.metadataMu.RLock() + recoveredIssuer := m.issuerURL + m.metadataMu.RUnlock() + + if recoveredIssuer == "" { + t.Error("Expected issuerURL to be recovered after provider became available") + } + + // Third request should succeed (redirect to auth, not 503) + req3 := httptest.NewRequest("GET", "/protected", nil) + rr3 := httptest.NewRecorder() + m.ServeHTTP(rr3, req3) + + if rr3.Code == http.StatusServiceUnavailable { + t.Errorf("Expected redirect after recovery, still got 503") + } + + t.Logf("Recovery test: initial_issuer=%q, recovered_issuer=%q, final_status=%d", + initialIssuer, recoveredIssuer, rr3.Code) +} + func TestServeHTTPRolesAndGroups(t *testing.T) { ts := NewTestSuite(t) ts.Setup() diff --git a/middleware.go b/middleware.go index 52372b6..afb0595 100644 --- a/middleware.go +++ b/middleware.go @@ -50,6 +50,20 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.metadataMu.RUnlock() if issuerURL == "" { + // Provider metadata initialization failed - try to recover + // Retry every 30 seconds to allow automatic recovery when provider comes back online + t.metadataRetryMutex.Lock() + shouldRetry := time.Since(t.lastMetadataRetryTime) >= 30*time.Second + if shouldRetry { + t.lastMetadataRetryTime = time.Now() + } + t.metadataRetryMutex.Unlock() + + if shouldRetry && t.providerURL != "" { + t.logger.Info("Attempting to recover OIDC provider metadata...") + go t.attemptMetadataRecovery() + } + t.logger.Error("OIDC provider metadata initialization failed or incomplete") t.sendErrorResponse(rw, req, "OIDC provider metadata initialization failed - please check provider availability and configuration", http.StatusServiceUnavailable) return diff --git a/singleton_resources_test.go b/singleton_resources_test.go index af63ace..634f6df 100644 --- a/singleton_resources_test.go +++ b/singleton_resources_test.go @@ -2,6 +2,8 @@ package traefikoidc import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" "runtime" "sync" @@ -97,6 +99,89 @@ func TestSingletonResourceManager(t *testing.T) { } }) + t.Run("MultiRealmMetadataRefreshTaskNaming", func(t *testing.T) { + // This test verifies that different provider URLs generate different task names + // which is critical for multi-realm Keycloak support (PR #88) + + // Reset singletons for clean test state + resetResourceManagerForTesting() + ResetGlobalTaskRegistry() + defer ResetGlobalTaskRegistry() + rm := GetResourceManager() + + // Simulate different Keycloak realms + providerURL1 := "https://keycloak.example.com/realms/realm1" + providerURL2 := "https://keycloak.example.com/realms/realm2" + + // Generate task names using the same logic as startMetadataRefresh + hash1 := sha256.Sum256([]byte(providerURL1)) + taskName1 := "singleton-metadata-refresh-" + hex.EncodeToString(hash1[:])[0:6] + + hash2 := sha256.Sum256([]byte(providerURL2)) + taskName2 := "singleton-metadata-refresh-" + hex.EncodeToString(hash2[:])[0:6] + + // Verify task names are different + if taskName1 == taskName2 { + t.Errorf("Task names should be different for different provider URLs: %s vs %s", taskName1, taskName2) + } + + // Register both tasks + task1Called := int32(0) + task2Called := int32(0) + + err := rm.RegisterBackgroundTask(taskName1, 100*time.Millisecond, func() { + atomic.AddInt32(&task1Called, 1) + }) + if err != nil { + t.Errorf("Failed to register task 1: %v", err) + } + + err = rm.RegisterBackgroundTask(taskName2, 100*time.Millisecond, func() { + atomic.AddInt32(&task2Called, 1) + }) + if err != nil { + t.Errorf("Failed to register task 2: %v", err) + } + + // Start both tasks + _ = rm.StartBackgroundTask(taskName1) + _ = rm.StartBackgroundTask(taskName2) + + // Wait for tasks to execute + time.Sleep(250 * time.Millisecond) + + // Verify both tasks are running independently + if !rm.IsTaskRunning(taskName1) { + t.Error("Task 1 should be running") + } + if !rm.IsTaskRunning(taskName2) { + t.Error("Task 2 should be running") + } + + // Verify both tasks were called (at least once) + if atomic.LoadInt32(&task1Called) == 0 { + t.Error("Task 1 should have been called at least once") + } + if atomic.LoadInt32(&task2Called) == 0 { + t.Error("Task 2 should have been called at least once") + } + + // Stop both tasks + _ = rm.StopBackgroundTask(taskName1) + _ = rm.StopBackgroundTask(taskName2) + + // Verify tasks are stopped + time.Sleep(50 * time.Millisecond) + if rm.IsTaskRunning(taskName1) { + t.Error("Task 1 should be stopped") + } + if rm.IsTaskRunning(taskName2) { + t.Error("Task 2 should be stopped") + } + + t.Logf("Successfully verified multi-realm task isolation: task1=%s, task2=%s", taskName1, taskName2) + }) + t.Run("ReferenceCountingCleanup", func(t *testing.T) { rm := GetResourceManager() diff --git a/types.go b/types.go index 9f8be29..607616f 100644 --- a/types.go +++ b/types.go @@ -128,8 +128,10 @@ type TraefikOidc struct { suppressDiagnosticLogs bool firstRequestReceived bool metadataRefreshStarted bool - allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks) - minimalHeaders bool // Reduce headers to prevent 431 errors + lastMetadataRetryTime time.Time // Track last metadata retry for failed state recovery + metadataRetryMutex sync.Mutex // Protects lastMetadataRetryTime + allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks) + minimalHeaders bool // Reduce headers to prevent 431 errors securityHeadersApplier func(http.ResponseWriter, *http.Request) scopeFilter *ScopeFilter // NEW - for discovery-based scope filtering scopesSupported []string // NEW - from provider metadata diff --git a/utilities.go b/utilities.go index cd7c21a..0df262b 100644 --- a/utilities.go +++ b/utilities.go @@ -3,6 +3,8 @@ package traefikoidc import ( + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "html" @@ -222,8 +224,13 @@ func (t *TraefikOidc) Close() error { rm := GetResourceManager() // Stop singleton tasks related to this instance - _ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup - _ = rm.StopBackgroundTask("singleton-metadata-refresh") // Safe to ignore: best effort cleanup + _ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup + // Stop metadata refresh task using same hash-based name as startMetadataRefresh + if t.providerURL != "" { + hash := sha256.Sum256([]byte(t.providerURL)) + taskName := "singleton-metadata-refresh-" + hex.EncodeToString(hash[:])[0:6] + _ = rm.StopBackgroundTask(taskName) // Safe to ignore: best effort cleanup + } // Remove reference for this instance rm.RemoveReference(t.name)