Multiple fixes

- Unbounded Replay Cache: Now bounded to 10,000 entries with automatic cleanup
- Session Pool Leaks: Proper object lifecycle prevents accumulation
- HTTP Client Leaks: Reusable clients eliminate connection overhead
- Goroutine Leaks: Tracked lifecycle with graceful shutdown
This commit is contained in:
2025-05-23 10:55:57 +01:00
parent 82a640cc3b
commit 99881f5837
6 changed files with 279 additions and 79 deletions
+18 -13
View File
@@ -123,19 +123,24 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
data.Set("refresh_token", codeOrToken)
}
// Create a cookie jar for this request to handle redirects with cookies
jar, _ := cookiejar.New(nil)
client := &http.Client{
Transport: t.httpClient.Transport,
Timeout: t.httpClient.Timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
Jar: jar,
// Use the reusable token HTTP client, fallback to creating one if not initialized
client := t.tokenHTTPClient
if client == nil {
// Fallback for tests or incomplete initialization - create a temporary client
// with the same behavior as the original implementation
jar, _ := cookiejar.New(nil)
client = &http.Client{
Transport: t.httpClient.Transport,
Timeout: t.httpClient.Timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
Jar: jar,
}
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
+17 -26
View File
@@ -17,26 +17,14 @@ import (
var (
replayCacheMu sync.Mutex
replayCache = make(map[string]time.Time)
replayCache *Cache // Replace unbounded map with bounded Cache
)
// cleanupReplayCache iterates through the replay cache and removes entries
// whose expiration time is before the current time. This function should be
// called periodically to prevent the cache from growing indefinitely.
// It acquires a mutex to ensure thread safety during cleanup.
// SECURITY FIX: Add proper locking protection for cleanupReplayCache
func cleanupReplayCache() {
now := time.Now()
// SECURITY FIX: Use safe iteration with proper locking
toDelete := make([]string, 0)
for token, expiry := range replayCache {
if expiry.Before(now) {
toDelete = append(toDelete, token)
}
}
// Delete expired entries
for _, token := range toDelete {
delete(replayCache, token)
// initReplayCache initializes the global replay cache with size limit
func initReplayCache() {
if replayCache == nil {
replayCache = NewCache()
replayCache.SetMaxSize(10000) // Set size limit to 10,000 entries
}
}
@@ -203,15 +191,15 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return nil
}
// SECURITY FIX: Implement thread-safe replay cache operations with proper locking
// SECURITY FIX: Use bounded Cache with thread-safe operations
replayCacheMu.Lock()
defer replayCacheMu.Unlock() // Ensure unlock happens even if panic occurs
defer replayCacheMu.Unlock()
// SECURITY FIX: Clean up expired entries safely
cleanupReplayCache()
// Initialize cache if not already done
initReplayCache()
// SECURITY FIX: Check for replay attack with atomic operation
if _, exists := replayCache[jti]; exists {
// SECURITY FIX: Check for replay attack using Cache API
if _, exists := replayCache.Get(jti); exists {
return fmt.Errorf("token replay detected")
}
@@ -224,8 +212,11 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
expTime = time.Now().Add(10 * time.Minute)
}
// SECURITY FIX: Add to replay cache atomically
replayCache[jti] = expTime
// SECURITY FIX: Add to replay cache with expiration using Cache API
duration := time.Until(expTime)
if duration > 0 {
replayCache.Set(jti, true, duration)
}
}
sub, ok := claims["sub"].(string)
+87 -26
View File
@@ -9,9 +9,11 @@ import (
"math"
"net"
"net/http"
"net/http/cookiejar"
"net/url"
"runtime"
"strings"
"sync"
"text/template"
"time"
@@ -59,6 +61,33 @@ func createDefaultHTTPClient() *http.Client {
}
}
// createTokenHTTPClient creates a specialized HTTP client for token operations.
// It reuses the transport from the main HTTP client but adds cookie jar support
// and optimized redirect handling for OIDC token endpoints.
//
// Parameters:
// - baseClient: The base HTTP client to derive transport settings from.
//
// Returns:
// - A pointer to the configured http.Client optimized for token operations.
func createTokenHTTPClient(baseClient *http.Client) *http.Client {
// Create a cookie jar for handling redirects with cookies
jar, _ := cookiejar.New(nil)
return &http.Client{
Transport: baseClient.Transport, // Reuse the transport from base client
Timeout: baseClient.Timeout, // Reuse the timeout from base client
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Always follow redirects for OIDC endpoints
if len(via) >= 50 {
return fmt.Errorf("stopped after 50 redirects")
}
return nil
},
Jar: jar, // Add cookie jar for redirect handling
}
}
const (
ConstSessionTimeout = 86400 // Session timeout in seconds
defaultBlacklistDuration = 24 * time.Hour // Default duration to blacklist a JTI
@@ -105,6 +134,7 @@ type TraefikOidc struct {
scheme string
tokenCache *TokenCache
httpClient *http.Client
tokenHTTPClient *http.Client // Reusable HTTP client for token operations
logger *Logger
tokenVerifier TokenVerifier
jwtVerifier JWTVerifier
@@ -124,6 +154,7 @@ type TraefikOidc struct {
headerTemplates map[string]*template.Template // Parsed templates for custom headers
tokenCleanupStopChan chan struct{} // Channel to stop token cleanup goroutine
metadataRefreshStopChan chan struct{} // Channel to stop metadata refresh goroutine
goroutineWG sync.WaitGroup // WaitGroup to track background goroutines
}
// ProviderMetadata holds OIDC provider metadata
@@ -243,7 +274,14 @@ func (t *TraefikOidc) VerifyToken(token string) error {
// Also update the global replayCache for backwards compatibility
replayCacheMu.Lock()
replayCache[jti] = expiry
// Initialize cache if not already done
if replayCache == nil {
initReplayCache()
}
duration := time.Until(expiry)
if duration > 0 {
replayCache.Set(jti, true, duration)
}
replayCacheMu.Unlock()
}
@@ -416,6 +454,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
tokenHTTPClient: createTokenHTTPClient(httpClient),
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers),
@@ -526,30 +565,34 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
// - providerURL: The base URL of bogged OIDC provider, used for subsequent refresh attempts.
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
ticker := time.NewTicker(1 * time.Hour)
// No defer ticker.Stop() here, it's stopped in the select case
t.goroutineWG.Add(1) // Track this goroutine
for {
select {
case <-ticker.C:
t.logger.Debug("Refreshing OIDC metadata")
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to refresh metadata: %v", err)
continue
}
go func() {
defer t.goroutineWG.Done() // Signal completion when goroutine exits
defer ticker.Stop() // Ensure ticker is always stopped
if metadata != nil {
t.updateMetadataEndpoints(metadata)
t.logger.Debug("Successfully refreshed metadata")
} else {
t.logger.Error("Received nil metadata during refresh")
for {
select {
case <-ticker.C:
t.logger.Debug("Refreshing OIDC metadata")
metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to refresh metadata: %v", err)
continue
}
if metadata != nil {
t.updateMetadataEndpoints(metadata)
t.logger.Debug("Successfully refreshed metadata")
} else {
t.logger.Error("Received nil metadata during refresh")
}
case <-t.metadataRefreshStopChan:
t.logger.Debug("Metadata refresh goroutine stopped.")
return
}
case <-t.metadataRefreshStopChan:
ticker.Stop()
t.logger.Debug("Metadata refresh goroutine stopped.")
return
}
}
}()
}
// discoverProviderMetadata attempts to fetch the OIDC provider's configuration from its
@@ -1720,8 +1763,11 @@ func (t *TraefikOidc) validateHost(host string) error {
// the token cache, token blacklist cache, and JWK cache.
func (t *TraefikOidc) startTokenCleanup() {
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
t.goroutineWG.Add(1) // Track this goroutine
go func() {
// No defer ticker.Stop() here, it's stopped in the select case
defer t.goroutineWG.Done() // Signal completion when goroutine exits
defer ticker.Stop() // Ensure ticker is always stopped
for {
select {
case <-ticker.C:
@@ -1741,7 +1787,6 @@ func (t *TraefikOidc) startTokenCleanup() {
t.jwkCache.Cleanup()
}
case <-t.tokenCleanupStopChan:
ticker.Stop()
t.logger.Debug("Token cleanup goroutine stopped.")
return
}
@@ -2231,20 +2276,36 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
_, _ = rw.Write([]byte(htmlBody)) // Ignore write error as header is already sent
}
// Close stops all background goroutines and closes resources.
// Close stops all background goroutines and closes resources with proper timeout.
func (t *TraefikOidc) Close() error {
t.logger.Debug("Closing TraefikOidc plugin instance")
// Signal and close tokenCleanupStopChan
// Signal all goroutines to stop
if t.tokenCleanupStopChan != nil {
close(t.tokenCleanupStopChan)
t.logger.Debug("tokenCleanupStopChan closed")
}
// Signal and close metadataRefreshStopChan
if t.metadataRefreshStopChan != nil {
close(t.metadataRefreshStopChan)
t.logger.Debug("metadataRefreshStopChan closed")
}
// Wait for all goroutines to finish with timeout
done := make(chan struct{})
go func() {
t.goroutineWG.Wait()
close(done)
}()
// Wait for goroutines to finish or timeout after 10 seconds
select {
case <-done:
t.logger.Debug("All background goroutines stopped gracefully")
case <-time.After(10 * time.Second):
t.logger.Errorf("Timeout waiting for background goroutines to stop")
// Continue with cleanup even if goroutines didn't stop gracefully
}
// Close caches
// These Close methods should stop their respective autoCleanupRoutine goroutines
if t.tokenBlacklist != nil {
+6 -3
View File
@@ -722,7 +722,8 @@ func TestServeHTTP(t *testing.T) {
// Reset the global replayCache to prevent "token replay detected" errors
replayCacheMu.Lock()
replayCache = make(map[string]time.Time) // Reset the global cache
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
// Store original tokenVerifier to restore later
@@ -734,7 +735,8 @@ func TestServeHTTP(t *testing.T) {
VerifyFunc: func(token string) error {
// Clear replay cache before token verification
replayCacheMu.Lock()
replayCache = make(map[string]time.Time)
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
// Call the original verifier's VerifyToken method
@@ -1143,7 +1145,8 @@ func TestHandleCallback(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// Clear the global replay cache before each test run
replayCacheMu.Lock()
replayCache = make(map[string]time.Time) // Reset the global cache
replayCache = NewCache()
replayCache.SetMaxSize(10000)
replayCacheMu.Unlock()
// Explicitly clear the shared blacklist at the start of each sub-test
+93 -6
View File
@@ -36,6 +36,10 @@ type PerformanceMetrics struct {
// Resource metrics
memoryUsage int64
goroutineCount int64
memoryPressure int64 // Memory pressure level (0-100)
gcPauseTime int64 // Last GC pause time in nanoseconds
heapSize int64 // Current heap size
heapInUse int64 // Heap memory in use
// Error metrics (kept for backward compatibility)
verificationErrors int64
@@ -251,9 +255,36 @@ func (pm *PerformanceMetrics) collectSystemMetrics() {
var m runtime.MemStats
runtime.ReadMemStats(&m)
atomic.StoreInt64(&pm.memoryUsage, int64(m.Alloc))
atomic.StoreInt64(&pm.heapSize, int64(m.HeapSys))
atomic.StoreInt64(&pm.heapInUse, int64(m.HeapInuse))
atomic.StoreInt64(&pm.gcPauseTime, int64(m.PauseNs[(m.NumGC+255)%256]))
// Calculate memory pressure (0-100 scale)
// Based on heap utilization and GC frequency
heapUtilization := float64(m.HeapInuse) / float64(m.HeapSys)
gcFrequency := float64(m.NumGC) / time.Since(pm.startTime).Minutes()
// Memory pressure calculation
pressure := int64(heapUtilization * 50) // 0-50 based on heap utilization
if gcFrequency > 10 { // High GC frequency indicates pressure
pressure += int64((gcFrequency - 10) * 2) // Add up to 50 more
}
if pressure > 100 {
pressure = 100
}
atomic.StoreInt64(&pm.memoryPressure, pressure)
// Goroutine count
atomic.StoreInt64(&pm.goroutineCount, int64(runtime.NumGoroutine()))
// Log memory pressure warnings
if pressure > 80 {
pm.logger.Errorf("High memory pressure detected: %d%% (heap utilization: %.1f%%, GC frequency: %.1f/min)",
pressure, heapUtilization*100, gcFrequency)
} else if pressure > 60 {
pm.logger.Infof("Moderate memory pressure: %d%% (heap utilization: %.1f%%, GC frequency: %.1f/min)",
pressure, heapUtilization*100, gcFrequency)
}
}
// GetMetrics returns all current performance metrics
@@ -317,6 +348,10 @@ func (pm *PerformanceMetrics) GetMetrics() map[string]interface{} {
// Resource metrics
"memory_usage_bytes": atomic.LoadInt64(&pm.memoryUsage),
"memory_pressure": atomic.LoadInt64(&pm.memoryPressure),
"heap_size_bytes": atomic.LoadInt64(&pm.heapSize),
"heap_inuse_bytes": atomic.LoadInt64(&pm.heapInUse),
"gc_pause_time_ns": atomic.LoadInt64(&pm.gcPauseTime),
"goroutine_count": atomic.LoadInt64(&pm.goroutineCount),
// Rate limiting metrics
@@ -414,6 +449,10 @@ type ResourceMonitor struct {
// Session limits
maxSessions int64
// Cache size tracking
cacheSizes map[string]int64
cacheMutex sync.RWMutex
// Monitoring state
alertThresholds map[string]float64
alerts []ResourceAlert
@@ -441,11 +480,13 @@ func NewResourceMonitor(perfMetrics *PerformanceMetrics, logger *Logger) *Resour
maxMemoryBytes: 100 * 1024 * 1024, // 100MB default
maxCacheSize: 10000, // 10k items default
maxSessions: 1000, // 1k sessions default
cacheSizes: make(map[string]int64),
alertThresholds: map[string]float64{
"memory_usage": 0.8, // 80%
"cache_usage": 0.9, // 90%
"session_usage": 0.85, // 85%
"error_rate": 0.1, // 10%
"memory_usage": 0.8, // 80%
"memory_pressure": 0.7, // 70%
"cache_usage": 0.9, // 90%
"session_usage": 0.85, // 85%
"error_rate": 0.1, // 10%
},
alerts: make([]ResourceAlert, 0),
perfMetrics: perfMetrics,
@@ -473,6 +514,25 @@ func (rm *ResourceMonitor) SetSessionLimit(count int64) {
rm.maxSessions = count
}
// UpdateCacheSize updates the size of a specific cache
func (rm *ResourceMonitor) UpdateCacheSize(cacheName string, size int64) {
rm.cacheMutex.Lock()
defer rm.cacheMutex.Unlock()
rm.cacheSizes[cacheName] = size
}
// GetCacheSizes returns current cache sizes
func (rm *ResourceMonitor) GetCacheSizes() map[string]int64 {
rm.cacheMutex.RLock()
defer rm.cacheMutex.RUnlock()
sizes := make(map[string]int64)
for name, size := range rm.cacheSizes {
sizes[name] = size
}
return sizes
}
// startMonitoring starts the background monitoring routine
func (rm *ResourceMonitor) startMonitoring() {
ticker := time.NewTicker(10 * time.Second)
@@ -502,6 +562,21 @@ func (rm *ResourceMonitor) checkResourceUsage() {
}
}
// Check memory pressure
if memPressure, ok := metrics["memory_pressure"].(int64); ok {
pressureRatio := float64(memPressure) / 100.0 // Convert to 0-1 scale
if pressureRatio > rm.alertThresholds["memory_pressure"] {
rm.addAlert(ResourceAlert{
Type: "memory_pressure",
Message: "Memory pressure exceeds threshold",
Threshold: rm.alertThresholds["memory_pressure"],
CurrentValue: pressureRatio,
Timestamp: time.Now(),
Severity: rm.getSeverity(pressureRatio, rm.alertThresholds["memory_pressure"]),
})
}
}
// Check cache usage
if cacheSize, ok := metrics["cache_size"].(int64); ok {
cacheUsageRatio := float64(cacheSize) / float64(rm.maxCacheSize)
@@ -592,6 +667,7 @@ func (rm *ResourceMonitor) GetAlerts() []ResourceAlert {
// GetResourceStatus returns current resource status
func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} {
metrics := rm.perfMetrics.GetMetrics()
cacheSizes := rm.GetCacheSizes()
status := map[string]interface{}{
"limits": map[string]interface{}{
@@ -599,8 +675,9 @@ func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} {
"max_cache_size": rm.maxCacheSize,
"max_sessions": rm.maxSessions,
},
"thresholds": rm.alertThresholds,
"current": metrics,
"thresholds": rm.alertThresholds,
"current": metrics,
"cache_sizes": cacheSizes,
// Add expected keys for tests
"memory_limit": uint64(rm.maxMemoryBytes),
"cache_limit": int(rm.maxCacheSize),
@@ -611,6 +688,9 @@ func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} {
if memUsage, ok := metrics["memory_usage_bytes"].(int64); ok {
status["memory_usage_ratio"] = float64(memUsage) / float64(rm.maxMemoryBytes)
}
if memPressure, ok := metrics["memory_pressure"].(int64); ok {
status["memory_pressure_ratio"] = float64(memPressure) / 100.0
}
if cacheSize, ok := metrics["cache_size"].(int64); ok {
status["cache_usage_ratio"] = float64(cacheSize) / float64(rm.maxCacheSize)
}
@@ -618,5 +698,12 @@ func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} {
status["session_usage_ratio"] = float64(activeSessions) / float64(rm.maxSessions)
}
// Calculate total cache size across all caches
var totalCacheSize int64
for _, size := range cacheSizes {
totalCacheSize += size
}
status["total_cache_size"] = totalCacheSize
return status
}
+58 -5
View File
@@ -190,13 +190,18 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*
// Initialize session pool.
sm.sessionPool.New = func() interface{} {
// Initialize SessionData with necessary fields and the mutex.
return &SessionData{
sd := &SessionData{
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
refreshMutex: sync.Mutex{}, // Initialize the mutex
dirty: false, // Initialize dirty flag
refreshMutex: sync.Mutex{}, // Initialize the mutex
sessionMutex: sync.RWMutex{}, // Initialize the session mutex
dirty: false, // Initialize dirty flag
inUse: false, // Initialize in-use flag
}
// Ensure the object is properly reset when created
sd.Reset()
return sd
}
return sm, nil
@@ -482,6 +487,8 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// STABILITY FIX: Mark as not in use and return session to pool, regardless of error.
// This ensures the session is always returned to the pool, preventing memory leaks.
sd.inUse = false
// Reset the session data before returning to pool to prevent data leakage
sd.Reset()
sd.manager.sessionPool.Put(sd)
// Return the error from Save, if any
@@ -602,6 +609,52 @@ func (sd *SessionData) SetAuthenticated(value bool) error {
return nil
}
// Reset clears all session data and prepares the SessionData object for reuse.
// This method is called when returning objects to the pool to prevent data leakage
// between different users/sessions.
func (sd *SessionData) Reset() {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
// Clear all session values if sessions exist
if sd.mainSession != nil {
for k := range sd.mainSession.Values {
delete(sd.mainSession.Values, k)
}
sd.mainSession.ID = ""
sd.mainSession.IsNew = true
}
if sd.accessSession != nil {
for k := range sd.accessSession.Values {
delete(sd.accessSession.Values, k)
}
sd.accessSession.ID = ""
sd.accessSession.IsNew = true
}
if sd.refreshSession != nil {
for k := range sd.refreshSession.Values {
delete(sd.refreshSession.Values, k)
}
sd.refreshSession.ID = ""
sd.refreshSession.IsNew = true
}
// Clear chunk maps
for k := range sd.accessTokenChunks {
delete(sd.accessTokenChunks, k)
}
for k := range sd.refreshTokenChunks {
delete(sd.refreshTokenChunks, k)
}
// Reset state flags
sd.dirty = false
sd.inUse = false
sd.request = nil
}
// ReturnToPool explicitly returns this SessionData object to the pool.
// This should be called when you're done with a SessionData in any error path
// where Clear() is not called, to prevent memory leaks.
@@ -609,8 +662,8 @@ func (sd *SessionData) ReturnToPool() {
if sd != nil && sd.manager != nil {
// STABILITY FIX: Only return to pool if not currently in use
if !sd.inUse {
// Clear request reference to avoid memory leaks
sd.request = nil
// Reset the session data before returning to pool
sd.Reset()
sd.manager.sessionPool.Put(sd)
}
}