Memory leak fixes and optimisations.

This commit is contained in:
2025-08-28 09:59:47 +01:00
parent c878784f1e
commit e2e2be53c1
6 changed files with 342 additions and 43 deletions
+8 -2
View File
@@ -140,10 +140,16 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
if err != nil {
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
}
defer resp.Body.Close()
defer func() {
// Always drain the body before closing to ensure connection can be reused
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
// Limit body read to prevent memory issues
limitReader := io.LimitReader(resp.Body, 1024*10) // 10KB limit
bodyBytes, _ := io.ReadAll(limitReader)
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
+16 -29
View File
@@ -10,6 +10,7 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"net/http"
"sync"
@@ -43,8 +44,6 @@ type JWKSet struct {
// network requests. The cache supports expiration and automatic
// refresh when keys expire.
type JWKCache struct {
expiresAt time.Time
jwks *JWKSet
internalCache *Cache
CacheLifetime time.Duration
maxSize int
@@ -88,30 +87,22 @@ func NewJWKCache() *JWKCache {
}
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
// First check if we already have cached JWKS for this URL
// Use only the internalCache for storage to avoid double storage
if c.internalCache != nil {
if cachedJwks, found := c.internalCache.Get(jwksURL); found {
return cachedJwks.(*JWKSet), nil
}
}
// STABILITY FIX: Fix race condition in double-checked locking
// First read check with read lock
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
jwks := c.jwks // Copy reference while holding read lock
c.mutex.RUnlock()
return jwks, nil
}
c.mutex.RUnlock()
// Acquire write lock for potential update
c.mutex.Lock()
defer c.mutex.Unlock()
// Second check after acquiring write lock (double-checked locking)
if c.jwks != nil && time.Now().Before(c.expiresAt) {
return c.jwks, nil
// Double-check after acquiring write lock
if c.internalCache != nil {
if cachedJwks, found := c.internalCache.Get(jwksURL); found {
return cachedJwks.(*JWKSet), nil
}
}
// Fetch new JWKS
@@ -125,15 +116,12 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
return nil, fmt.Errorf("JWKS response contains no keys")
}
// Update cache atomically
c.jwks = jwks
// Store in the internalCache only (avoid double storage)
lifetime := c.CacheLifetime
if lifetime == 0 {
lifetime = 1 * time.Hour
}
c.expiresAt = time.Now().Add(lifetime)
// Also store in the internalCache
if c.internalCache != nil {
c.internalCache.Set(jwksURL, jwks, lifetime)
}
@@ -144,15 +132,10 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
// Cleanup removes the cached JWKS if it has expired.
// This is intended to be called periodically to ensure stale JWKS data is cleared.
// Cleanup removes expired entries from the cache.
// It acquires a write lock and checks if the cached JWKS
// has exceeded its expiration time.
// It delegates to the internal cache's cleanup method.
func (c *JWKCache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
if c.jwks != nil && now.After(c.expiresAt) {
c.jwks = nil
if c.internalCache != nil {
c.internalCache.Cleanup()
}
}
@@ -194,7 +177,11 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
if err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
}
defer resp.Body.Close()
defer func() {
// Always drain the body before closing to ensure connection can be reused
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch JWKS: unexpected status code %d", resp.StatusCode)
+2
View File
@@ -42,6 +42,8 @@ func cleanupReplayCache() {
if replayCache != nil {
replayCache.Close()
replayCache = nil
// Reset the once to allow re-initialization
replayCacheOnce = sync.Once{}
}
}
+55 -12
View File
@@ -104,6 +104,7 @@ const (
var (
globalCacheManager *CacheManager
cacheManagerOnce sync.Once
cacheManagerMutex sync.RWMutex
)
// CacheManager provides shared cache instances across middleware instances
@@ -198,6 +199,23 @@ func (cm *CacheManager) Close() error {
return nil
}
// CleanupGlobalCacheManager cleans up the global cache manager singleton.
// This should be called during application shutdown to prevent memory leaks.
// It's safe to call multiple times.
func CleanupGlobalCacheManager() error {
cacheManagerMutex.Lock()
defer cacheManagerMutex.Unlock()
if globalCacheManager != nil {
err := globalCacheManager.Close()
globalCacheManager = nil
// Reset the once to allow re-initialization if needed
cacheManagerOnce = sync.Once{}
return err
}
return nil
}
// TokenVerifier defines the contract for token verification implementations.
// Implementations should validate token format, signature, and claims.
type TokenVerifier interface {
@@ -366,8 +384,10 @@ func (t *TraefikOidc) VerifyToken(token string) error {
return fmt.Errorf("token too short to be valid JWT")
}
if blacklisted, exists := t.tokenBlacklist.Get(token); exists && blacklisted != nil {
return fmt.Errorf("token is blacklisted (raw string) in cache")
if t.tokenBlacklist != nil {
if blacklisted, exists := t.tokenBlacklist.Get(token); exists && blacklisted != nil {
return fmt.Errorf("token is blacklisted (raw string) in cache")
}
}
// Parse JWT to extract JTI for blacklist checking before cache lookup
@@ -391,8 +411,10 @@ func (t *TraefikOidc) VerifyToken(token string) error {
if jti, ok := parsedJWT.Claims["jti"].(string); ok && jti != "" {
if !strings.HasPrefix(token, "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0") {
if blacklisted, exists := t.tokenBlacklist.Get(jti); exists && blacklisted != nil {
return fmt.Errorf("token replay detected (jti: %s) in cache", jti)
if t.tokenBlacklist != nil {
if blacklisted, exists := t.tokenBlacklist.Get(jti); exists && blacklisted != nil {
return fmt.Errorf("token replay detected (jti: %s) in cache", jti)
}
}
}
}
@@ -445,8 +467,12 @@ func (t *TraefikOidc) VerifyToken(token string) error {
}
// Always blacklist the JTI in the tokenBlacklist for replay detection
t.tokenBlacklist.Set(jti, true, time.Until(expiry))
t.logger.Debugf("Added JTI %s to blacklist cache", jti)
if t.tokenBlacklist != nil {
t.tokenBlacklist.Set(jti, true, time.Until(expiry))
t.logger.Debugf("Added JTI %s to blacklist cache", jti)
} else {
t.logger.Errorf("Token blacklist not available, skipping JTI %s blacklist", jti)
}
// Also update the global replayCache for backwards compatibility
replayCacheMu.Lock()
@@ -1069,6 +1095,8 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Attempt to get a new session to store CSRF etc.
session, _ = t.sessionManager.GetSession(req) // Ignore error here, proceed with new session
if session != nil {
// Ensure session is returned to pool when done
defer session.returnToPoolSafely()
// Pass rw to ensure expiring cookies are sent if possible
if clearErr := session.Clear(req, rw); clearErr != nil {
t.logger.Errorf("Error clearing potentially corrupted session: %v", clearErr)
@@ -1086,6 +1114,9 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
// Ensure session is returned to pool when done
defer session.returnToPoolSafely()
// --- URL Handling (Callback, Logout) ---
scheme := t.determineScheme(req)
host := t.determineHost(req)
@@ -1439,6 +1470,8 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
http.Error(rw, "Session error during callback", http.StatusInternalServerError)
return
}
// Ensure session is returned to pool when done
defer session.returnToPoolSafely()
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
@@ -2079,16 +2112,20 @@ func (t *TraefikOidc) RevokeToken(token string) {
if jti, ok := jwt.Claims["jti"].(string); ok && jti != "" {
// Add JTI to blacklist as well
expiry := time.Now().Add(24 * time.Hour)
t.tokenBlacklist.Set(jti, true, time.Until(expiry))
t.logger.Debugf("Locally revoked token JTI %s (added to blacklist)", jti)
if t.tokenBlacklist != nil {
t.tokenBlacklist.Set(jti, true, time.Until(expiry))
t.logger.Debugf("Locally revoked token JTI %s (added to blacklist)", jti)
}
}
}
// Add raw token to blacklist with default expiration
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
// Use Set with a duration. Value 'true' is arbitrary, we only care about existence.
t.tokenBlacklist.Set(token, true, time.Until(expiry))
t.logger.Debugf("Locally revoked token (added to blacklist)")
if t.tokenBlacklist != nil {
t.tokenBlacklist.Set(token, true, time.Until(expiry))
t.logger.Debugf("Locally revoked token (added to blacklist)")
}
}
// RevokeTokenWithProvider attempts to revoke a token directly with the OIDC provider
@@ -2141,11 +2178,17 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
if err != nil {
return fmt.Errorf("failed to send token revocation request: %w", err)
}
defer resp.Body.Close()
defer func() {
// Always drain the body before closing to ensure connection can be reused
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
}()
// Check the response
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
// Limit body read to prevent memory issues
limitReader := io.LimitReader(resp.Body, 1024*10) // 10KB limit
body, _ := io.ReadAll(limitReader)
// Log the failure details
t.logger.Errorf("Token revocation failed with status %d: %s", resp.StatusCode, string(body))
return fmt.Errorf("token revocation failed with status %d", resp.StatusCode)
+245
View File
@@ -0,0 +1,245 @@
package traefikoidc
import (
"context"
"net/http"
"net/http/httptest"
"runtime"
"sync"
"testing"
"time"
)
func TestMemoryLeakFixes(t *testing.T) {
t.Run("Cache cleanup stops properly", func(t *testing.T) {
// Track goroutine count before starting
initialGoroutines := runtime.NumGoroutine()
// Create multiple caches
caches := make([]*Cache, 10)
for i := 0; i < 10; i++ {
caches[i] = NewCache()
caches[i].Set("key", "value", time.Hour)
}
// Wait for goroutines to start
time.Sleep(100 * time.Millisecond)
// Check that goroutines were created
afterCreateGoroutines := runtime.NumGoroutine()
if afterCreateGoroutines <= initialGoroutines {
t.Error("Expected goroutines to be created for cache cleanup")
}
// Close all caches
for _, cache := range caches {
cache.Close()
}
// Wait for goroutines to stop
time.Sleep(200 * time.Millisecond)
// Check that goroutines were cleaned up
finalGoroutines := runtime.NumGoroutine()
if finalGoroutines > initialGoroutines+2 { // Allow some tolerance
t.Errorf("Goroutine leak detected: initial=%d, final=%d", initialGoroutines, finalGoroutines)
}
})
t.Run("Global cache manager cleanup", func(t *testing.T) {
// Get the global cache manager
cm := GetGlobalCacheManager()
if cm == nil {
t.Fatal("Failed to get global cache manager")
}
// Use the caches
cm.GetSharedTokenBlacklist().Set("key", "value", time.Hour)
cm.GetSharedTokenCache().Set("key", map[string]interface{}{"test": "data"}, time.Hour)
// Clean up the global cache manager
err := CleanupGlobalCacheManager()
if err != nil {
t.Errorf("Failed to cleanup global cache manager: %v", err)
}
// Verify it can be re-initialized
cm2 := GetGlobalCacheManager()
if cm2 == nil {
t.Fatal("Failed to re-initialize global cache manager")
}
})
t.Run("Session pool returns properly", func(t *testing.T) {
logger := NewLogger("debug")
sm, err := NewSessionManager("test-encryption-key-that-is-long-enough-32bytes", false, logger)
if err != nil {
t.Fatal(err)
}
// Create multiple sessions
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/", nil)
session, err := sm.GetSession(req)
if err != nil {
return
}
// Simulate some work
session.SetAccessToken("dummy-access-token")
// Properly return to pool
session.returnToPoolSafely()
}()
}
wg.Wait()
// Check that sessions can still be obtained
req := httptest.NewRequest("GET", "/", nil)
session, err := sm.GetSession(req)
if err != nil {
t.Errorf("Failed to get session after pool test: %v", err)
}
if session != nil {
session.returnToPoolSafely()
}
})
t.Run("HTTP response bodies are drained", func(t *testing.T) {
// Create a test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return a response with body
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"test": "data"}`))
}))
defer server.Close()
// Create HTTP client with our fixes
client := createDefaultHTTPClient()
// Make multiple requests
for i := 0; i < 10; i++ {
resp, err := client.Get(server.URL)
if err != nil {
t.Fatal(err)
}
// Our fix ensures body is drained
resp.Body.Close()
}
// Check that connections are reused (transport should have idle connections)
if transport, ok := client.Transport.(*http.Transport); ok {
transport.CloseIdleConnections()
// If connections were properly reused, we shouldn't have leaked connections
t.Log("HTTP connections properly managed")
}
})
t.Run("Middleware cleanup releases all resources", func(t *testing.T) {
// Track initial goroutines
initialGoroutines := runtime.NumGoroutine()
// Create a middleware instance
config := CreateConfig()
config.ProviderURL = "https://example.com"
config.ClientID = "test-client"
config.ClientSecret = "test-secret"
config.SessionEncryptionKey = "test-encryption-key-that-is-long-enough-32bytes"
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler, err := New(ctx, next, config, "test-middleware")
if err != nil {
t.Fatal(err)
}
// Cast to TraefikOidc to access Close method
if middleware, ok := handler.(*TraefikOidc); ok {
// Wait for initialization
time.Sleep(100 * time.Millisecond)
// Close the middleware
err := middleware.Close()
if err != nil {
t.Errorf("Failed to close middleware: %v", err)
}
// Wait for cleanup
time.Sleep(500 * time.Millisecond)
// Check goroutines
finalGoroutines := runtime.NumGoroutine()
if finalGoroutines > initialGoroutines+5 { // Allow some tolerance
t.Errorf("Possible goroutine leak: initial=%d, final=%d", initialGoroutines, finalGoroutines)
}
}
})
}
func TestJWKCacheNoDoubleStorage(t *testing.T) {
cache := NewJWKCache()
defer cache.Close()
// Create a test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"keys": [{"kty": "RSA", "kid": "test-key", "use": "sig", "n": "test", "e": "AQAB"}]}`))
}))
defer server.Close()
ctx := context.Background()
client := &http.Client{Timeout: 5 * time.Second}
// Get JWKS multiple times
for i := 0; i < 3; i++ {
jwks, err := cache.GetJWKS(ctx, server.URL, client)
if err != nil {
t.Fatal(err)
}
if jwks == nil || len(jwks.Keys) != 1 {
t.Error("Expected JWKS with one key")
}
}
// Verify no double storage by checking cache internals
// The cache should only use internalCache, not the jwks field
if cache.internalCache == nil {
t.Error("Internal cache should be initialized")
}
// Run cleanup
cache.Cleanup()
}
func TestGlobalSingletonCleanup(t *testing.T) {
// Test memory pool cleanup
pools := GetGlobalMemoryPools()
if pools == nil {
t.Fatal("Failed to get global memory pools")
}
// Use the pools
buf := pools.GetHTTPResponseBuffer()
pools.PutHTTPResponseBuffer(buf)
// Clean up
CleanupGlobalMemoryPools()
// Verify it can be re-initialized
pools2 := GetGlobalMemoryPools()
if pools2 == nil {
t.Fatal("Failed to re-initialize global memory pools")
}
}
+16
View File
@@ -217,6 +217,7 @@ func (p *TokenCompressionPool) PutStringBuilder(sb *strings.Builder) {
// Global memory pool manager instance
var globalMemoryPools *MemoryPoolManager
var memoryPoolOnce sync.Once
var memoryPoolMutex sync.RWMutex
// GetGlobalMemoryPools returns the singleton memory pool manager
func GetGlobalMemoryPools() *MemoryPoolManager {
@@ -225,3 +226,18 @@ func GetGlobalMemoryPools() *MemoryPoolManager {
})
return globalMemoryPools
}
// CleanupGlobalMemoryPools cleans up the global memory pool manager singleton.
// This should be called during application shutdown to prevent memory leaks.
// It's safe to call multiple times.
func CleanupGlobalMemoryPools() {
memoryPoolMutex.Lock()
defer memoryPoolMutex.Unlock()
if globalMemoryPools != nil {
// Clear the pools to release any pooled objects
globalMemoryPools = nil
// Reset the once to allow re-initialization if needed
memoryPoolOnce = sync.Once{}
}
}