From 147aa0b16954dc40b18e75f984672e5ebc08c74c Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Mon, 6 Jan 2025 11:23:12 +0000 Subject: [PATCH] Fix the issue #16 Removed global metadata cache and sync.Once Each middleware instance now handles its own metadata initialization Added tests to verify multiple instances work correctly The changes ensure that: Each route gets its own properly initialized middleware instance Metadata is fetched and set correctly for each instance No shared state between instances that could cause conflicts Each instance can handle requests independently The added test verifies this by creating multiple middleware instances with different routes and confirming they all initialize and function correctly. The test specifically checks that: Each instance initializes successfully Each instance gets its own metadata configuration Each instance can handle requests independently Callback URLs are correctly set per route --- main.go | 39 ++++++++-------------- main_test.go | 94 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 26 deletions(-) diff --git a/main.go b/main.go index 2c54af1..aeb6185 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,6 @@ import ( "net/http" "net/url" "strings" - "sync" "time" "github.com/google/uuid" @@ -83,14 +82,6 @@ var defaultExcludedURLs = map[string]struct{}{ var newTicker = time.NewTicker -var ( - globalMetadataCache struct { - sync.Once - metadata *ProviderMetadata - err error - } -) - // VerifyToken verifies the provided JWT token func (t *TraefikOidc) VerifyToken(token string) error { t.logger.Debugf("Verifying token") @@ -266,23 +257,19 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h // initializeMetadata discovers and initializes the provider metadata func (t *TraefikOidc) initializeMetadata(providerURL string) { - globalMetadataCache.Once.Do(func() { - t.logger.Debug("Starting global provider metadata discovery") - metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger) - globalMetadataCache.metadata = metadata - globalMetadataCache.err = err - }) - - if globalMetadataCache.err != nil { - t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err) - } else if globalMetadataCache.metadata != nil { - t.logger.Debug("Using cached provider metadata") - t.jwksURL = globalMetadataCache.metadata.JWKSURL - t.authURL = globalMetadataCache.metadata.AuthURL - t.tokenURL = globalMetadataCache.metadata.TokenURL - t.issuerURL = globalMetadataCache.metadata.Issuer - t.revocationURL = globalMetadataCache.metadata.RevokeURL - t.endSessionURL = globalMetadataCache.metadata.EndSessionURL + t.logger.Debug("Starting provider metadata discovery") + metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger) + + if err != nil { + t.logger.Errorf("Failed to discover provider metadata: %v", err) + } else if metadata != nil { + t.logger.Debug("Using provider metadata") + t.jwksURL = metadata.JWKSURL + t.authURL = metadata.AuthURL + t.tokenURL = metadata.TokenURL + t.issuerURL = metadata.Issuer + t.revocationURL = metadata.RevokeURL + t.endSessionURL = metadata.EndSessionURL } close(t.initComplete) diff --git a/main_test.go b/main_test.go index a151b9b..ea445f2 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package traefikoidc import ( + "context" "crypto" "crypto/ecdsa" "crypto/elliptic" @@ -1342,6 +1343,99 @@ func TestExtractGroupsAndRoles(t *testing.T) { } } +// TestMultipleMiddlewareInstances verifies that multiple middleware instances +// can be created and initialized properly for different routes +func TestMultipleMiddlewareInstances(t *testing.T) { + // Create mock provider metadata server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + metadata := ProviderMetadata{ + Issuer: "https://test-issuer.com", + AuthURL: "https://test-issuer.com/auth", + TokenURL: "https://test-issuer.com/token", + JWKSURL: "https://test-issuer.com/jwks", + RevokeURL: "https://test-issuer.com/revoke", + EndSessionURL: "https://test-issuer.com/end-session", + } + json.NewEncoder(w).Encode(metadata) + })) + defer mockServer.Close() + + // Create base config + config := &Config{ + ProviderURL: mockServer.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackURL: "/callback", + SessionEncryptionKey: "test-encryption-key-thats-long-enough", + } + + // Create multiple middleware instances + routes := []string{"/api/v1", "/api/v2", "/api/v3"} + var middlewares []*TraefikOidc + + for _, route := range routes { + config.CallbackURL = route + "/callback" + middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), config, "test") + + if err != nil { + t.Fatalf("Failed to create middleware for route %s: %v", route, err) + } + + // Type assert to access internal fields + if m, ok := middleware.(*TraefikOidc); ok { + middlewares = append(middlewares, m) + } else { + t.Fatalf("Middleware is not of type *TraefikOidc") + } + } + + // Wait for all instances to initialize + for i, m := range middlewares { + select { + case <-m.initComplete: + case <-time.After(5 * time.Second): + t.Fatalf("Middleware instance %d failed to initialize", i) + } + + // Verify each instance has its own unique configuration + if m.issuerURL != "https://test-issuer.com" { + t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL) + } + if m.authURL != "https://test-issuer.com/auth" { + t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL) + } + if m.tokenURL != "https://test-issuer.com/token" { + t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL) + } + if m.jwksURL != "https://test-issuer.com/jwks" { + t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL) + } + if m.redirURLPath != routes[i]+"/callback" { + t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath) + } + } + + // Test that each instance can handle requests independently + for i, m := range middlewares { + req := httptest.NewRequest("GET", routes[i]+"/protected", nil) + rr := httptest.NewRecorder() + + m.ServeHTTP(rr, req) + + // Should redirect to auth URL since not authenticated + if rr.Code != http.StatusFound { + t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code) + } + + location := rr.Header().Get("Location") + if !strings.Contains(location, "https://test-issuer.com/auth") { + t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location) + } + } +} + func TestServeHTTPRolesAndGroups(t *testing.T) { ts := &TestSuite{t: t} ts.Setup()