diff --git a/blacklist.go b/blacklist.go new file mode 100644 index 0000000..27a75c8 --- /dev/null +++ b/blacklist.go @@ -0,0 +1,110 @@ +package traefikoidc + +import ( + "sync" + "time" +) + +// TokenBlacklist manages a thread-safe list of revoked tokens with expiration. +type TokenBlacklist struct { + tokens map[string]time.Time + mutex sync.RWMutex +} + +// NewTokenBlacklist creates a new token blacklist instance. +func NewTokenBlacklist() *TokenBlacklist { + return &TokenBlacklist{ + tokens: make(map[string]time.Time), + } +} + +// Add adds a token to the blacklist with an expiration time. +func (b *TokenBlacklist) Add(token string, expiry time.Time) { + b.mutex.Lock() + defer b.mutex.Unlock() + + // Clean up expired tokens if we're at capacity + if len(b.tokens) >= 1000 { + now := time.Now() + futureThreshold := now.Add(time.Minute) + for t, exp := range b.tokens { + if now.After(exp) || futureThreshold.After(exp) { + delete(b.tokens, t) + } + } + + // If still at capacity, remove oldest token + if len(b.tokens) >= 1000 { + var oldestToken string + var oldestTime time.Time + first := true + for t, exp := range b.tokens { + if first || exp.Before(oldestTime) { + oldestToken = t + oldestTime = exp + first = false + } + } + if oldestToken != "" { + delete(b.tokens, oldestToken) + } + } + } + + b.tokens[token] = expiry +} + +// IsBlacklisted checks if a token is in the blacklist and not expired. +func (b *TokenBlacklist) IsBlacklisted(token string) bool { + b.mutex.RLock() + defer b.mutex.RUnlock() + + expiry, exists := b.tokens[token] + if !exists { + return false + } + + // If token is expired, remove it and return false + if time.Now().After(expiry) { + // Switch to write lock to remove expired token + b.mutex.RUnlock() + b.mutex.Lock() + delete(b.tokens, token) + b.mutex.Unlock() + b.mutex.RLock() + return false + } + + return true +} + +// Cleanup removes expired tokens from the blacklist. +// Also removes tokens that will expire within the next minute to prevent edge cases. +func (b *TokenBlacklist) Cleanup() { + b.mutex.Lock() + defer b.mutex.Unlock() + + now := time.Now() + futureThreshold := now.Add(time.Minute) + + for token, expiry := range b.tokens { + // Remove tokens that are expired or will expire soon + if now.After(expiry) || futureThreshold.After(expiry) { + delete(b.tokens, token) + } + } +} + +// Remove removes a token from the blacklist regardless of its expiration. +func (b *TokenBlacklist) Remove(token string) { + b.mutex.Lock() + defer b.mutex.Unlock() + delete(b.tokens, token) +} + +// Count returns the current number of tokens in the blacklist. +func (b *TokenBlacklist) Count() int { + b.mutex.RLock() + defer b.mutex.RUnlock() + return len(b.tokens) +} diff --git a/cache.go b/cache.go index 2ae81e3..2a2f6ef 100644 --- a/cache.go +++ b/cache.go @@ -128,6 +128,7 @@ func (c *Cache) Cleanup() { now := time.Now() for key, item := range c.items { + // Only remove items that are already expired if now.After(item.ExpiresAt) { c.removeItem(key) } @@ -136,8 +137,23 @@ func (c *Cache) Cleanup() { // evictOldest removes the least recently used item from the cache. func (c *Cache) evictOldest() { + now := time.Now() elem := c.order.Front() - if elem != nil { + + // First try to find an expired item from the front + for elem != nil { + entry := elem.Value.(lruEntry) + if item, exists := c.items[entry.key]; exists { + if now.After(item.ExpiresAt) { + c.removeItem(entry.key) + return + } + } + elem = elem.Next() + } + + // If no expired items found, remove the oldest item + if elem = c.order.Front(); elem != nil { entry := elem.Value.(lruEntry) c.removeItem(entry.key) } diff --git a/helpers.go b/helpers.go index 0dc936e..e18ffe2 100644 --- a/helpers.go +++ b/helpers.go @@ -11,7 +11,6 @@ import ( "net/http/cookiejar" "net/url" "strings" - "sync" "time" ) @@ -283,79 +282,6 @@ func extractClaims(tokenString string) (map[string]interface{}, error) { return claims, nil } -// TokenBlacklist maintains a thread-safe list of revoked tokens. -// It stores tokens with their expiration times and automatically -// removes expired entries during cleanup operations. -type TokenBlacklist struct { - // blacklist maps token IDs to their expiration times - blacklist map[string]time.Time - - // mutex protects concurrent access to the blacklist - mutex sync.RWMutex - - // maxSize is the maximum number of tokens in the blacklist - maxSize int -} - -// NewTokenBlacklist creates a new TokenBlacklist instance. -func NewTokenBlacklist() *TokenBlacklist { - return &TokenBlacklist{ - blacklist: make(map[string]time.Time), - maxSize: 1000, // Limit the size to prevent unbounded growth - } -} - -// Add adds a token to the blacklist with an expiration time. -func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) { - tb.mutex.Lock() - defer tb.mutex.Unlock() - - // Clean up expired tokens if we're at capacity - if len(tb.blacklist) >= tb.maxSize { - now := time.Now() - for token, exp := range tb.blacklist { - if now.After(exp) { - delete(tb.blacklist, token) - } - } - // If still at capacity after cleanup, remove oldest token - if len(tb.blacklist) >= tb.maxSize { - var oldestToken string - var oldestTime time.Time - first := true - for token, exp := range tb.blacklist { - if first || exp.Before(oldestTime) { - oldestToken = token - oldestTime = exp - first = false - } - } - delete(tb.blacklist, oldestToken) - } - } - tb.blacklist[tokenID] = expiration -} - -// IsBlacklisted checks if a token is in the blacklist and not expired. -func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool { - tb.mutex.RLock() - defer tb.mutex.RUnlock() - expiration, exists := tb.blacklist[tokenID] - return exists && time.Now().Before(expiration) -} - -// Cleanup removes expired tokens from the blacklist. -func (tb *TokenBlacklist) Cleanup() { - tb.mutex.Lock() - defer tb.mutex.Unlock() - now := time.Now() - for tokenID, expiration := range tb.blacklist { - if now.After(expiration) { - delete(tb.blacklist, tokenID) - } - } -} - // TokenCache provides a caching mechanism for validated tokens. // It stores token claims to avoid repeated validation of the // same token, improving performance for frequently used tokens. diff --git a/helpers_test.go b/helpers_test.go index 1c87c0c..96a3c45 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -16,16 +16,16 @@ func TestTokenBlacklistSizeLimit(t *testing.T) { } // Verify size is at max - if len(tb.blacklist) != 1000 { - t.Errorf("Expected blacklist size to be 1000, got %d", len(tb.blacklist)) + if tb.Count() != 1000 { + t.Errorf("Expected blacklist size to be 1000, got %d", tb.Count()) } // Add one more token, should trigger cleanup/eviction tb.Add("newtoken", time.Now().Add(time.Hour)) // Size should still be at max - if len(tb.blacklist) > 1000 { - t.Errorf("Blacklist exceeded max size: %d", len(tb.blacklist)) + if tb.Count() > 1000 { + t.Errorf("Blacklist exceeded max size: %d", tb.Count()) } } @@ -46,12 +46,14 @@ func TestTokenBlacklistExpiredCleanup(t *testing.T) { tb.Cleanup() // Only valid tokens should remain - if len(tb.blacklist) != 500 { - t.Errorf("Expected 500 valid tokens after cleanup, got %d", len(tb.blacklist)) + if tb.Count() != 500 { + t.Errorf("Expected 500 valid tokens after cleanup, got %d", tb.Count()) } // Verify only valid tokens remain - for token, expiry := range tb.blacklist { + tb.mutex.RLock() + defer tb.mutex.RUnlock() + for token, expiry := range tb.tokens { if time.Now().After(expiry) { t.Errorf("Found expired token after cleanup: %s", token) } @@ -130,8 +132,8 @@ func TestTokenBlacklistMemoryUsage(t *testing.T) { } // Verify size stayed within limits - if len(tb.blacklist) > tb.maxSize { - t.Errorf("Blacklist exceeded max size: %d", len(tb.blacklist)) + if tb.Count() > 1000 { + t.Errorf("Blacklist exceeded max size: %d", tb.Count()) } } @@ -167,8 +169,8 @@ func TestConcurrentTokenBlacklistOperations(t *testing.T) { } // Verify size constraints were maintained - if len(tb.blacklist) > tb.maxSize { - t.Errorf("Blacklist exceeded max size under concurrent operations: %d", len(tb.blacklist)) + if tb.Count() > 1000 { + t.Errorf("Blacklist exceeded max size under concurrent operations: %d", tb.Count()) } } diff --git a/main.go b/main.go index 83a432b..ef50b64 100644 --- a/main.go +++ b/main.go @@ -326,25 +326,22 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) { ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() - for { - select { - case <-ticker.C: - t.logger.Debug("Refreshing OIDC metadata") - metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger) - if err != nil { - t.logger.Errorf("Failed to refresh metadata: %v", err) - continue - } + for range ticker.C { + t.logger.Debug("Refreshing OIDC metadata") + metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger) + if err != nil { + t.logger.Errorf("Failed to refresh metadata: %v", err) + continue + } - if metadata != nil { - t.jwksURL = metadata.JWKSURL - t.authURL = metadata.AuthURL - t.tokenURL = metadata.TokenURL - t.issuerURL = metadata.Issuer - t.revocationURL = metadata.RevokeURL - t.endSessionURL = metadata.EndSessionURL - t.logger.Debug("Successfully refreshed metadata") - } + if metadata != nil { + t.jwksURL = metadata.JWKSURL + t.authURL = metadata.AuthURL + t.tokenURL = metadata.TokenURL + t.issuerURL = metadata.Issuer + t.revocationURL = metadata.RevokeURL + t.endSessionURL = metadata.EndSessionURL + t.logger.Debug("Successfully refreshed metadata") } } } @@ -719,39 +716,34 @@ func (t *TraefikOidc) startTokenCleanup() { defer ticker.Stop() defer cancel() - for { + for range ticker.C { + t.logger.Debug("Starting token cleanup cycle") + + // Run cleanup in a separate goroutine with shorter timeout + cleanupCtx, cleanupCancel := context.WithTimeout(ctx, 5*time.Second) + done := make(chan struct{}) + + go func() { + defer close(done) + // Clean up in smaller batches to prevent long-running operations + t.tokenCache.Cleanup() + t.tokenBlacklist.Cleanup() + + // Force garbage collection after cleanup + runtime.GC() + }() + + // Wait for cleanup to complete or timeout select { - case <-ctx.Done(): - return - case <-ticker.C: - t.logger.Debug("Starting token cleanup cycle") - - // Run cleanup in a separate goroutine with shorter timeout - cleanupCtx, cleanupCancel := context.WithTimeout(ctx, 5*time.Second) - done := make(chan struct{}) - - go func() { - defer close(done) - // Clean up in smaller batches to prevent long-running operations - t.tokenCache.Cleanup() - t.tokenBlacklist.Cleanup() - - // Force garbage collection after cleanup - runtime.GC() - }() - - // Wait for cleanup to complete or timeout - select { - case <-cleanupCtx.Done(): - if cleanupCtx.Err() == context.DeadlineExceeded { - t.logger.Error("Token cleanup cycle timed out") - } - case <-done: - t.logger.Debug("Token cleanup cycle completed successfully") + case <-cleanupCtx.Done(): + if cleanupCtx.Err() == context.DeadlineExceeded { + t.logger.Error("Token cleanup cycle timed out") } - - cleanupCancel() + case <-done: + t.logger.Debug("Token cleanup cycle completed successfully") } + + cleanupCancel() } }() }