diff --git a/main.go b/main.go index 2c54af1..aeb6185 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,6 @@ import ( "net/http" "net/url" "strings" - "sync" "time" "github.com/google/uuid" @@ -83,14 +82,6 @@ var defaultExcludedURLs = map[string]struct{}{ var newTicker = time.NewTicker -var ( - globalMetadataCache struct { - sync.Once - metadata *ProviderMetadata - err error - } -) - // VerifyToken verifies the provided JWT token func (t *TraefikOidc) VerifyToken(token string) error { t.logger.Debugf("Verifying token") @@ -266,23 +257,19 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h // initializeMetadata discovers and initializes the provider metadata func (t *TraefikOidc) initializeMetadata(providerURL string) { - globalMetadataCache.Once.Do(func() { - t.logger.Debug("Starting global provider metadata discovery") - metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger) - globalMetadataCache.metadata = metadata - globalMetadataCache.err = err - }) - - if globalMetadataCache.err != nil { - t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err) - } else if globalMetadataCache.metadata != nil { - t.logger.Debug("Using cached provider metadata") - t.jwksURL = globalMetadataCache.metadata.JWKSURL - t.authURL = globalMetadataCache.metadata.AuthURL - t.tokenURL = globalMetadataCache.metadata.TokenURL - t.issuerURL = globalMetadataCache.metadata.Issuer - t.revocationURL = globalMetadataCache.metadata.RevokeURL - t.endSessionURL = globalMetadataCache.metadata.EndSessionURL + t.logger.Debug("Starting provider metadata discovery") + metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger) + + if err != nil { + t.logger.Errorf("Failed to discover provider metadata: %v", err) + } else if metadata != nil { + t.logger.Debug("Using provider metadata") + t.jwksURL = metadata.JWKSURL + t.authURL = metadata.AuthURL + t.tokenURL = metadata.TokenURL + t.issuerURL = metadata.Issuer + t.revocationURL = metadata.RevokeURL + t.endSessionURL = metadata.EndSessionURL } close(t.initComplete) diff --git a/main_test.go b/main_test.go index a151b9b..ea445f2 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package traefikoidc import ( + "context" "crypto" "crypto/ecdsa" "crypto/elliptic" @@ -1342,6 +1343,99 @@ func TestExtractGroupsAndRoles(t *testing.T) { } } +// TestMultipleMiddlewareInstances verifies that multiple middleware instances +// can be created and initialized properly for different routes +func TestMultipleMiddlewareInstances(t *testing.T) { + // Create mock provider metadata server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + metadata := ProviderMetadata{ + Issuer: "https://test-issuer.com", + AuthURL: "https://test-issuer.com/auth", + TokenURL: "https://test-issuer.com/token", + JWKSURL: "https://test-issuer.com/jwks", + RevokeURL: "https://test-issuer.com/revoke", + EndSessionURL: "https://test-issuer.com/end-session", + } + json.NewEncoder(w).Encode(metadata) + })) + defer mockServer.Close() + + // Create base config + config := &Config{ + ProviderURL: mockServer.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackURL: "/callback", + SessionEncryptionKey: "test-encryption-key-thats-long-enough", + } + + // Create multiple middleware instances + routes := []string{"/api/v1", "/api/v2", "/api/v3"} + var middlewares []*TraefikOidc + + for _, route := range routes { + config.CallbackURL = route + "/callback" + middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), config, "test") + + if err != nil { + t.Fatalf("Failed to create middleware for route %s: %v", route, err) + } + + // Type assert to access internal fields + if m, ok := middleware.(*TraefikOidc); ok { + middlewares = append(middlewares, m) + } else { + t.Fatalf("Middleware is not of type *TraefikOidc") + } + } + + // Wait for all instances to initialize + for i, m := range middlewares { + select { + case <-m.initComplete: + case <-time.After(5 * time.Second): + t.Fatalf("Middleware instance %d failed to initialize", i) + } + + // Verify each instance has its own unique configuration + if m.issuerURL != "https://test-issuer.com" { + t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL) + } + if m.authURL != "https://test-issuer.com/auth" { + t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL) + } + if m.tokenURL != "https://test-issuer.com/token" { + t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL) + } + if m.jwksURL != "https://test-issuer.com/jwks" { + t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL) + } + if m.redirURLPath != routes[i]+"/callback" { + t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath) + } + } + + // Test that each instance can handle requests independently + for i, m := range middlewares { + req := httptest.NewRequest("GET", routes[i]+"/protected", nil) + rr := httptest.NewRecorder() + + m.ServeHTTP(rr, req) + + // Should redirect to auth URL since not authenticated + if rr.Code != http.StatusFound { + t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code) + } + + location := rr.Header().Get("Location") + if !strings.Contains(location, "https://test-issuer.com/auth") { + t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location) + } + } +} + func TestServeHTTPRolesAndGroups(t *testing.T) { ts := &TestSuite{t: t} ts.Setup()