diff --git a/main.go b/main.go index 8445081..83a432b 100644 --- a/main.go +++ b/main.go @@ -39,6 +39,7 @@ type TraefikOidc struct { issuerURL string revocationURL string jwkCache JWKCacheInterface + metadataCache *MetadataCache tokenBlacklist *TokenBlacklist jwksURL string clientID string @@ -253,6 +254,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h }(), tokenBlacklist: NewTokenBlacklist(), jwkCache: &JWKCache{}, + metadataCache: NewMetadataCache(), clientID: config.ClientID, clientSecret: config.ClientSecret, forceHTTPS: config.ForceHTTPS, @@ -292,40 +294,58 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h func (t *TraefikOidc) initializeMetadata(providerURL string) { t.logger.Debug("Starting provider metadata discovery") - // Keep retrying until successful - backoff := time.Second - maxBackoff := 30 * time.Second + // Get metadata from cache or fetch it + metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger) + if err != nil { + t.logger.Errorf("Failed to get provider metadata: %v", err) + return + } + + if metadata != nil { + t.logger.Debug("Successfully initialized provider metadata") + t.jwksURL = metadata.JWKSURL + t.authURL = metadata.AuthURL + t.tokenURL = metadata.TokenURL + t.issuerURL = metadata.Issuer + t.revocationURL = metadata.RevokeURL + t.endSessionURL = metadata.EndSessionURL + + // Start metadata refresh goroutine + go t.startMetadataRefresh(providerURL) + + // Only close channel on success + close(t.initComplete) + return + } + + t.logger.Error("Received nil metadata") +} + +// startMetadataRefresh periodically refreshes the OIDC metadata +func (t *TraefikOidc) startMetadataRefresh(providerURL string) { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + for { - metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger) - - if err != nil { - t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff) - time.Sleep(backoff) - - // Exponential backoff with max - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff + select { + case <-ticker.C: + t.logger.Debug("Refreshing OIDC metadata") + metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger) + if err != nil { + t.logger.Errorf("Failed to refresh metadata: %v", err) + continue + } + + if metadata != nil { + t.jwksURL = metadata.JWKSURL + t.authURL = metadata.AuthURL + t.tokenURL = metadata.TokenURL + t.issuerURL = metadata.Issuer + t.revocationURL = metadata.RevokeURL + t.endSessionURL = metadata.EndSessionURL + t.logger.Debug("Successfully refreshed metadata") } - continue } - - if metadata != nil { - t.logger.Debug("Successfully initialized provider metadata") - t.jwksURL = metadata.JWKSURL - t.authURL = metadata.AuthURL - t.tokenURL = metadata.TokenURL - t.issuerURL = metadata.Issuer - t.revocationURL = metadata.RevokeURL - t.endSessionURL = metadata.EndSessionURL - - // Only close channel on success - close(t.initComplete) - return - } - - t.logger.Error("Received nil metadata, retrying") - time.Sleep(backoff) } } @@ -693,7 +713,7 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { // startTokenCleanup starts the token cleanup goroutine func (t *TraefikOidc) startTokenCleanup() { ctx, cancel := context.WithCancel(context.Background()) - ticker := time.NewTicker(30 * time.Second) // Increased frequency to prevent memory buildup + ticker := time.NewTicker(15 * time.Second) // More frequent cleanup go func() { defer ticker.Stop() @@ -706,14 +726,18 @@ func (t *TraefikOidc) startTokenCleanup() { case <-ticker.C: t.logger.Debug("Starting token cleanup cycle") - // Run cleanup in a separate goroutine with timeout - cleanupCtx, cleanupCancel := context.WithTimeout(ctx, 10*time.Second) + // Run cleanup in a separate goroutine with shorter timeout + cleanupCtx, cleanupCancel := context.WithTimeout(ctx, 5*time.Second) done := make(chan struct{}) go func() { defer close(done) + // Clean up in smaller batches to prevent long-running operations t.tokenCache.Cleanup() t.tokenBlacklist.Cleanup() + + // Force garbage collection after cleanup + runtime.GC() }() // Wait for cleanup to complete or timeout diff --git a/metadata_cache.go b/metadata_cache.go new file mode 100644 index 0000000..7b44164 --- /dev/null +++ b/metadata_cache.go @@ -0,0 +1,54 @@ +package traefikoidc + +import ( + "fmt" + "net/http" + "sync" + "time" +) + +// MetadataCache provides thread-safe caching for OIDC provider metadata +type MetadataCache struct { + metadata *ProviderMetadata + expiresAt time.Time + mutex sync.RWMutex +} + +// NewMetadataCache creates a new metadata cache instance +func NewMetadataCache() *MetadataCache { + return &MetadataCache{} +} + +// GetMetadata retrieves the metadata from cache or fetches it if expired +func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client, logger *Logger) (*ProviderMetadata, error) { + c.mutex.RLock() + if c.metadata != nil && time.Now().Before(c.expiresAt) { + defer c.mutex.RUnlock() + return c.metadata, nil + } + c.mutex.RUnlock() + + c.mutex.Lock() + defer c.mutex.Unlock() + + // Double-check after acquiring write lock + if c.metadata != nil && time.Now().Before(c.expiresAt) { + return c.metadata, nil + } + + metadata, err := discoverProviderMetadata(providerURL, httpClient, logger) + if err != nil { + if c.metadata != nil { + // On error, extend current cache by 5 minutes to prevent thundering herd + c.expiresAt = time.Now().Add(5 * time.Minute) + logger.Errorf("Failed to refresh metadata, using cached version for 5 more minutes: %v", err) + return c.metadata, nil + } + return nil, fmt.Errorf("failed to fetch provider metadata: %w", err) + } + + c.metadata = metadata + c.expiresAt = time.Now().Add(1 * time.Hour) + + return metadata, nil +}