diff --git a/helpers.go b/helpers.go index 9f9529f..97d0531 100644 --- a/helpers.go +++ b/helpers.go @@ -123,19 +123,24 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code data.Set("refresh_token", codeOrToken) } - // Create a cookie jar for this request to handle redirects with cookies - jar, _ := cookiejar.New(nil) - client := &http.Client{ - Transport: t.httpClient.Transport, - Timeout: t.httpClient.Timeout, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - // Always follow redirects for OIDC endpoints - if len(via) >= 50 { - return fmt.Errorf("stopped after 50 redirects") - } - return nil - }, - Jar: jar, + // Use the reusable token HTTP client, fallback to creating one if not initialized + client := t.tokenHTTPClient + if client == nil { + // Fallback for tests or incomplete initialization - create a temporary client + // with the same behavior as the original implementation + jar, _ := cookiejar.New(nil) + client = &http.Client{ + Transport: t.httpClient.Transport, + Timeout: t.httpClient.Timeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Always follow redirects for OIDC endpoints + if len(via) >= 50 { + return fmt.Errorf("stopped after 50 redirects") + } + return nil + }, + Jar: jar, + } } req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode())) diff --git a/jwt.go b/jwt.go index 4479ac1..dde5ac3 100644 --- a/jwt.go +++ b/jwt.go @@ -17,26 +17,14 @@ import ( var ( replayCacheMu sync.Mutex - replayCache = make(map[string]time.Time) + replayCache *Cache // Replace unbounded map with bounded Cache ) -// cleanupReplayCache iterates through the replay cache and removes entries -// whose expiration time is before the current time. This function should be -// called periodically to prevent the cache from growing indefinitely. -// It acquires a mutex to ensure thread safety during cleanup. -// SECURITY FIX: Add proper locking protection for cleanupReplayCache -func cleanupReplayCache() { - now := time.Now() - // SECURITY FIX: Use safe iteration with proper locking - toDelete := make([]string, 0) - for token, expiry := range replayCache { - if expiry.Before(now) { - toDelete = append(toDelete, token) - } - } - // Delete expired entries - for _, token := range toDelete { - delete(replayCache, token) +// initReplayCache initializes the global replay cache with size limit +func initReplayCache() { + if replayCache == nil { + replayCache = NewCache() + replayCache.SetMaxSize(10000) // Set size limit to 10,000 entries } } @@ -203,15 +191,15 @@ func (j *JWT) Verify(issuerURL, clientID string) error { return nil } - // SECURITY FIX: Implement thread-safe replay cache operations with proper locking + // SECURITY FIX: Use bounded Cache with thread-safe operations replayCacheMu.Lock() - defer replayCacheMu.Unlock() // Ensure unlock happens even if panic occurs + defer replayCacheMu.Unlock() - // SECURITY FIX: Clean up expired entries safely - cleanupReplayCache() + // Initialize cache if not already done + initReplayCache() - // SECURITY FIX: Check for replay attack with atomic operation - if _, exists := replayCache[jti]; exists { + // SECURITY FIX: Check for replay attack using Cache API + if _, exists := replayCache.Get(jti); exists { return fmt.Errorf("token replay detected") } @@ -224,8 +212,11 @@ func (j *JWT) Verify(issuerURL, clientID string) error { expTime = time.Now().Add(10 * time.Minute) } - // SECURITY FIX: Add to replay cache atomically - replayCache[jti] = expTime + // SECURITY FIX: Add to replay cache with expiration using Cache API + duration := time.Until(expTime) + if duration > 0 { + replayCache.Set(jti, true, duration) + } } sub, ok := claims["sub"].(string) diff --git a/main.go b/main.go index ef754de..576253e 100644 --- a/main.go +++ b/main.go @@ -9,9 +9,11 @@ import ( "math" "net" "net/http" + "net/http/cookiejar" "net/url" "runtime" "strings" + "sync" "text/template" "time" @@ -59,6 +61,33 @@ func createDefaultHTTPClient() *http.Client { } } +// createTokenHTTPClient creates a specialized HTTP client for token operations. +// It reuses the transport from the main HTTP client but adds cookie jar support +// and optimized redirect handling for OIDC token endpoints. +// +// Parameters: +// - baseClient: The base HTTP client to derive transport settings from. +// +// Returns: +// - A pointer to the configured http.Client optimized for token operations. +func createTokenHTTPClient(baseClient *http.Client) *http.Client { + // Create a cookie jar for handling redirects with cookies + jar, _ := cookiejar.New(nil) + + return &http.Client{ + Transport: baseClient.Transport, // Reuse the transport from base client + Timeout: baseClient.Timeout, // Reuse the timeout from base client + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Always follow redirects for OIDC endpoints + if len(via) >= 50 { + return fmt.Errorf("stopped after 50 redirects") + } + return nil + }, + Jar: jar, // Add cookie jar for redirect handling + } +} + const ( ConstSessionTimeout = 86400 // Session timeout in seconds defaultBlacklistDuration = 24 * time.Hour // Default duration to blacklist a JTI @@ -105,6 +134,7 @@ type TraefikOidc struct { scheme string tokenCache *TokenCache httpClient *http.Client + tokenHTTPClient *http.Client // Reusable HTTP client for token operations logger *Logger tokenVerifier TokenVerifier jwtVerifier JWTVerifier @@ -124,6 +154,7 @@ type TraefikOidc struct { headerTemplates map[string]*template.Template // Parsed templates for custom headers tokenCleanupStopChan chan struct{} // Channel to stop token cleanup goroutine metadataRefreshStopChan chan struct{} // Channel to stop metadata refresh goroutine + goroutineWG sync.WaitGroup // WaitGroup to track background goroutines } // ProviderMetadata holds OIDC provider metadata @@ -243,7 +274,14 @@ func (t *TraefikOidc) VerifyToken(token string) error { // Also update the global replayCache for backwards compatibility replayCacheMu.Lock() - replayCache[jti] = expiry + // Initialize cache if not already done + if replayCache == nil { + initReplayCache() + } + duration := time.Until(expiry) + if duration > 0 { + replayCache.Set(jti, true, duration) + } replayCacheMu.Unlock() } @@ -416,6 +454,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit), tokenCache: NewTokenCache(), httpClient: httpClient, + tokenHTTPClient: createTokenHTTPClient(httpClient), excludedURLs: createStringMap(config.ExcludedURLs), allowedUserDomains: createStringMap(config.AllowedUserDomains), allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers), @@ -526,30 +565,34 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) { // - providerURL: The base URL of bogged OIDC provider, used for subsequent refresh attempts. func (t *TraefikOidc) startMetadataRefresh(providerURL string) { ticker := time.NewTicker(1 * time.Hour) - // No defer ticker.Stop() here, it's stopped in the select case + t.goroutineWG.Add(1) // Track this goroutine - for { - select { - case <-ticker.C: - t.logger.Debug("Refreshing OIDC metadata") - metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger) - if err != nil { - t.logger.Errorf("Failed to refresh metadata: %v", err) - continue - } + go func() { + defer t.goroutineWG.Done() // Signal completion when goroutine exits + defer ticker.Stop() // Ensure ticker is always stopped - if metadata != nil { - t.updateMetadataEndpoints(metadata) - t.logger.Debug("Successfully refreshed metadata") - } else { - t.logger.Error("Received nil metadata during refresh") + for { + select { + case <-ticker.C: + t.logger.Debug("Refreshing OIDC metadata") + metadata, err := t.metadataCache.GetMetadata(providerURL, t.httpClient, t.logger) + if err != nil { + t.logger.Errorf("Failed to refresh metadata: %v", err) + continue + } + + if metadata != nil { + t.updateMetadataEndpoints(metadata) + t.logger.Debug("Successfully refreshed metadata") + } else { + t.logger.Error("Received nil metadata during refresh") + } + case <-t.metadataRefreshStopChan: + t.logger.Debug("Metadata refresh goroutine stopped.") + return } - case <-t.metadataRefreshStopChan: - ticker.Stop() - t.logger.Debug("Metadata refresh goroutine stopped.") - return } - } + }() } // discoverProviderMetadata attempts to fetch the OIDC provider's configuration from its @@ -1720,8 +1763,11 @@ func (t *TraefikOidc) validateHost(host string) error { // the token cache, token blacklist cache, and JWK cache. func (t *TraefikOidc) startTokenCleanup() { ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute + t.goroutineWG.Add(1) // Track this goroutine go func() { - // No defer ticker.Stop() here, it's stopped in the select case + defer t.goroutineWG.Done() // Signal completion when goroutine exits + defer ticker.Stop() // Ensure ticker is always stopped + for { select { case <-ticker.C: @@ -1741,7 +1787,6 @@ func (t *TraefikOidc) startTokenCleanup() { t.jwkCache.Cleanup() } case <-t.tokenCleanupStopChan: - ticker.Stop() t.logger.Debug("Token cleanup goroutine stopped.") return } @@ -2231,20 +2276,36 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques _, _ = rw.Write([]byte(htmlBody)) // Ignore write error as header is already sent } -// Close stops all background goroutines and closes resources. +// Close stops all background goroutines and closes resources with proper timeout. func (t *TraefikOidc) Close() error { t.logger.Debug("Closing TraefikOidc plugin instance") - // Signal and close tokenCleanupStopChan + + // Signal all goroutines to stop if t.tokenCleanupStopChan != nil { close(t.tokenCleanupStopChan) t.logger.Debug("tokenCleanupStopChan closed") } - // Signal and close metadataRefreshStopChan if t.metadataRefreshStopChan != nil { close(t.metadataRefreshStopChan) t.logger.Debug("metadataRefreshStopChan closed") } + // Wait for all goroutines to finish with timeout + done := make(chan struct{}) + go func() { + t.goroutineWG.Wait() + close(done) + }() + + // Wait for goroutines to finish or timeout after 10 seconds + select { + case <-done: + t.logger.Debug("All background goroutines stopped gracefully") + case <-time.After(10 * time.Second): + t.logger.Errorf("Timeout waiting for background goroutines to stop") + // Continue with cleanup even if goroutines didn't stop gracefully + } + // Close caches // These Close methods should stop their respective autoCleanupRoutine goroutines if t.tokenBlacklist != nil { diff --git a/main_test.go b/main_test.go index 4ab7a48..a65e74a 100644 --- a/main_test.go +++ b/main_test.go @@ -722,7 +722,8 @@ func TestServeHTTP(t *testing.T) { // Reset the global replayCache to prevent "token replay detected" errors replayCacheMu.Lock() - replayCache = make(map[string]time.Time) // Reset the global cache + replayCache = NewCache() + replayCache.SetMaxSize(10000) replayCacheMu.Unlock() // Store original tokenVerifier to restore later @@ -734,7 +735,8 @@ func TestServeHTTP(t *testing.T) { VerifyFunc: func(token string) error { // Clear replay cache before token verification replayCacheMu.Lock() - replayCache = make(map[string]time.Time) + replayCache = NewCache() + replayCache.SetMaxSize(10000) replayCacheMu.Unlock() // Call the original verifier's VerifyToken method @@ -1143,7 +1145,8 @@ func TestHandleCallback(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Clear the global replay cache before each test run replayCacheMu.Lock() - replayCache = make(map[string]time.Time) // Reset the global cache + replayCache = NewCache() + replayCache.SetMaxSize(10000) replayCacheMu.Unlock() // Explicitly clear the shared blacklist at the start of each sub-test diff --git a/performance_monitoring.go b/performance_monitoring.go index 3845410..8037c9d 100644 --- a/performance_monitoring.go +++ b/performance_monitoring.go @@ -36,6 +36,10 @@ type PerformanceMetrics struct { // Resource metrics memoryUsage int64 goroutineCount int64 + memoryPressure int64 // Memory pressure level (0-100) + gcPauseTime int64 // Last GC pause time in nanoseconds + heapSize int64 // Current heap size + heapInUse int64 // Heap memory in use // Error metrics (kept for backward compatibility) verificationErrors int64 @@ -251,9 +255,36 @@ func (pm *PerformanceMetrics) collectSystemMetrics() { var m runtime.MemStats runtime.ReadMemStats(&m) atomic.StoreInt64(&pm.memoryUsage, int64(m.Alloc)) + atomic.StoreInt64(&pm.heapSize, int64(m.HeapSys)) + atomic.StoreInt64(&pm.heapInUse, int64(m.HeapInuse)) + atomic.StoreInt64(&pm.gcPauseTime, int64(m.PauseNs[(m.NumGC+255)%256])) + + // Calculate memory pressure (0-100 scale) + // Based on heap utilization and GC frequency + heapUtilization := float64(m.HeapInuse) / float64(m.HeapSys) + gcFrequency := float64(m.NumGC) / time.Since(pm.startTime).Minutes() + + // Memory pressure calculation + pressure := int64(heapUtilization * 50) // 0-50 based on heap utilization + if gcFrequency > 10 { // High GC frequency indicates pressure + pressure += int64((gcFrequency - 10) * 2) // Add up to 50 more + } + if pressure > 100 { + pressure = 100 + } + atomic.StoreInt64(&pm.memoryPressure, pressure) // Goroutine count atomic.StoreInt64(&pm.goroutineCount, int64(runtime.NumGoroutine())) + + // Log memory pressure warnings + if pressure > 80 { + pm.logger.Errorf("High memory pressure detected: %d%% (heap utilization: %.1f%%, GC frequency: %.1f/min)", + pressure, heapUtilization*100, gcFrequency) + } else if pressure > 60 { + pm.logger.Infof("Moderate memory pressure: %d%% (heap utilization: %.1f%%, GC frequency: %.1f/min)", + pressure, heapUtilization*100, gcFrequency) + } } // GetMetrics returns all current performance metrics @@ -317,6 +348,10 @@ func (pm *PerformanceMetrics) GetMetrics() map[string]interface{} { // Resource metrics "memory_usage_bytes": atomic.LoadInt64(&pm.memoryUsage), + "memory_pressure": atomic.LoadInt64(&pm.memoryPressure), + "heap_size_bytes": atomic.LoadInt64(&pm.heapSize), + "heap_inuse_bytes": atomic.LoadInt64(&pm.heapInUse), + "gc_pause_time_ns": atomic.LoadInt64(&pm.gcPauseTime), "goroutine_count": atomic.LoadInt64(&pm.goroutineCount), // Rate limiting metrics @@ -414,6 +449,10 @@ type ResourceMonitor struct { // Session limits maxSessions int64 + // Cache size tracking + cacheSizes map[string]int64 + cacheMutex sync.RWMutex + // Monitoring state alertThresholds map[string]float64 alerts []ResourceAlert @@ -441,11 +480,13 @@ func NewResourceMonitor(perfMetrics *PerformanceMetrics, logger *Logger) *Resour maxMemoryBytes: 100 * 1024 * 1024, // 100MB default maxCacheSize: 10000, // 10k items default maxSessions: 1000, // 1k sessions default + cacheSizes: make(map[string]int64), alertThresholds: map[string]float64{ - "memory_usage": 0.8, // 80% - "cache_usage": 0.9, // 90% - "session_usage": 0.85, // 85% - "error_rate": 0.1, // 10% + "memory_usage": 0.8, // 80% + "memory_pressure": 0.7, // 70% + "cache_usage": 0.9, // 90% + "session_usage": 0.85, // 85% + "error_rate": 0.1, // 10% }, alerts: make([]ResourceAlert, 0), perfMetrics: perfMetrics, @@ -473,6 +514,25 @@ func (rm *ResourceMonitor) SetSessionLimit(count int64) { rm.maxSessions = count } +// UpdateCacheSize updates the size of a specific cache +func (rm *ResourceMonitor) UpdateCacheSize(cacheName string, size int64) { + rm.cacheMutex.Lock() + defer rm.cacheMutex.Unlock() + rm.cacheSizes[cacheName] = size +} + +// GetCacheSizes returns current cache sizes +func (rm *ResourceMonitor) GetCacheSizes() map[string]int64 { + rm.cacheMutex.RLock() + defer rm.cacheMutex.RUnlock() + + sizes := make(map[string]int64) + for name, size := range rm.cacheSizes { + sizes[name] = size + } + return sizes +} + // startMonitoring starts the background monitoring routine func (rm *ResourceMonitor) startMonitoring() { ticker := time.NewTicker(10 * time.Second) @@ -502,6 +562,21 @@ func (rm *ResourceMonitor) checkResourceUsage() { } } + // Check memory pressure + if memPressure, ok := metrics["memory_pressure"].(int64); ok { + pressureRatio := float64(memPressure) / 100.0 // Convert to 0-1 scale + if pressureRatio > rm.alertThresholds["memory_pressure"] { + rm.addAlert(ResourceAlert{ + Type: "memory_pressure", + Message: "Memory pressure exceeds threshold", + Threshold: rm.alertThresholds["memory_pressure"], + CurrentValue: pressureRatio, + Timestamp: time.Now(), + Severity: rm.getSeverity(pressureRatio, rm.alertThresholds["memory_pressure"]), + }) + } + } + // Check cache usage if cacheSize, ok := metrics["cache_size"].(int64); ok { cacheUsageRatio := float64(cacheSize) / float64(rm.maxCacheSize) @@ -592,6 +667,7 @@ func (rm *ResourceMonitor) GetAlerts() []ResourceAlert { // GetResourceStatus returns current resource status func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} { metrics := rm.perfMetrics.GetMetrics() + cacheSizes := rm.GetCacheSizes() status := map[string]interface{}{ "limits": map[string]interface{}{ @@ -599,8 +675,9 @@ func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} { "max_cache_size": rm.maxCacheSize, "max_sessions": rm.maxSessions, }, - "thresholds": rm.alertThresholds, - "current": metrics, + "thresholds": rm.alertThresholds, + "current": metrics, + "cache_sizes": cacheSizes, // Add expected keys for tests "memory_limit": uint64(rm.maxMemoryBytes), "cache_limit": int(rm.maxCacheSize), @@ -611,6 +688,9 @@ func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} { if memUsage, ok := metrics["memory_usage_bytes"].(int64); ok { status["memory_usage_ratio"] = float64(memUsage) / float64(rm.maxMemoryBytes) } + if memPressure, ok := metrics["memory_pressure"].(int64); ok { + status["memory_pressure_ratio"] = float64(memPressure) / 100.0 + } if cacheSize, ok := metrics["cache_size"].(int64); ok { status["cache_usage_ratio"] = float64(cacheSize) / float64(rm.maxCacheSize) } @@ -618,5 +698,12 @@ func (rm *ResourceMonitor) GetResourceStatus() map[string]interface{} { status["session_usage_ratio"] = float64(activeSessions) / float64(rm.maxSessions) } + // Calculate total cache size across all caches + var totalCacheSize int64 + for _, size := range cacheSizes { + totalCacheSize += size + } + status["total_cache_size"] = totalCacheSize + return status } diff --git a/session.go b/session.go index e94948a..71a92d8 100644 --- a/session.go +++ b/session.go @@ -190,13 +190,18 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (* // Initialize session pool. sm.sessionPool.New = func() interface{} { // Initialize SessionData with necessary fields and the mutex. - return &SessionData{ + sd := &SessionData{ manager: sm, accessTokenChunks: make(map[int]*sessions.Session), refreshTokenChunks: make(map[int]*sessions.Session), - refreshMutex: sync.Mutex{}, // Initialize the mutex - dirty: false, // Initialize dirty flag + refreshMutex: sync.Mutex{}, // Initialize the mutex + sessionMutex: sync.RWMutex{}, // Initialize the session mutex + dirty: false, // Initialize dirty flag + inUse: false, // Initialize in-use flag } + // Ensure the object is properly reset when created + sd.Reset() + return sd } return sm, nil @@ -482,6 +487,8 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error { // STABILITY FIX: Mark as not in use and return session to pool, regardless of error. // This ensures the session is always returned to the pool, preventing memory leaks. sd.inUse = false + // Reset the session data before returning to pool to prevent data leakage + sd.Reset() sd.manager.sessionPool.Put(sd) // Return the error from Save, if any @@ -602,6 +609,52 @@ func (sd *SessionData) SetAuthenticated(value bool) error { return nil } +// Reset clears all session data and prepares the SessionData object for reuse. +// This method is called when returning objects to the pool to prevent data leakage +// between different users/sessions. +func (sd *SessionData) Reset() { + sd.sessionMutex.Lock() + defer sd.sessionMutex.Unlock() + + // Clear all session values if sessions exist + if sd.mainSession != nil { + for k := range sd.mainSession.Values { + delete(sd.mainSession.Values, k) + } + sd.mainSession.ID = "" + sd.mainSession.IsNew = true + } + + if sd.accessSession != nil { + for k := range sd.accessSession.Values { + delete(sd.accessSession.Values, k) + } + sd.accessSession.ID = "" + sd.accessSession.IsNew = true + } + + if sd.refreshSession != nil { + for k := range sd.refreshSession.Values { + delete(sd.refreshSession.Values, k) + } + sd.refreshSession.ID = "" + sd.refreshSession.IsNew = true + } + + // Clear chunk maps + for k := range sd.accessTokenChunks { + delete(sd.accessTokenChunks, k) + } + for k := range sd.refreshTokenChunks { + delete(sd.refreshTokenChunks, k) + } + + // Reset state flags + sd.dirty = false + sd.inUse = false + sd.request = nil +} + // ReturnToPool explicitly returns this SessionData object to the pool. // This should be called when you're done with a SessionData in any error path // where Clear() is not called, to prevent memory leaks. @@ -609,8 +662,8 @@ func (sd *SessionData) ReturnToPool() { if sd != nil && sd.manager != nil { // STABILITY FIX: Only return to pool if not currently in use if !sd.inUse { - // Clear request reference to avoid memory leaks - sd.request = nil + // Reset the session data before returning to pool + sd.Reset() sd.manager.sessionPool.Put(sd) } }