Compare commits

...

4 Commits

16 changed files with 2950 additions and 62 deletions
+24 -2
View File
@@ -1,6 +1,9 @@
package traefikoidc
import "time"
import (
"sync"
"time"
)
// BackgroundTask represents a managed recurring task that runs in the background.
// It provides a clean interface for starting and stopping periodic operations
@@ -11,6 +14,7 @@ type BackgroundTask struct {
logger *Logger
name string
interval time.Duration
wg *sync.WaitGroup
}
// NewBackgroundTask creates a new background task with the specified parameters.
@@ -20,35 +24,53 @@ type BackgroundTask struct {
// - interval: Duration between task executions.
// - taskFunc: The function to execute periodically.
// - logger: Logger instance for task lifecycle events.
// - wg: Optional WaitGroup for synchronizing goroutine completion.
//
// Returns:
// - A configured BackgroundTask ready to be started.
func NewBackgroundTask(name string, interval time.Duration, taskFunc func(), logger *Logger) *BackgroundTask {
func NewBackgroundTask(name string, interval time.Duration, taskFunc func(), logger *Logger, wg ...*sync.WaitGroup) *BackgroundTask {
var waitGroup *sync.WaitGroup
if len(wg) > 0 {
waitGroup = wg[0]
}
return &BackgroundTask{
name: name,
interval: interval,
stopChan: make(chan struct{}),
taskFunc: taskFunc,
logger: logger,
wg: waitGroup,
}
}
// Start begins the background task execution in a separate goroutine.
// The task runs immediately upon start and then at the specified interval.
func (bt *BackgroundTask) Start() {
if bt.wg != nil {
bt.wg.Add(1)
}
go bt.run()
}
// Stop gracefully terminates the background task by closing the stop channel.
// If a WaitGroup was provided, it waits for the goroutine to complete.
// This method is safe to call multiple times.
func (bt *BackgroundTask) Stop() {
close(bt.stopChan)
if bt.wg != nil {
bt.wg.Wait()
}
}
// run is the main execution loop for the background task.
// It executes the task function immediately and then at regular intervals
// until the stop signal is received.
func (bt *BackgroundTask) run() {
defer func() {
if bt.wg != nil {
bt.wg.Done()
}
}()
ticker := time.NewTicker(bt.interval)
defer ticker.Stop()
+12 -6
View File
@@ -53,7 +53,7 @@ func NewCacheWithLogger(logger *Logger) *Cache {
order: list.New(),
elems: make(map[string]*list.Element, DefaultMaxSize),
maxSize: DefaultMaxSize,
autoCleanupInterval: 5 * time.Minute,
autoCleanupInterval: 15 * time.Minute, // Increased from 5 minutes to reduce overhead
logger: logger,
}
c.startAutoCleanup()
@@ -155,16 +155,21 @@ func (c *Cache) Cleanup() {
}
// evictOldest removes the least recently used (oldest) item from the cache.
// It first attempts to find and remove an expired item from the front of the LRU list.
// If no expired items are found at the front, it removes the absolute oldest item (front of the list).
// It first attempts to find and remove expired items, checking up to 5 items
// from the front of the LRU list for efficiency. If no expired items are found,
// it removes the absolute oldest item (front of the list).
// This method is called internally by Set when the cache reaches its maximum size.
// Note: This function assumes the write lock is already held.
func (c *Cache) evictOldest() {
now := time.Now()
elem := c.order.Front()
// First try to find an expired item from the front
for elem != nil {
// Check up to 5 items from the front for expired entries
// This limits the search overhead while still finding expired items efficiently
const maxExpiredCheck = 5
checked := 0
for elem != nil && checked < maxExpiredCheck {
entry := elem.Value.(lruEntry)
if item, exists := c.items[entry.key]; exists {
if now.After(item.ExpiresAt) {
@@ -173,9 +178,10 @@ func (c *Cache) evictOldest() {
}
}
elem = elem.Next()
checked++
}
// If no expired items found, remove the oldest item
// If no expired items found in the first few entries, remove the oldest item
if elem = c.order.Front(); elem != nil {
entry := elem.Value.(lruEntry)
c.removeItem(entry.key)
+336
View File
@@ -0,0 +1,336 @@
package traefikoidc
import (
"container/list"
"fmt"
"runtime"
"sync"
"testing"
"time"
)
// TestCacheMemoryLeaks tests various cache scenarios for memory leaks
func TestCacheMemoryLeaks(t *testing.T) {
t.Run("Cache doesn't release expired items memory", func(t *testing.T) {
runtime.GC()
var m runtime.MemStats
runtime.ReadMemStats(&m)
baselineAlloc := m.Alloc
cache := NewCache()
defer cache.Close()
// Add many large items with short expiration
largeData := make([]byte, 1024*1024) // 1MB
for i := 0; i < 100; i++ {
key := fmt.Sprintf("key-%d", i)
// Items expire in 1 second
cache.Set(key, largeData, 1*time.Second)
}
// Wait for items to expire
time.Sleep(2 * time.Second)
// Force cleanup
cache.Cleanup()
// Check memory after cleanup
runtime.GC()
runtime.ReadMemStats(&m)
afterCleanupAlloc := m.Alloc
allocIncrease := float64(afterCleanupAlloc-baselineAlloc) / 1024 / 1024
t.Logf("Memory after adding and expiring 100MB of data: %.2f MB", allocIncrease)
// Memory should be mostly released after cleanup
if allocIncrease > 10.0 {
t.Errorf("Cache retains too much memory after cleanup: %.2f MB", allocIncrease)
}
})
t.Run("Token blacklist unbounded growth", func(t *testing.T) {
runtime.GC()
var m runtime.MemStats
runtime.ReadMemStats(&m)
baselineAlloc := m.Alloc
blacklist := NewCache()
blacklist.SetMaxSize(1000) // Limit size
defer blacklist.Close()
// Simulate continuous token blacklisting
for i := 0; i < 10000; i++ {
token := fmt.Sprintf("token-%d", i)
// All tokens expire in 24 hours (typical blacklist duration)
blacklist.Set(token, true, 24*time.Hour)
}
runtime.GC()
runtime.ReadMemStats(&m)
currentAlloc := m.Alloc
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
t.Logf("Memory after adding 10000 blacklisted tokens (max 1000): %.2f MB", allocIncrease)
// Should respect max size limit
if len(blacklist.items) > 1000 {
t.Errorf("Blacklist exceeded max size: %d items", len(blacklist.items))
}
// Memory should be bounded
if allocIncrease > 5.0 {
t.Errorf("Blacklist uses too much memory: %.2f MB for max 1000 items", allocIncrease)
}
})
t.Run("Replay cache with high JTI volume", func(t *testing.T) {
initReplayCache()
defer cleanupReplayCache()
runtime.GC()
var m runtime.MemStats
runtime.ReadMemStats(&m)
baselineAlloc := m.Alloc
// Simulate high volume of JTIs
for i := 0; i < 20000; i++ {
jti := fmt.Sprintf("jti-%d", i)
replayCacheMu.Lock()
if replayCache != nil {
// JTIs expire after token expiry (typically 1 hour)
replayCache.Set(jti, true, 1*time.Hour)
}
replayCacheMu.Unlock()
}
runtime.GC()
runtime.ReadMemStats(&m)
currentAlloc := m.Alloc
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
t.Logf("Memory after adding 20000 JTIs (max 10000): %.2f MB", allocIncrease)
// Check size limit is enforced
replayCacheMu.RLock()
cacheSize := 0
if replayCache != nil {
cacheSize = len(replayCache.items)
}
replayCacheMu.RUnlock()
if cacheSize > 10000 {
t.Errorf("Replay cache exceeded max size: %d items", cacheSize)
}
// Memory should be bounded
if allocIncrease > 10.0 {
t.Errorf("Replay cache uses too much memory: %.2f MB for max 10000 items", allocIncrease)
}
})
t.Run("Cache cleanup interval effectiveness", func(t *testing.T) {
// Create a cache with custom settings - don't use NewCache to avoid default cleanup
cache := &Cache{
items: make(map[string]CacheItem, DefaultMaxSize),
order: list.New(),
elems: make(map[string]*list.Element, DefaultMaxSize),
maxSize: DefaultMaxSize,
autoCleanupInterval: 200 * time.Millisecond, // Fast cleanup for test
logger: newNoOpLogger(),
}
// Start cleanup with our custom interval
cache.startAutoCleanup()
defer cache.Close()
// Add expired items
for i := 0; i < 1000; i++ {
key := fmt.Sprintf("key-%d", i)
cache.Set(key, "data", 50*time.Millisecond) // Very short expiry
}
// Wait for items to expire and cleanup to run (at least 2 cleanup cycles)
time.Sleep(600 * time.Millisecond)
// Manually trigger cleanup to ensure it runs
cache.Cleanup()
// Check that expired items are removed
cache.mutex.RLock()
remainingItems := len(cache.items)
cache.mutex.RUnlock()
t.Logf("Remaining items after auto cleanup: %d", remainingItems)
if remainingItems > 100 {
t.Errorf("Auto cleanup not effective: %d items remain", remainingItems)
}
})
t.Run("Concurrent cache operations memory stability", func(t *testing.T) {
cache := NewCache()
defer cache.Close()
runtime.GC()
var m runtime.MemStats
runtime.ReadMemStats(&m)
baselineAlloc := m.Alloc
var wg sync.WaitGroup
stop := make(chan struct{})
// Writers continuously add items
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 1000; j++ {
select {
case <-stop:
return
default:
key := fmt.Sprintf("writer-%d-%d", id, j)
cache.Set(key, "data", 1*time.Second)
time.Sleep(1 * time.Millisecond)
}
}
}(i)
}
// Readers continuously read items
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 1000; j++ {
select {
case <-stop:
return
default:
key := fmt.Sprintf("writer-%d-%d", id%5, j)
cache.Get(key)
time.Sleep(1 * time.Millisecond)
}
}
}(i)
}
// Let it run for a bit
time.Sleep(5 * time.Second)
close(stop)
wg.Wait()
runtime.GC()
runtime.ReadMemStats(&m)
finalAlloc := m.Alloc
// Handle potential underflow
var allocIncrease float64
if finalAlloc > baselineAlloc {
allocIncrease = float64(finalAlloc-baselineAlloc) / 1024 / 1024
} else {
allocIncrease = -float64(baselineAlloc-finalAlloc) / 1024 / 1024
}
t.Logf("Memory increase under concurrent load: %.2f MB", allocIncrease)
if allocIncrease > 5.0 {
t.Errorf("Memory leak under concurrent operations: %.2f MB", allocIncrease)
}
})
t.Run("LRU eviction memory release", func(t *testing.T) {
cache := NewCache()
cache.SetMaxSize(100) // Small cache
defer cache.Close()
runtime.GC()
var m runtime.MemStats
runtime.ReadMemStats(&m)
baselineAlloc := m.Alloc
// Add many items to trigger eviction
for i := 0; i < 1000; i++ {
key := fmt.Sprintf("key-%d", i)
data := make([]byte, 10240) // 10KB per item
cache.Set(key, data, 1*time.Hour)
}
runtime.GC()
runtime.ReadMemStats(&m)
afterEvictionAlloc := m.Alloc
allocIncrease := float64(afterEvictionAlloc-baselineAlloc) / 1024 / 1024
t.Logf("Memory after LRU eviction (1000 items, max 100): %.2f MB", allocIncrease)
// Should only keep 100 items worth of memory
if allocIncrease > 2.0 { // 100 * 10KB = ~1MB
t.Errorf("LRU eviction doesn't release memory properly: %.2f MB", allocIncrease)
}
// Verify cache size
if len(cache.items) > 100 {
t.Errorf("Cache size exceeded limit: %d items", len(cache.items))
}
})
t.Run("Token cache with claims memory", func(t *testing.T) {
tokenCache := NewTokenCache()
defer tokenCache.Close()
runtime.GC()
var m runtime.MemStats
runtime.ReadMemStats(&m)
baselineAlloc := m.Alloc
// Add tokens with large claims
for i := 0; i < 1000; i++ {
token := fmt.Sprintf("token-%d", i)
claims := map[string]interface{}{
"sub": fmt.Sprintf("user-%d", i),
"email": fmt.Sprintf("user%d@example.com", i),
"groups": make([]string, 100), // Large groups list
"data": make([]byte, 1024), // Extra data
}
tokenCache.Set(token, claims, 1*time.Hour)
}
runtime.GC()
runtime.ReadMemStats(&m)
currentAlloc := m.Alloc
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
t.Logf("Memory after adding 1000 tokens with large claims: %.2f MB", allocIncrease)
// Check if memory is reasonable
if allocIncrease > 20.0 {
t.Errorf("Token cache uses excessive memory: %.2f MB", allocIncrease)
}
})
}
// TestCacheEvictionBug tests the inefficient eviction in evictOldest
func TestCacheEvictionBug(t *testing.T) {
t.Run("evictOldest scans entire list", func(t *testing.T) {
cache := NewCache()
cache.SetMaxSize(100)
defer cache.Close()
// Fill cache with non-expired items
for i := 0; i < 100; i++ {
key := fmt.Sprintf("key-%d", i)
cache.Set(key, "data", 1*time.Hour) // Long expiry
}
// Try to add one more item to trigger eviction
start := time.Now()
cache.Set("trigger", "data", 1*time.Hour)
elapsed := time.Since(start)
t.Logf("Time to evict and add one item: %v", elapsed)
// Should be fast even with full cache
if elapsed > 10*time.Millisecond {
t.Errorf("Eviction too slow, possibly scanning entire list: %v", elapsed)
}
})
}
+646
View File
@@ -0,0 +1,646 @@
package traefikoidc
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"runtime/debug"
"strings"
"sync"
"testing"
"time"
"golang.org/x/time/rate"
)
// TestBackgroundGoroutineLeaks tests that background goroutines don't leak memory
// even when no requests are made to protected resources
func TestBackgroundGoroutineLeaks(t *testing.T) {
t.Run("Idle middleware memory growth", func(t *testing.T) {
// Force GC to get clean baseline
runtime.GC()
runtime.GC()
time.Sleep(100 * time.Millisecond)
var m runtime.MemStats
runtime.ReadMemStats(&m)
baselineAlloc := m.Alloc
baselineGoroutines := runtime.NumGoroutine()
t.Logf("Baseline: Memory=%d KB, Goroutines=%d", baselineAlloc/1024, baselineGoroutines)
// Create middleware instance
config := CreateConfig()
config.ProviderURL = "https://example.com"
config.ClientID = "test-client"
config.ClientSecret = "test-secret"
config.SessionEncryptionKey = "test-encryption-key-that-is-long-enough-32bytes"
config.LogLevel = "error" // Reduce noise
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler, err := New(ctx, next, config, "test-middleware")
if err != nil {
t.Fatal(err)
}
// Cast to TraefikOidc to access internals
middleware, ok := handler.(*TraefikOidc)
if !ok {
t.Fatal("Failed to cast to TraefikOidc")
}
// Let it sit idle for a while - simulating no requests
// During this time, background goroutines are running
t.Log("Letting middleware sit idle for 30 seconds...")
// Take measurements every 5 seconds
for i := 0; i < 6; i++ {
time.Sleep(5 * time.Second)
runtime.GC()
runtime.ReadMemStats(&m)
currentAlloc := m.Alloc
currentGoroutines := runtime.NumGoroutine()
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
goroutineIncrease := currentGoroutines - baselineGoroutines
t.Logf("After %d seconds: Memory increase=%.2f MB, Goroutine increase=%d",
(i+1)*5, allocIncrease, goroutineIncrease)
// Check for significant memory growth (more than 5MB)
if allocIncrease > 5.0 {
t.Errorf("Significant memory increase detected: %.2f MB after %d seconds of idle",
allocIncrease, (i+1)*5)
}
// Check for goroutine leaks (more than 10 extra goroutines)
if goroutineIncrease > 10 {
t.Errorf("Goroutine leak detected: %d extra goroutines after %d seconds",
goroutineIncrease, (i+1)*5)
}
}
// Clean up
err = middleware.Close()
if err != nil {
t.Errorf("Failed to close middleware: %v", err)
}
// Wait for cleanup
time.Sleep(500 * time.Millisecond)
// Final check
runtime.GC()
runtime.ReadMemStats(&m)
finalAlloc := m.Alloc
finalGoroutines := runtime.NumGoroutine()
finalAllocIncrease := float64(finalAlloc-baselineAlloc) / 1024 / 1024
finalGoroutineIncrease := finalGoroutines - baselineGoroutines
t.Logf("Final: Memory increase=%.2f MB, Goroutine increase=%d",
finalAllocIncrease, finalGoroutineIncrease)
if finalGoroutineIncrease > 2 {
t.Errorf("Goroutines not cleaned up properly: %d extra goroutines remain",
finalGoroutineIncrease)
}
})
}
// TestHTTPClientConnectionLeaks tests that HTTP clients don't leak connections
func TestHTTPClientConnectionLeaks(t *testing.T) {
t.Run("HTTP client connection accumulation", func(t *testing.T) {
// Create test server that simulates OIDC endpoints
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"jwks_uri": "https://example.com/jwks",
"userinfo_endpoint": "https://example.com/userinfo"
}`))
case "/jwks":
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"keys": []}`))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
// Monitor connection count
getActiveConnections := func(client *http.Client) int {
if transport, ok := client.Transport.(*http.Transport); ok {
// This is a simplified check - in reality we'd need to inspect
// the transport's connection pool
return transport.MaxIdleConnsPerHost
}
return 0
}
// Create multiple HTTP clients like the middleware does
clients := make([]*http.Client, 10)
for i := 0; i < 10; i++ {
clients[i] = createDefaultHTTPClient()
// Make requests
resp, err := clients[i].Get(server.URL + "/.well-known/openid-configuration")
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
// Check connection settings
conns := getActiveConnections(clients[i])
if conns > 1 {
t.Logf("Client %d has %d max idle connections per host", i, conns)
}
}
// Let connections sit idle (now only 5s with our optimizations)
time.Sleep(6 * time.Second) // Slightly longer than new IdleConnTimeout (5s)
// Force cleanup
for _, client := range clients {
if transport, ok := client.Transport.(*http.Transport); ok {
transport.CloseIdleConnections()
}
}
t.Log("HTTP clients cleaned up successfully")
})
}
// TestCacheBackgroundTaskLeaks tests that cache background tasks don't leak
func TestCacheBackgroundTaskLeaks(t *testing.T) {
t.Run("Multiple cache instances with cleanup tasks", func(t *testing.T) {
initialGoroutines := runtime.NumGoroutine()
// Create many cache instances
caches := make([]*Cache, 50)
for i := 0; i < 50; i++ {
caches[i] = NewCache()
// Add some data
for j := 0; j < 100; j++ {
key := fmt.Sprintf("key-%d-%d", i, j)
caches[i].Set(key, "value", 5*time.Minute)
}
}
// Wait for all cleanup goroutines to start
time.Sleep(200 * time.Millisecond)
afterCreateGoroutines := runtime.NumGoroutine()
goroutineIncrease := afterCreateGoroutines - initialGoroutines
t.Logf("Created %d caches, goroutine increase: %d", len(caches), goroutineIncrease)
// Expected: one cleanup goroutine per cache
if goroutineIncrease < len(caches) {
t.Errorf("Expected at least %d goroutines, got %d increase",
len(caches), goroutineIncrease)
}
// Close all caches
for _, cache := range caches {
cache.Close()
}
// Wait for goroutines to stop
time.Sleep(500 * time.Millisecond)
finalGoroutines := runtime.NumGoroutine()
remainingGoroutines := finalGoroutines - initialGoroutines
if remainingGoroutines > 5 { // Allow small tolerance
t.Errorf("Cache cleanup goroutines not stopped properly: %d extra goroutines remain",
remainingGoroutines)
}
})
}
// TestGlobalSingletonMemoryGrowth tests that global singletons don't grow unbounded
func TestGlobalSingletonMemoryGrowth(t *testing.T) {
t.Run("Global cache manager memory growth", func(t *testing.T) {
// Clean up any existing global state
CleanupGlobalCacheManager()
CleanupGlobalMemoryPools()
runtime.GC()
var m runtime.MemStats
runtime.ReadMemStats(&m)
baselineAlloc := m.Alloc
// Get global cache manager
wg := &sync.WaitGroup{}
cm := GetGlobalCacheManager(wg)
// Simulate continuous usage without cleanup
for i := 0; i < 1000; i++ {
// Add to token cache
cm.GetSharedTokenCache().Set(
fmt.Sprintf("token-%d", i),
map[string]interface{}{"claim": fmt.Sprintf("value-%d", i)},
5*time.Minute,
)
// Add to blacklist
cm.GetSharedTokenBlacklist().Set(
fmt.Sprintf("blacklist-%d", i),
true,
5*time.Minute,
)
// Every 100 iterations, check memory
if i%100 == 0 {
runtime.GC()
runtime.ReadMemStats(&m)
currentAlloc := m.Alloc
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
t.Logf("After %d items: Memory increase=%.2f MB", i, allocIncrease)
// The caches should have max size limits
// If memory grows more than 10MB, there's likely a leak
if allocIncrease > 10.0 {
t.Errorf("Excessive memory growth in global caches: %.2f MB after %d items",
allocIncrease, i)
break
}
}
}
// Cleanup
CleanupGlobalCacheManager()
CleanupGlobalMemoryPools()
wg.Wait()
// Force full cleanup of replay cache too
cleanupReplayCache()
runtime.GC()
runtime.ReadMemStats(&m)
finalAlloc := m.Alloc
finalAllocIncrease := float64(finalAlloc-baselineAlloc) / 1024 / 1024
t.Logf("Final memory increase after cleanup: %.2f MB", finalAllocIncrease)
if finalAllocIncrease > 2.0 {
t.Errorf("Memory not properly released after cleanup: %.2f MB remains",
finalAllocIncrease)
}
})
}
// TestMetadataCacheRefreshLeak tests for memory leaks in metadata refresh
func TestMetadataCacheRefreshLeak(t *testing.T) {
t.Run("Metadata cache refresh memory leak", func(t *testing.T) {
wg := &sync.WaitGroup{}
cache := NewMetadataCacheWithLogger(wg, NewLogger("error"))
defer cache.Close()
// Mock HTTP client that returns metadata
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// Return a large metadata response to amplify any leaks
w.Write([]byte(`{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"jwks_uri": "https://example.com/jwks",
"userinfo_endpoint": "https://example.com/userinfo",
"extra_field_1": "` + strings.Repeat("x", 1000) + `",
"extra_field_2": "` + strings.Repeat("y", 1000) + `"
}`))
}))
defer server.Close()
client := &http.Client{Timeout: 5 * time.Second}
runtime.GC()
var m runtime.MemStats
runtime.ReadMemStats(&m)
baselineAlloc := m.Alloc
// Simulate many metadata refreshes
for i := 0; i < 100; i++ {
_, err := cache.GetMetadata(server.URL, client, NewLogger("error"))
if err != nil {
t.Logf("Metadata fetch error (expected for test server): %v", err)
}
// Force cache expiry to trigger refresh
cache.mutex.Lock()
cache.expiresAt = time.Now().Add(-1 * time.Hour)
cache.mutex.Unlock()
if i%20 == 0 && i > 0 {
runtime.GC()
runtime.ReadMemStats(&m)
currentAlloc := m.Alloc
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
t.Logf("After %d refreshes: Memory increase=%.2f MB", i, allocIncrease)
// Metadata cache should only store one copy
if allocIncrease > 3.0 {
t.Errorf("Metadata cache leak detected: %.2f MB after %d refreshes",
allocIncrease, i)
break
}
}
}
cache.Close()
wg.Wait()
runtime.GC()
runtime.ReadMemStats(&m)
finalAlloc := m.Alloc
finalAllocIncrease := float64(finalAlloc-baselineAlloc) / 1024 / 1024
t.Logf("Final memory after metadata cache closure: %.2f MB", finalAllocIncrease)
if finalAllocIncrease > 1.0 {
t.Errorf("Metadata cache not cleaned properly: %.2f MB remains", finalAllocIncrease)
}
})
}
// TestMemoryPoolLeak tests for leaks in memory pool management
func TestMemoryPoolLeak(t *testing.T) {
t.Run("Memory pool buffer leaks", func(t *testing.T) {
// Clean up any existing pools
CleanupGlobalMemoryPools()
pools := GetGlobalMemoryPools()
runtime.GC()
var m runtime.MemStats
runtime.ReadMemStats(&m)
baselineAlloc := m.Alloc
// Simulate heavy buffer usage
var buffers [][]byte
for i := 0; i < 1000; i++ {
buf := pools.GetHTTPResponseBuffer()
// Simulate using the buffer
copy(buf, []byte("test data"))
// 90% of the time, return the buffer
// 10% of the time, "forget" to return it (simulating a leak)
if i%10 != 0 {
pools.PutHTTPResponseBuffer(buf)
} else {
// Keep reference to simulate leak
buffers = append(buffers, buf)
}
if i%100 == 0 && i > 0 {
runtime.GC()
runtime.ReadMemStats(&m)
currentAlloc := m.Alloc
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
t.Logf("After %d buffer operations: Memory increase=%.2f MB, Leaked buffers=%d",
i, allocIncrease, len(buffers))
// With proper pooling, memory should be bounded
if allocIncrease > 5.0 {
t.Errorf("Memory pool leak detected: %.2f MB after %d operations",
allocIncrease, i)
break
}
}
}
// Return the "leaked" buffers
for _, buf := range buffers {
pools.PutHTTPResponseBuffer(buf)
}
CleanupGlobalMemoryPools()
runtime.GC()
runtime.ReadMemStats(&m)
finalAlloc := m.Alloc
finalAllocIncrease := float64(finalAlloc-baselineAlloc) / 1024 / 1024
t.Logf("Final memory after pool cleanup: %.2f MB", finalAllocIncrease)
if finalAllocIncrease > 1.0 {
t.Errorf("Memory pools not cleaned properly: %.2f MB remains", finalAllocIncrease)
}
})
}
// TestConcurrentMemoryLeaks tests for memory leaks under concurrent load
func TestConcurrentMemoryLeaks(t *testing.T) {
t.Run("Concurrent operations memory stability", func(t *testing.T) {
// Reset global state
resetGlobalState()
// Set lower GC percentage to trigger GC more frequently
debug.SetGCPercent(50)
defer debug.SetGCPercent(100)
runtime.GC()
var m runtime.MemStats
runtime.ReadMemStats(&m)
baselineAlloc := m.Alloc
baselineGoroutines := runtime.NumGoroutine()
// Create a mock OIDC server
var mockServerURL string
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"issuer": mockServerURL,
"authorization_endpoint": mockServerURL + "/auth",
"token_endpoint": mockServerURL + "/token",
"userinfo_endpoint": mockServerURL + "/userinfo",
"jwks_uri": mockServerURL + "/keys",
})
case "/keys":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "RSA",
"use": "sig",
"kid": "test-key",
"n": "test-modulus",
"e": "AQAB",
},
},
})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer mockServer.Close()
mockServerURL = mockServer.URL
// Create middleware config with mock server
config := createTestConfig()
config.ProviderURL = mockServerURL
config.LogLevel = "error"
// Create middleware directly without using New() to avoid automatic metadata fetch
middleware := &TraefikOidc{
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }),
providerURL: config.ProviderURL,
clientID: config.ClientID,
clientSecret: config.ClientSecret,
redirURLPath: "/callback",
scopes: []string{"openid", "email", "profile"},
logger: NewLogger(config.LogLevel),
excludedURLs: make(map[string]struct{}),
httpClient: &http.Client{Timeout: 5 * time.Second},
sessionManager: nil, // Will be created below
tokenCache: NewTokenCache(),
tokenBlacklist: NewCache(),
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
goroutineWG: &sync.WaitGroup{},
firstRequestMutex: sync.Mutex{},
firstRequestReceived: false,
}
// Create session manager
var err error
middleware.sessionManager, err = NewSessionManager(
config.SessionEncryptionKey,
config.ForceHTTPS,
config.CookieDomain,
middleware.logger,
)
if err != nil {
t.Fatal(err)
}
// Simulate concurrent operations
var wg sync.WaitGroup
stopChan := make(chan struct{})
// Worker that continuously performs operations
worker := func(id int) {
defer wg.Done()
for {
select {
case <-stopChan:
return
default:
// Simulate various operations
switch id % 4 {
case 0:
// Cache operations
cache := NewCache()
cache.Set(fmt.Sprintf("key-%d", id), "value", time.Minute)
cache.Get(fmt.Sprintf("key-%d", id))
cache.Close()
case 1:
// Session operations
req := httptest.NewRequest("GET", "/", nil)
session, _ := middleware.sessionManager.GetSession(req)
if session != nil {
session.SetAccessToken("token")
session.returnToPoolSafely()
}
case 2:
// Token cache operations
cm := GetGlobalCacheManager(nil)
if cm != nil {
tc := cm.GetSharedTokenCache()
tc.Set("test-token", map[string]interface{}{"test": "data"}, time.Minute)
tc.Get("test-token")
}
case 3:
// Memory pool operations
pools := GetGlobalMemoryPools()
buf := pools.GetHTTPResponseBuffer()
pools.PutHTTPResponseBuffer(buf)
}
// Small delay to prevent tight loop
time.Sleep(10 * time.Millisecond)
}
}
}
// Start workers
numWorkers := 20
wg.Add(numWorkers)
for i := 0; i < numWorkers; i++ {
go worker(i)
}
// Let it run and measure periodically
for i := 0; i < 5; i++ {
time.Sleep(5 * time.Second)
runtime.GC()
runtime.ReadMemStats(&m)
currentAlloc := m.Alloc
currentGoroutines := runtime.NumGoroutine()
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
goroutineIncrease := currentGoroutines - baselineGoroutines
t.Logf("After %d seconds of concurrent load: Memory increase=%.2f MB, Goroutine increase=%d",
(i+1)*5, allocIncrease, goroutineIncrease)
// Under sustained concurrent load, memory should stabilize
if i > 2 && allocIncrease > 20.0 {
t.Errorf("Memory leak under concurrent load: %.2f MB after %d seconds",
allocIncrease, (i+1)*5)
}
}
// Stop workers
close(stopChan)
wg.Wait()
// Clean up
middleware.Close()
// Final measurements
time.Sleep(500 * time.Millisecond)
runtime.GC()
runtime.ReadMemStats(&m)
finalAlloc := m.Alloc
finalGoroutines := runtime.NumGoroutine()
finalAllocIncrease := float64(finalAlloc-baselineAlloc) / 1024 / 1024
finalGoroutineIncrease := finalGoroutines - baselineGoroutines
t.Logf("Final after cleanup: Memory increase=%.2f MB, Goroutine increase=%d",
finalAllocIncrease, finalGoroutineIncrease)
if finalGoroutineIncrease > 5 {
t.Errorf("Goroutines not cleaned up after concurrent operations: %d extra remain",
finalGoroutineIncrease)
}
if finalAllocIncrease > 5.0 {
t.Errorf("Memory not released after concurrent operations: %.2f MB remains",
finalAllocIncrease)
}
})
}
+24 -5
View File
@@ -64,15 +64,37 @@ func (r *ProviderRegistry) ClearCache() {
// DetectProvider determines the most appropriate provider for a given issuer URL.
// It iterates through the registered providers and returns the first one that matches.
// Detection is based on URL patterns and other provider-specific criteria.
// Uses double-checked locking pattern to avoid race conditions while caching results.
func (r *ProviderRegistry) DetectProvider(issuerURL string) OIDCProvider {
// First check: read lock for cache lookup
r.mu.RLock()
defer r.mu.RUnlock()
if provider, found := r.cache[issuerURL]; found {
r.mu.RUnlock()
return provider
}
r.mu.RUnlock()
// Check cache first for performance
// Cache miss - acquire write lock for detection and caching
r.mu.Lock()
defer r.mu.Unlock()
// Second check: another goroutine might have cached the result while we waited for write lock
if provider, found := r.cache[issuerURL]; found {
return provider
}
// Perform detection under write lock
detectedProvider := r.detectProviderUnsafe(issuerURL)
// Cache the result (even if nil to avoid repeated expensive operations)
r.cache[issuerURL] = detectedProvider
return detectedProvider
}
// detectProviderUnsafe performs the actual provider detection logic.
// This method assumes the caller holds the appropriate lock and should not be called directly.
func (r *ProviderRegistry) detectProviderUnsafe(issuerURL string) OIDCProvider {
// Normalize issuer URL for consistent matching
normalizedURL, err := url.Parse(issuerURL)
if err != nil {
@@ -86,12 +108,10 @@ func (r *ProviderRegistry) DetectProvider(issuerURL string) OIDCProvider {
switch p.GetType() {
case ProviderTypeGoogle:
if strings.Contains(host, "accounts.google.com") {
r.cache[issuerURL] = p
return p
}
case ProviderTypeAzure:
if strings.Contains(host, "login.microsoftonline.com") || strings.Contains(host, "sts.windows.net") {
r.cache[issuerURL] = p
return p
}
}
@@ -100,7 +120,6 @@ func (r *ProviderRegistry) DetectProvider(issuerURL string) OIDCProvider {
// Fallback to the generic provider if no specific provider is detected
for _, p := range r.providers {
if p.GetType() == ProviderTypeGeneric {
r.cache[issuerURL] = p
return p
}
}
+82 -37
View File
@@ -36,27 +36,27 @@ func createDefaultHTTPClient() *http.Client {
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: 10 * time.Second, // Connection timeout for faster failure detection
KeepAlive: 30 * time.Second, // Keep-alive interval for connection reuse
Timeout: 5 * time.Second, // Reduced connection timeout for faster failure detection
KeepAlive: 15 * time.Second, // Reduced keep-alive interval for connection reuse
}
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 3 * time.Second, // TLS handshake timeout
ExpectContinueTimeout: 1 * time.Second, // Timeout for 100-continue responses
MaxIdleConns: 10, // Reduced from 20 to prevent memory buildup
MaxIdleConnsPerHost: 2, // Reduced from 5 to limit per-host connections
IdleConnTimeout: 30 * time.Second, // Reduced from 60 to close idle connections faster
DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 10, // Reduced from 20 to limit concurrent connections
ResponseHeaderTimeout: 5 * time.Second, // Timeout for reading response headers
DisableCompression: false, // Enable compression for bandwidth efficiency
WriteBufferSize: 4096, // Write buffer size for connections
ReadBufferSize: 4096, // Read buffer size for connections
TLSHandshakeTimeout: 2 * time.Second, // Reduced TLS handshake timeout
ExpectContinueTimeout: 1 * time.Second, // Timeout for 100-continue responses
MaxIdleConns: 2, // Reduced from 5 to minimize idle connections
MaxIdleConnsPerHost: 1, // Keep minimal idle connections per host
IdleConnTimeout: 5 * time.Second, // Reduced from 15s to close idle connections faster
DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 2, // Reduced from 5 to limit concurrent connections
ResponseHeaderTimeout: 3 * time.Second, // Reduced timeout for reading response headers
DisableCompression: false, // Enable compression for bandwidth efficiency
WriteBufferSize: 4096, // Write buffer size for connections
ReadBufferSize: 4096, // Read buffer size for connections
}
return &http.Client{
Timeout: time.Second * 10, // HTTP client timeout
Timeout: time.Second * 5, // Reduced HTTP client timeout to prevent hanging connections
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Limit redirects to prevent redirect loops
@@ -124,7 +124,7 @@ type CacheManager struct {
// It initializes all cache types on first call with appropriate default settings.
// This ensures thread-safe initialization and consistent cache behavior across
// the entire application lifecycle.
func GetGlobalCacheManager() *CacheManager {
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
cacheManagerOnce.Do(func() {
globalCacheManager = &CacheManager{
tokenBlacklist: func() *Cache {
@@ -133,7 +133,7 @@ func GetGlobalCacheManager() *CacheManager {
return c
}(),
tokenCache: NewTokenCache(),
metadataCache: NewMetadataCache(),
metadataCache: NewMetadataCache(wg),
jwkCache: &JWKCache{},
}
})
@@ -283,13 +283,17 @@ type TraefikOidc struct {
issuerURL string
revocationURL string
scopes []string
goroutineWG sync.WaitGroup
goroutineWG *sync.WaitGroup
refreshGracePeriod time.Duration
shutdownOnce sync.Once
forceHTTPS bool
enablePKCE bool
overrideScopes bool
suppressDiagnosticLogs bool
firstRequestReceived bool
firstRequestMutex sync.Mutex
metadataRefreshStarted bool
providerURL string
}
// ProviderMetadata represents the OpenID Connect provider's discovery metadata.
@@ -690,13 +694,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
} else {
httpClient = createDefaultHTTPClient()
}
cacheManager := GetGlobalCacheManager()
goroutineWG := &sync.WaitGroup{}
cacheManager := GetGlobalCacheManager(goroutineWG)
pluginCtx, cancelFunc := context.WithCancel(context.Background())
t := &TraefikOidc{
next: next,
name: name,
goroutineWG: goroutineWG,
redirURLPath: config.CallbackURL,
logoutURLPath: func() string {
if config.LogoutURL == "" {
@@ -771,7 +777,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t.tokenVerifier = t
t.jwtVerifier = t
t.startTokenCleanup()
// Delay startTokenCleanup() until first request
t.tokenExchanger = t // Initialize the interface field to self
// Initialize and parse header templates with safe field access
@@ -815,6 +821,9 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
startReplayCacheCleanup(pluginCtx, logger)
logger.Debugf("TraefikOidc.New: Final t.scopes initialized to: %v", t.scopes)
// Store provider URL for later use
t.providerURL = config.ProviderURL
go t.initializeMetadata(config.ProviderURL)
return t, nil
@@ -851,8 +860,8 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.logger.Debug("Successfully initialized provider metadata")
t.updateMetadataEndpoints(metadata)
// Start metadata refresh goroutine
go t.startMetadataRefresh(providerURL)
// Delay metadata refresh goroutine until first request
// It will be started in ServeHTTP when needed
// Only close channel on success
close(t.initComplete)
@@ -888,11 +897,19 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
// Use longer interval to reduce memory pressure from refresh attempts
ticker := time.NewTicker(2 * time.Hour) // Increased from 1 hour
t.goroutineWG.Add(1) // Track this goroutine
// Check if WaitGroup is initialized (it might be nil in tests)
if t.goroutineWG != nil {
t.goroutineWG.Add(1) // Track this goroutine
}
go func() {
defer t.goroutineWG.Done() // Signal completion when goroutine exits
defer ticker.Stop() // Ensure ticker is always stopped
defer func() {
if t.goroutineWG != nil {
t.goroutineWG.Done() // Signal completion when goroutine exits
}
}()
defer ticker.Stop() // Ensure ticker is always stopped
consecutiveFailures := 0
for {
@@ -1056,6 +1073,23 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
// - Authentication state management
// - Header injection for authenticated requests
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Start background tasks on first request (except for health checks)
if !t.firstRequestReceived && !strings.HasPrefix(req.URL.Path, "/health") {
t.firstRequestMutex.Lock()
if !t.firstRequestReceived {
t.firstRequestReceived = true
t.logger.Debug("Starting background tasks on first request")
t.startTokenCleanup()
// Also start metadata refresh if not already started
if !t.metadataRefreshStarted && t.providerURL != "" {
t.metadataRefreshStarted = true
go t.startMetadataRefresh(t.providerURL)
}
}
t.firstRequestMutex.Unlock()
}
// Wait for provider metadata initialization to complete
select {
case <-t.initComplete:
@@ -2048,11 +2082,17 @@ func (t *TraefikOidc) validateHost(host string) error {
// panic recovery to ensure stability.
func (t *TraefikOidc) startTokenCleanup() {
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
t.goroutineWG.Add(1) // Track this goroutine
// Check if WaitGroup is initialized (it might be nil in tests)
if t.goroutineWG != nil {
t.goroutineWG.Add(1) // Track this goroutine
}
go func() {
defer func() {
t.goroutineWG.Done() // Signal completion when goroutine exits
ticker.Stop() // Ensure ticker is always stopped
if t.goroutineWG != nil {
t.goroutineWG.Done() // Signal completion when goroutine exits
}
ticker.Stop() // Ensure ticker is always stopped
// CRITICAL: Recover from panics to prevent middleware crashes
if r := recover(); r != nil {
@@ -2838,17 +2878,22 @@ func (t *TraefikOidc) Close() error {
t.logger.Debug("metadataRefreshStopChan closed")
}
done := make(chan struct{})
go func() {
t.goroutineWG.Wait()
close(done)
}()
// Only wait for goroutines if WaitGroup is initialized
if t.goroutineWG != nil {
done := make(chan struct{})
go func() {
t.goroutineWG.Wait()
close(done)
}()
select {
case <-done:
t.logger.Debug("All background goroutines stopped gracefully")
case <-time.After(10 * time.Second):
t.logger.Errorf("Timeout waiting for background goroutines to stop")
select {
case <-done:
t.logger.Debug("All background goroutines stopped gracefully")
case <-time.After(10 * time.Second):
t.logger.Errorf("Timeout waiting for background goroutines to stop")
}
} else {
t.logger.Debug("No goroutineWG to wait for (likely in test)")
}
if t.httpClient != nil {
+111
View File
@@ -0,0 +1,111 @@
package traefikoidc
import (
"container/list"
"net/http"
"sync"
"time"
)
// LazyBackgroundTask wraps BackgroundTask to start only when needed
type LazyBackgroundTask struct {
*BackgroundTask
started bool
startOnce sync.Once
}
// NewLazyBackgroundTask creates a background task that doesn't start immediately
func NewLazyBackgroundTask(name string, interval time.Duration, taskFunc func(), logger *Logger, wg ...*sync.WaitGroup) *LazyBackgroundTask {
return &LazyBackgroundTask{
BackgroundTask: NewBackgroundTask(name, interval, taskFunc, logger, wg...),
started: false,
}
}
// StartIfNeeded starts the task only if it hasn't been started yet
func (lt *LazyBackgroundTask) StartIfNeeded() {
lt.startOnce.Do(func() {
if !lt.started {
lt.BackgroundTask.Start()
lt.started = true
}
})
}
// Stop stops the task if it was started
func (lt *LazyBackgroundTask) Stop() {
if lt.started {
lt.BackgroundTask.Stop()
lt.started = false
lt.startOnce = sync.Once{} // Reset for potential restart
}
}
// NewLazyCacheWithLogger creates a cache that doesn't start cleanup until first use
func NewLazyCacheWithLogger(logger *Logger) *Cache {
if logger == nil {
logger = newNoOpLogger()
}
c := &Cache{
items: make(map[string]CacheItem, DefaultMaxSize),
order: list.New(),
elems: make(map[string]*list.Element, DefaultMaxSize),
maxSize: DefaultMaxSize,
autoCleanupInterval: 10 * time.Minute, // Increased from 5 minutes
logger: logger,
}
// Don't start cleanup immediately - it will be started on first use
return c
}
// NewLazyCache creates a cache that doesn't start cleanup immediately
func NewLazyCache() *Cache {
return NewLazyCacheWithLogger(nil)
}
// CleanupIdleConnections periodically closes idle HTTP connections
func CleanupIdleConnections(client *http.Client, interval time.Duration, stopChan <-chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if transport, ok := client.Transport.(*http.Transport); ok {
transport.CloseIdleConnections()
}
case <-stopChan:
// Final cleanup
if transport, ok := client.Transport.(*http.Transport); ok {
transport.CloseIdleConnections()
}
return
}
}
}
// OptimizedMiddlewareConfig provides configuration for memory-optimized middleware
type OptimizedMiddlewareConfig struct {
// DelayBackgroundTasks delays starting background tasks until first request
DelayBackgroundTasks bool
// ReducedCleanupIntervals uses longer intervals for cleanup tasks
ReducedCleanupIntervals bool
// AggressiveConnectionCleanup closes idle connections more aggressively
AggressiveConnectionCleanup bool
// MinimalCacheSize uses smaller default cache sizes
MinimalCacheSize bool
}
// DefaultOptimizedConfig returns a configuration optimized for low memory usage
func DefaultOptimizedConfig() *OptimizedMiddlewareConfig {
return &OptimizedMiddlewareConfig{
DelayBackgroundTasks: true,
ReducedCleanupIntervals: true,
AggressiveConnectionCleanup: true,
MinimalCacheSize: true,
}
}
+2 -2
View File
@@ -48,7 +48,7 @@ func TestMemoryLeakFixes(t *testing.T) {
t.Run("Global cache manager cleanup", func(t *testing.T) {
// Get the global cache manager
cm := GetGlobalCacheManager()
cm := GetGlobalCacheManager(nil)
if cm == nil {
t.Fatal("Failed to get global cache manager")
}
@@ -64,7 +64,7 @@ func TestMemoryLeakFixes(t *testing.T) {
}
// Verify it can be re-initialized
cm2 := GetGlobalCacheManager()
cm2 := GetGlobalCacheManager(nil)
if cm2 == nil {
t.Fatal("Failed to re-initialize global cache manager")
}
+9 -5
View File
@@ -19,23 +19,27 @@ type MetadataCache struct {
logger *Logger
autoCleanupInterval time.Duration
mutex sync.RWMutex
wg *sync.WaitGroup
stopChan chan struct{}
}
// NewMetadataCache creates a new MetadataCache instance.
// It initializes the cache structure and starts the background cleanup task.
func NewMetadataCache() *MetadataCache {
return NewMetadataCacheWithLogger(nil)
func NewMetadataCache(wg *sync.WaitGroup) *MetadataCache {
return NewMetadataCacheWithLogger(wg, nil)
}
// NewMetadataCacheWithLogger creates a new MetadataCache with a specified logger.
func NewMetadataCacheWithLogger(logger *Logger) *MetadataCache {
func NewMetadataCacheWithLogger(wg *sync.WaitGroup, logger *Logger) *MetadataCache {
if logger == nil {
logger = newNoOpLogger()
}
c := &MetadataCache{
autoCleanupInterval: 5 * time.Minute,
autoCleanupInterval: 30 * time.Minute, // Increased from 5 minutes since metadata changes rarely
logger: logger,
wg: wg,
stopChan: make(chan struct{}),
}
c.startAutoCleanup()
return c
@@ -200,7 +204,7 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
// startAutoCleanup starts the background task that periodically calls Cleanup
// to remove expired metadata from the cache.
func (c *MetadataCache) startAutoCleanup() {
c.cleanupTask = NewBackgroundTask("metadata-cache-cleanup", c.autoCleanupInterval, c.Cleanup, c.logger)
c.cleanupTask = NewBackgroundTask("metadata-cache-cleanup", c.autoCleanupInterval, c.Cleanup, c.logger, c.wg)
c.cleanupTask.Start()
}
+852
View File
@@ -0,0 +1,852 @@
package traefikoidc
import (
"bytes"
"fmt"
"net/http"
"runtime"
"runtime/pprof"
"sync"
"time"
)
// MemoryProfiler defines the interface for memory profiling operations
type MemoryProfiler interface {
// TakeSnapshot captures current memory statistics
TakeSnapshot() (*MemorySnapshot, error)
// StartProfiling begins memory profiling with specified configuration
StartProfiling(config ProfilingConfig) error
// StopProfiling ends memory profiling and returns final snapshot
StopProfiling() (*MemorySnapshot, error)
// GetCurrentStats returns current runtime memory statistics
GetCurrentStats() *runtime.MemStats
// AnalyzeLeaks performs leak detection analysis
AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis
}
// MemorySnapshot represents a point-in-time capture of memory statistics
type MemorySnapshot struct {
Timestamp time.Time
RuntimeStats runtime.MemStats
HeapProfile []byte
GoroutineProfile []byte
CustomMetrics map[string]interface{}
}
// LeakAnalysis contains the results of memory leak detection
type LeakAnalysis struct {
HasLeak bool
LeakDescription string
MemoryIncrease uint64
GoroutineIncrease int
SuspectedLeaks []string
Recommendations []string
}
// ProfilingManager coordinates memory profiling operations
type ProfilingManager struct {
mu sync.RWMutex
isProfiling bool
startTime time.Time
baselineSnapshot *MemorySnapshot
config ProfilingConfig
logger *Logger
profilers map[string]MemoryProfiler
}
// ProfilingConfig contains configuration for profiling operations
type ProfilingConfig struct {
EnableHeapProfiling bool
EnableGoroutineProfiling bool
SnapshotInterval time.Duration
LeakThresholdMB uint64
MaxSnapshots int
EnableContinuousMonitoring bool
MonitoringInterval time.Duration
}
// LeakDetectionConfig contains configuration for leak detection
type LeakDetectionConfig struct {
EnableLeakDetection bool
LeakThresholdMB uint64
GoroutineLeakThreshold int
SessionPoolThreshold int
CacheMemoryThreshold uint64
HTTPClientThreshold int
TokenCompressionThreshold uint64
}
// NewProfilingManager creates a new profiling manager instance
func NewProfilingManager(logger *Logger) *ProfilingManager {
if logger == nil {
logger = newNoOpLogger()
}
return &ProfilingManager{
profilers: make(map[string]MemoryProfiler),
config: ProfilingConfig{
EnableHeapProfiling: true,
EnableGoroutineProfiling: true,
SnapshotInterval: 30 * time.Second,
LeakThresholdMB: 50, // 50MB
MaxSnapshots: 100,
EnableContinuousMonitoring: true,
MonitoringInterval: 60 * time.Second,
},
logger: logger,
}
}
// TakeSnapshot captures current memory statistics
func (pm *ProfilingManager) TakeSnapshot() (*MemorySnapshot, error) {
var buf bytes.Buffer
snapshot := &MemorySnapshot{
Timestamp: time.Now(),
CustomMetrics: make(map[string]interface{}),
}
// Capture runtime memory statistics
runtime.ReadMemStats(&snapshot.RuntimeStats)
// Capture heap profile if enabled
if pm.config.EnableHeapProfiling {
if err := pprof.WriteHeapProfile(&buf); err != nil {
pm.logger.Errorf("Failed to capture heap profile: %v", err)
} else {
snapshot.HeapProfile = make([]byte, buf.Len())
copy(snapshot.HeapProfile, buf.Bytes())
buf.Reset()
}
}
// Capture goroutine profile if enabled
if pm.config.EnableGoroutineProfiling {
if err := pprof.Lookup("goroutine").WriteTo(&buf, 0); err != nil {
pm.logger.Errorf("Failed to capture goroutine profile: %v", err)
} else {
snapshot.GoroutineProfile = make([]byte, buf.Len())
copy(snapshot.GoroutineProfile, buf.Bytes())
buf.Reset()
}
}
// Capture custom metrics from registered profilers
pm.mu.RLock()
for name, profiler := range pm.profilers {
if customStats := profiler.GetCurrentStats(); customStats != nil {
snapshot.CustomMetrics[name] = customStats
}
}
pm.mu.RUnlock()
return snapshot, nil
}
// StartProfiling begins memory profiling with specified configuration
func (pm *ProfilingManager) StartProfiling(config ProfilingConfig) error {
pm.mu.Lock()
defer pm.mu.Unlock()
if pm.isProfiling {
return fmt.Errorf("profiling already in progress")
}
pm.config = config
pm.isProfiling = true
pm.startTime = time.Now()
// Take baseline snapshot
baseline, err := pm.TakeSnapshot()
if err != nil {
pm.isProfiling = false
return fmt.Errorf("failed to take baseline snapshot: %w", err)
}
pm.baselineSnapshot = baseline
pm.logger.Infof("Memory profiling started at %v", pm.startTime)
return nil
}
// StopProfiling ends memory profiling and returns final snapshot
func (pm *ProfilingManager) StopProfiling() (*MemorySnapshot, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
if !pm.isProfiling {
return nil, fmt.Errorf("profiling not in progress")
}
// Take final snapshot
finalSnapshot, err := pm.TakeSnapshot()
if err != nil {
pm.logger.Errorf("Failed to take final snapshot: %v", err)
}
pm.isProfiling = false
duration := time.Since(pm.startTime)
pm.logger.Infof("Memory profiling stopped after %v", duration)
return finalSnapshot, err
}
// GetCurrentStats returns current runtime memory statistics
func (pm *ProfilingManager) GetCurrentStats() *runtime.MemStats {
stats := &runtime.MemStats{}
runtime.ReadMemStats(stats)
return stats
}
// AnalyzeLeaks performs leak detection analysis
func (pm *ProfilingManager) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
analysis := &LeakAnalysis{
SuspectedLeaks: make([]string, 0),
Recommendations: make([]string, 0),
}
if baseline == nil || current == nil {
analysis.HasLeak = false
analysis.LeakDescription = "Insufficient data for leak analysis"
return analysis
}
// Calculate memory increase
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
analysis.MemoryIncrease = memoryIncrease
// Calculate goroutine increase
currentGoroutines := runtime.NumGoroutine()
baselineGoroutines := runtime.NumGoroutine() // Note: This is not accurate for baseline, but we don't have historical data
goroutineIncrease := currentGoroutines - baselineGoroutines
analysis.GoroutineIncrease = goroutineIncrease
// Check for memory leaks
memoryThreshold := pm.config.LeakThresholdMB * 1024 * 1024 // Convert MB to bytes
if memoryIncrease > memoryThreshold {
analysis.HasLeak = true
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
fmt.Sprintf("Memory usage increased by %.2f MB", float64(memoryIncrease)/(1024*1024)))
analysis.Recommendations = append(analysis.Recommendations,
"Consider checking for unreleased memory pools or growing caches")
}
// Check for goroutine leaks
if goroutineIncrease > 10 { // Arbitrary threshold
analysis.HasLeak = true
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
fmt.Sprintf("Goroutine count increased by %d", goroutineIncrease))
analysis.Recommendations = append(analysis.Recommendations,
"Check for goroutines that are not being properly cleaned up")
}
if analysis.HasLeak {
analysis.LeakDescription = fmt.Sprintf("Potential memory leak detected: %s",
fmt.Sprintf("%.2f MB increase, %d goroutines", float64(memoryIncrease)/(1024*1024), goroutineIncrease))
} else {
analysis.LeakDescription = "No significant memory leaks detected"
}
return analysis
}
// RegisterProfiler registers a component-specific profiler
func (pm *ProfilingManager) RegisterProfiler(name string, profiler MemoryProfiler) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.profilers[name] = profiler
pm.logger.Debugf("Registered profiler: %s", name)
}
// UnregisterProfiler removes a component-specific profiler
func (pm *ProfilingManager) UnregisterProfiler(name string) {
pm.mu.Lock()
defer pm.mu.Unlock()
delete(pm.profilers, name)
pm.logger.Debugf("Unregistered profiler: %s", name)
}
// GetRegisteredProfilers returns list of registered profiler names
func (pm *ProfilingManager) GetRegisteredProfilers() []string {
pm.mu.RLock()
defer pm.mu.RUnlock()
names := make([]string, 0, len(pm.profilers))
for name := range pm.profilers {
names = append(names, name)
}
return names
}
// MemoryTestOrchestrator coordinates memory leak testing across components
type MemoryTestOrchestrator struct {
mu sync.RWMutex
profilers map[string]MemoryProfiler
config LeakDetectionConfig
logger *Logger
isRunning bool
stopChan chan struct{}
testResults map[string]*LeakAnalysis
}
// NewMemoryTestOrchestrator creates a new test orchestrator
func NewMemoryTestOrchestrator(config LeakDetectionConfig, logger *Logger) *MemoryTestOrchestrator {
if logger == nil {
logger = newNoOpLogger()
}
return &MemoryTestOrchestrator{
profilers: make(map[string]MemoryProfiler),
config: config,
logger: logger,
stopChan: make(chan struct{}),
testResults: make(map[string]*LeakAnalysis),
}
}
// RegisterComponent registers a component for memory leak testing
func (mto *MemoryTestOrchestrator) RegisterComponent(name string, profiler MemoryProfiler) {
mto.mu.Lock()
defer mto.mu.Unlock()
mto.profilers[name] = profiler
mto.logger.Debugf("Registered component for leak testing: %s", name)
}
// UnregisterComponent removes a component from leak testing
func (mto *MemoryTestOrchestrator) UnregisterComponent(name string) {
mto.mu.Lock()
defer mto.mu.Unlock()
delete(mto.profilers, name)
delete(mto.testResults, name)
mto.logger.Debugf("Unregistered component from leak testing: %s", name)
}
// StartLeakDetection begins continuous leak detection monitoring
func (mto *MemoryTestOrchestrator) StartLeakDetection() error {
mto.mu.Lock()
defer mto.mu.Unlock()
if mto.isRunning {
return fmt.Errorf("leak detection already running")
}
if !mto.config.EnableLeakDetection {
return fmt.Errorf("leak detection is disabled in configuration")
}
mto.isRunning = true
go mto.runLeakDetection()
mto.logger.Infof("Memory leak detection started")
return nil
}
// StopLeakDetection stops continuous leak detection monitoring
func (mto *MemoryTestOrchestrator) StopLeakDetection() error {
mto.mu.Lock()
defer mto.mu.Unlock()
if !mto.isRunning {
return fmt.Errorf("leak detection not running")
}
mto.isRunning = false
close(mto.stopChan)
mto.stopChan = make(chan struct{}) // Reset for potential restart
mto.logger.Infof("Memory leak detection stopped")
return nil
}
// runLeakDetection performs continuous leak detection monitoring
func (mto *MemoryTestOrchestrator) runLeakDetection() {
ticker := time.NewTicker(5 * time.Minute) // Check every 5 minutes
defer ticker.Stop()
baselineSnapshots := make(map[string]*MemorySnapshot)
// Take initial baseline snapshots
mto.mu.RLock()
for name, profiler := range mto.profilers {
if snapshot, err := profiler.TakeSnapshot(); err == nil {
baselineSnapshots[name] = snapshot
}
}
mto.mu.RUnlock()
for {
select {
case <-ticker.C:
mto.performLeakCheck(baselineSnapshots)
case <-mto.stopChan:
return
}
}
}
// performLeakCheck performs leak detection for all registered components
func (mto *MemoryTestOrchestrator) performLeakCheck(baselineSnapshots map[string]*MemorySnapshot) {
mto.mu.RLock()
defer mto.mu.RUnlock()
for name, profiler := range mto.profilers {
baseline, exists := baselineSnapshots[name]
if !exists {
continue
}
current, err := profiler.TakeSnapshot()
if err != nil {
mto.logger.Errorf("Failed to take snapshot for component %s: %v", name, err)
continue
}
analysis := profiler.AnalyzeLeaks(baseline, current)
if analysis.HasLeak {
mto.logger.Errorf("Memory leak detected in component %s: %s", name, analysis.LeakDescription)
for _, rec := range analysis.Recommendations {
mto.logger.Errorf("Recommendation for %s: %s", name, rec)
}
}
mto.testResults[name] = analysis
}
}
// GetLeakAnalysis returns leak analysis for a specific component
func (mto *MemoryTestOrchestrator) GetLeakAnalysis(componentName string) (*LeakAnalysis, bool) {
mto.mu.RLock()
defer mto.mu.RUnlock()
analysis, exists := mto.testResults[componentName]
return analysis, exists
}
// GetAllLeakAnalyses returns leak analyses for all components
func (mto *MemoryTestOrchestrator) GetAllLeakAnalyses() map[string]*LeakAnalysis {
mto.mu.RLock()
defer mto.mu.RUnlock()
results := make(map[string]*LeakAnalysis)
for name, analysis := range mto.testResults {
results[name] = analysis
}
return results
}
// Component-specific profiler implementations
// SessionPoolProfiler monitors session pool memory usage
type SessionPoolProfiler struct {
sessionManager *SessionManager
logger *Logger
}
// NewSessionPoolProfiler creates a new session pool profiler
func NewSessionPoolProfiler(sm *SessionManager, logger *Logger) *SessionPoolProfiler {
if logger == nil {
logger = newNoOpLogger()
}
return &SessionPoolProfiler{
sessionManager: sm,
logger: logger,
}
}
// TakeSnapshot captures session pool memory statistics
func (spp *SessionPoolProfiler) TakeSnapshot() (*MemorySnapshot, error) {
snapshot := &MemorySnapshot{
Timestamp: time.Now(),
CustomMetrics: make(map[string]interface{}),
}
// Capture runtime stats
runtime.ReadMemStats(&snapshot.RuntimeStats)
// Add session pool specific metrics
snapshot.CustomMetrics["session_pool_metrics"] = spp.sessionManager.GetSessionMetrics()
return snapshot, nil
}
// StartProfiling begins profiling (no-op for session pools)
func (spp *SessionPoolProfiler) StartProfiling(config ProfilingConfig) error {
return nil
}
// StopProfiling ends profiling (no-op for session pools)
func (spp *SessionPoolProfiler) StopProfiling() (*MemorySnapshot, error) {
return spp.TakeSnapshot()
}
// GetCurrentStats returns current memory statistics
func (spp *SessionPoolProfiler) GetCurrentStats() *runtime.MemStats {
stats := &runtime.MemStats{}
runtime.ReadMemStats(stats)
return stats
}
// AnalyzeLeaks analyzes session pool for leaks
func (spp *SessionPoolProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
analysis := &LeakAnalysis{
SuspectedLeaks: make([]string, 0),
Recommendations: make([]string, 0),
}
if baseline == nil || current == nil {
analysis.LeakDescription = "Insufficient session pool data"
return analysis
}
// Check for session pool specific leaks
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
if memoryIncrease > 10*1024*1024 { // 10MB threshold
analysis.HasLeak = true
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
"Session pool memory usage increased significantly")
analysis.Recommendations = append(analysis.Recommendations,
"Check for sessions not being returned to pool properly")
}
return analysis
}
// CacheMemoryProfiler monitors cache memory usage
type CacheMemoryProfiler struct {
cache *Cache
logger *Logger
}
// NewCacheMemoryProfiler creates a new cache memory profiler
func NewCacheMemoryProfiler(cache *Cache, logger *Logger) *CacheMemoryProfiler {
if logger == nil {
logger = newNoOpLogger()
}
return &CacheMemoryProfiler{
cache: cache,
logger: logger,
}
}
// TakeSnapshot captures cache memory statistics
func (cmp *CacheMemoryProfiler) TakeSnapshot() (*MemorySnapshot, error) {
snapshot := &MemorySnapshot{
Timestamp: time.Now(),
CustomMetrics: make(map[string]interface{}),
}
runtime.ReadMemStats(&snapshot.RuntimeStats)
// Add cache-specific metrics (would need to be added to Cache struct)
snapshot.CustomMetrics["cache_size"] = "unknown" // Placeholder
return snapshot, nil
}
// StartProfiling begins profiling (no-op for cache)
func (cmp *CacheMemoryProfiler) StartProfiling(config ProfilingConfig) error {
return nil
}
// StopProfiling ends profiling
func (cmp *CacheMemoryProfiler) StopProfiling() (*MemorySnapshot, error) {
return cmp.TakeSnapshot()
}
// GetCurrentStats returns current memory statistics
func (cmp *CacheMemoryProfiler) GetCurrentStats() *runtime.MemStats {
stats := &runtime.MemStats{}
runtime.ReadMemStats(stats)
return stats
}
// AnalyzeLeaks analyzes cache for memory leaks
func (cmp *CacheMemoryProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
analysis := &LeakAnalysis{
SuspectedLeaks: make([]string, 0),
Recommendations: make([]string, 0),
}
if baseline == nil || current == nil {
analysis.LeakDescription = "Insufficient cache data"
return analysis
}
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
if memoryIncrease > 20*1024*1024 { // 20MB threshold for cache
analysis.HasLeak = true
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
"Cache memory usage increased significantly")
analysis.Recommendations = append(analysis.Recommendations,
"Check cache size limits and cleanup intervals")
}
return analysis
}
// HTTPClientProfiler monitors HTTP client connection pools
type HTTPClientProfiler struct {
httpClient *http.Client
logger *Logger
}
// NewHTTPClientProfiler creates a new HTTP client profiler
func NewHTTPClientProfiler(client *http.Client, logger *Logger) *HTTPClientProfiler {
if logger == nil {
logger = newNoOpLogger()
}
return &HTTPClientProfiler{
httpClient: client,
logger: logger,
}
}
// TakeSnapshot captures HTTP client memory statistics
func (hcp *HTTPClientProfiler) TakeSnapshot() (*MemorySnapshot, error) {
snapshot := &MemorySnapshot{
Timestamp: time.Now(),
CustomMetrics: make(map[string]interface{}),
}
runtime.ReadMemStats(&snapshot.RuntimeStats)
// Add HTTP client specific metrics
if transport, ok := hcp.httpClient.Transport.(*http.Transport); ok {
snapshot.CustomMetrics["idle_connections"] = transport.IdleConnTimeout.String()
snapshot.CustomMetrics["max_idle_conns"] = transport.MaxIdleConns
}
return snapshot, nil
}
// StartProfiling begins profiling (no-op for HTTP client)
func (hcp *HTTPClientProfiler) StartProfiling(config ProfilingConfig) error {
return nil
}
// StopProfiling ends profiling
func (hcp *HTTPClientProfiler) StopProfiling() (*MemorySnapshot, error) {
return hcp.TakeSnapshot()
}
// GetCurrentStats returns current memory statistics
func (hcp *HTTPClientProfiler) GetCurrentStats() *runtime.MemStats {
stats := &runtime.MemStats{}
runtime.ReadMemStats(stats)
return stats
}
// AnalyzeLeaks analyzes HTTP client for connection leaks
func (hcp *HTTPClientProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
analysis := &LeakAnalysis{
SuspectedLeaks: make([]string, 0),
Recommendations: make([]string, 0),
}
if baseline == nil || current == nil {
analysis.LeakDescription = "Insufficient HTTP client data"
return analysis
}
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
if memoryIncrease > 5*1024*1024 { // 5MB threshold for HTTP client
analysis.HasLeak = true
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
"HTTP client memory usage increased significantly")
analysis.Recommendations = append(analysis.Recommendations,
"Check for HTTP response bodies not being drained properly")
}
return analysis
}
// TokenCompressionProfiler monitors token compression memory usage
type TokenCompressionProfiler struct {
compressionPool *TokenCompressionPool
logger *Logger
}
// NewTokenCompressionProfiler creates a new token compression profiler
func NewTokenCompressionProfiler(pool *TokenCompressionPool, logger *Logger) *TokenCompressionProfiler {
if logger == nil {
logger = newNoOpLogger()
}
return &TokenCompressionProfiler{
compressionPool: pool,
logger: logger,
}
}
// TakeSnapshot captures token compression memory statistics
func (tcp *TokenCompressionProfiler) TakeSnapshot() (*MemorySnapshot, error) {
snapshot := &MemorySnapshot{
Timestamp: time.Now(),
CustomMetrics: make(map[string]interface{}),
}
runtime.ReadMemStats(&snapshot.RuntimeStats)
// Add compression pool specific metrics
snapshot.CustomMetrics["compression_pool_active"] = true
return snapshot, nil
}
// StartProfiling begins profiling (no-op for compression)
func (tcp *TokenCompressionProfiler) StartProfiling(config ProfilingConfig) error {
return nil
}
// StopProfiling ends profiling
func (tcp *TokenCompressionProfiler) StopProfiling() (*MemorySnapshot, error) {
return tcp.TakeSnapshot()
}
// GetCurrentStats returns current memory statistics
func (tcp *TokenCompressionProfiler) GetCurrentStats() *runtime.MemStats {
stats := &runtime.MemStats{}
runtime.ReadMemStats(stats)
return stats
}
// AnalyzeLeaks analyzes token compression for memory leaks
func (tcp *TokenCompressionProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
analysis := &LeakAnalysis{
SuspectedLeaks: make([]string, 0),
Recommendations: make([]string, 0),
}
if baseline == nil || current == nil {
analysis.LeakDescription = "Insufficient compression data"
return analysis
}
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
if memoryIncrease > 2*1024*1024 { // 2MB threshold for compression
analysis.HasLeak = true
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
"Token compression memory usage increased significantly")
analysis.Recommendations = append(analysis.Recommendations,
"Check for compression buffers not being returned to pool")
}
return analysis
}
// MemoryPoolProfiler monitors memory pool usage and detects leaks
type MemoryPoolProfiler struct {
memoryPoolManager *MemoryPoolManager
tokenCompressionPool *TokenCompressionPool
logger *Logger
}
// NewMemoryPoolProfiler creates a new memory pool profiler
func NewMemoryPoolProfiler(memoryPoolManager *MemoryPoolManager, tokenCompressionPool *TokenCompressionPool, logger *Logger) *MemoryPoolProfiler {
if logger == nil {
logger = newNoOpLogger()
}
return &MemoryPoolProfiler{
memoryPoolManager: memoryPoolManager,
tokenCompressionPool: tokenCompressionPool,
logger: logger,
}
}
// TakeSnapshot captures memory pool statistics
func (mpp *MemoryPoolProfiler) TakeSnapshot() (*MemorySnapshot, error) {
snapshot := &MemorySnapshot{
Timestamp: time.Now(),
CustomMetrics: make(map[string]interface{}),
}
// Capture runtime stats
runtime.ReadMemStats(&snapshot.RuntimeStats)
// Add memory pool metrics
if mpp.memoryPoolManager != nil {
snapshot.CustomMetrics["memory_pool_active"] = true
// Note: sync.Pool doesn't expose internal statistics, so we track usage patterns
}
if mpp.tokenCompressionPool != nil {
snapshot.CustomMetrics["token_compression_pool_active"] = true
}
return snapshot, nil
}
// StartProfiling begins profiling (no-op for memory pools)
func (mpp *MemoryPoolProfiler) StartProfiling(config ProfilingConfig) error {
return nil
}
// StopProfiling ends profiling
func (mpp *MemoryPoolProfiler) StopProfiling() (*MemorySnapshot, error) {
return mpp.TakeSnapshot()
}
// GetCurrentStats returns current memory statistics
func (mpp *MemoryPoolProfiler) GetCurrentStats() *runtime.MemStats {
stats := &runtime.MemStats{}
runtime.ReadMemStats(stats)
return stats
}
// AnalyzeLeaks analyzes memory pools for leaks
func (mpp *MemoryPoolProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
analysis := &LeakAnalysis{
SuspectedLeaks: make([]string, 0),
Recommendations: make([]string, 0),
}
if baseline == nil || current == nil {
analysis.LeakDescription = "Insufficient memory pool data"
return analysis
}
// Check for memory leaks in pool operations
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
if memoryIncrease > 5*1024*1024 { // 5MB threshold for pool operations
analysis.HasLeak = true
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
"Memory pool operations caused significant memory increase")
analysis.Recommendations = append(analysis.Recommendations,
"Check for objects not being returned to memory pools properly")
}
return analysis
}
// Global profiling manager instance
var globalProfilingManager *ProfilingManager
var profilingManagerOnce sync.Once
// GetGlobalProfilingManager returns the singleton profiling manager
func GetGlobalProfilingManager() *ProfilingManager {
profilingManagerOnce.Do(func() {
globalProfilingManager = NewProfilingManager(nil)
})
return globalProfilingManager
}
// Global test orchestrator instance
var globalTestOrchestrator *MemoryTestOrchestrator
var testOrchestratorOnce sync.Once
// GetGlobalTestOrchestrator returns the singleton test orchestrator
func GetGlobalTestOrchestrator() *MemoryTestOrchestrator {
testOrchestratorOnce.Do(func() {
config := LeakDetectionConfig{
EnableLeakDetection: true,
LeakThresholdMB: 50,
GoroutineLeakThreshold: 10,
SessionPoolThreshold: 100,
CacheMemoryThreshold: 20 * 1024 * 1024, // 20MB
HTTPClientThreshold: 50,
TokenCompressionThreshold: 2 * 1024 * 1024, // 2MB
}
globalTestOrchestrator = NewMemoryTestOrchestrator(config, nil)
})
return globalTestOrchestrator
}
+788
View File
@@ -0,0 +1,788 @@
package traefikoidc
import (
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"runtime"
"sync"
"testing"
"time"
)
func TestProfilingManager(t *testing.T) {
logger := NewLogger("debug")
pm := NewProfilingManager(logger)
// Test taking a snapshot
snapshot, err := pm.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take snapshot: %v", err)
}
if snapshot == nil {
t.Fatal("Snapshot is nil")
}
if snapshot.RuntimeStats.Alloc == 0 {
t.Error("Runtime stats Alloc should not be zero")
}
if snapshot.Timestamp.IsZero() {
t.Error("Snapshot timestamp should not be zero")
}
}
func TestMemoryTestOrchestrator(t *testing.T) {
logger := NewLogger("debug")
config := LeakDetectionConfig{
EnableLeakDetection: true,
LeakThresholdMB: 10,
}
mto := NewMemoryTestOrchestrator(config, logger)
// Test registering a component
sessionManager, err := NewSessionManager("test-key-32-chars-long-for-testing", false, "", logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
profiler := NewSessionPoolProfiler(sessionManager, logger)
mto.RegisterComponent("session_pool", profiler)
// Test getting leak analysis (should return false initially since no checks have been performed)
_, exists := mto.GetLeakAnalysis("session_pool")
if exists {
t.Error("Should not have leak analysis before any checks are performed")
}
// Perform a manual leak check
baseline, err := profiler.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take baseline snapshot: %v", err)
}
time.Sleep(10 * time.Millisecond) // Small delay
// Manually trigger leak check with baseline
baselineSnapshots := make(map[string]*MemorySnapshot)
baselineSnapshots["session_pool"] = baseline
mto.performLeakCheck(baselineSnapshots)
// Now test getting leak analysis
analysis, exists := mto.GetLeakAnalysis("session_pool")
if !exists {
t.Error("Should have leak analysis after performing checks")
}
if analysis == nil {
t.Error("Leak analysis should not be nil after checks")
}
}
func TestComponentProfilers(t *testing.T) {
logger := NewLogger("debug")
// Test Session Pool Profiler
sessionManager, err := NewSessionManager("test-key-32-chars-long-for-testing", false, "", logger)
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
spp := NewSessionPoolProfiler(sessionManager, logger)
snapshot, err := spp.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take session pool snapshot: %v", err)
}
if snapshot == nil {
t.Fatal("Session pool snapshot is nil")
}
// Test Cache Memory Profiler
cache := NewCache()
cmp := NewCacheMemoryProfiler(cache, logger)
snapshot, err = cmp.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take cache snapshot: %v", err)
}
if snapshot == nil {
t.Fatal("Cache snapshot is nil")
}
// Test HTTP Client Profiler
httpClient := createDefaultHTTPClient()
hcp := NewHTTPClientProfiler(httpClient, logger)
snapshot, err = hcp.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take HTTP client snapshot: %v", err)
}
if snapshot == nil {
t.Fatal("HTTP client snapshot is nil")
}
// Test Token Compression Profiler
compressionPool := NewTokenCompressionPool()
tcp := NewTokenCompressionProfiler(compressionPool, logger)
snapshot, err = tcp.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take compression snapshot: %v", err)
}
if snapshot == nil {
t.Fatal("Compression snapshot is nil")
}
}
func TestLeakAnalysis(t *testing.T) {
logger := NewLogger("debug")
pm := NewProfilingManager(logger)
// Create baseline snapshot
baseline, err := pm.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to create baseline: %v", err)
}
// Wait a bit and create current snapshot
time.Sleep(10 * time.Millisecond)
current, err := pm.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to create current snapshot: %v", err)
}
// Test leak analysis
analysis := pm.AnalyzeLeaks(baseline, current)
if analysis == nil {
t.Fatal("Leak analysis is nil")
}
// Analysis should not have leaks for normal operation
if analysis.HasLeak {
t.Logf("Leak detected: %s", analysis.LeakDescription)
// This is acceptable as the test environment may have varying memory usage
}
}
func TestGlobalInstances(t *testing.T) {
// Test global profiling manager
gpm := GetGlobalProfilingManager()
if gpm == nil {
t.Fatal("Global profiling manager is nil")
}
// Test global test orchestrator
gto := GetGlobalTestOrchestrator()
if gto == nil {
t.Fatal("Global test orchestrator is nil")
}
// Test that they're singletons
gpm2 := GetGlobalProfilingManager()
if gpm != gpm2 {
t.Error("Global profiling manager should be singleton")
}
gto2 := GetGlobalTestOrchestrator()
if gto != gto2 {
t.Error("Global test orchestrator should be singleton")
}
}
func TestProfilingConfig(t *testing.T) {
config := ProfilingConfig{
EnableHeapProfiling: true,
EnableGoroutineProfiling: true,
SnapshotInterval: 30 * time.Second,
LeakThresholdMB: 50,
MaxSnapshots: 100,
EnableContinuousMonitoring: true,
MonitoringInterval: 60 * time.Second,
}
if !config.EnableHeapProfiling {
t.Error("Heap profiling should be enabled")
}
if !config.EnableGoroutineProfiling {
t.Error("Goroutine profiling should be enabled")
}
if config.LeakThresholdMB != 50 {
t.Errorf("Expected leak threshold 50, got %d", config.LeakThresholdMB)
}
}
func TestLeakDetectionConfig(t *testing.T) {
config := LeakDetectionConfig{
EnableLeakDetection: true,
LeakThresholdMB: 50,
GoroutineLeakThreshold: 10,
SessionPoolThreshold: 100,
CacheMemoryThreshold: 20 * 1024 * 1024,
HTTPClientThreshold: 50,
TokenCompressionThreshold: 2 * 1024 * 1024,
}
if !config.EnableLeakDetection {
t.Error("Leak detection should be enabled")
}
if config.LeakThresholdMB != 50 {
t.Errorf("Expected leak threshold 50, got %d", config.LeakThresholdMB)
}
if config.CacheMemoryThreshold != 20*1024*1024 {
t.Errorf("Expected cache threshold 20MB, got %d", config.CacheMemoryThreshold)
}
}
// ProviderMetadataProfiler monitors provider metadata fetching and caching operations
type ProviderMetadataProfiler struct {
metadataCache *MetadataCache
httpClient *http.Client
logger *Logger
providerURL string
}
// NewProviderMetadataProfiler creates a new provider metadata profiler
func NewProviderMetadataProfiler(metadataCache *MetadataCache, httpClient *http.Client, providerURL string, logger *Logger) *ProviderMetadataProfiler {
if logger == nil {
logger = newNoOpLogger()
}
return &ProviderMetadataProfiler{
metadataCache: metadataCache,
httpClient: httpClient,
providerURL: providerURL,
logger: logger,
}
}
// TakeSnapshot captures current memory statistics for metadata operations
func (pmp *ProviderMetadataProfiler) TakeSnapshot() (*MemorySnapshot, error) {
snapshot := &MemorySnapshot{
Timestamp: time.Now(),
CustomMetrics: make(map[string]interface{}),
}
// Capture runtime memory statistics
runtime.ReadMemStats(&snapshot.RuntimeStats)
// Add metadata-specific metrics
snapshot.CustomMetrics["metadata_cache_size"] = 1 // Placeholder for cache size
snapshot.CustomMetrics["metadata_fetch_count"] = 0 // Placeholder for fetch count
snapshot.CustomMetrics["background_goroutines"] = runtime.NumGoroutine()
return snapshot, nil
}
// StartProfiling begins profiling (no-op for metadata profiler)
func (pmp *ProviderMetadataProfiler) StartProfiling(config ProfilingConfig) error {
return nil
}
// StopProfiling ends profiling
func (pmp *ProviderMetadataProfiler) StopProfiling() (*MemorySnapshot, error) {
return pmp.TakeSnapshot()
}
// GetCurrentStats returns current memory statistics
func (pmp *ProviderMetadataProfiler) GetCurrentStats() *runtime.MemStats {
stats := &runtime.MemStats{}
runtime.ReadMemStats(stats)
return stats
}
// AnalyzeLeaks analyzes metadata operations for memory leaks
func (pmp *ProviderMetadataProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
analysis := &LeakAnalysis{
SuspectedLeaks: make([]string, 0),
Recommendations: make([]string, 0),
}
if baseline == nil || current == nil {
analysis.LeakDescription = "Insufficient metadata data"
return analysis
}
// Check for memory leaks
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
if memoryIncrease > 5*1024*1024 { // 5MB threshold for metadata operations
analysis.HasLeak = true
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
"Metadata operations memory usage increased significantly")
analysis.Recommendations = append(analysis.Recommendations,
"Check for metadata cache not being cleaned up properly")
}
// Check for goroutine leaks
goroutineIncrease := current.CustomMetrics["background_goroutines"].(int) - baseline.CustomMetrics["background_goroutines"].(int)
if goroutineIncrease > 2 { // Allow some variance
analysis.HasLeak = true
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
fmt.Sprintf("Goroutine count increased by %d during metadata operations", goroutineIncrease))
analysis.Recommendations = append(analysis.Recommendations,
"Check for background goroutines not being cleaned up")
}
return analysis
}
// TestProviderMetadataMemoryLeakDetection tests for memory leaks in provider metadata operations
func TestProviderMetadataMemoryLeakDetection(t *testing.T) {
logger := NewLogger("debug")
strictMode := os.Getenv("STRICT_MEMORY_TEST") == "true"
if strictMode {
t.Log("Running in strict memory test mode - will fail on detected leaks")
} else {
t.Log("Running in lenient memory test mode - will log warnings instead of failing")
}
config := LeakDetectionConfig{
EnableLeakDetection: true,
LeakThresholdMB: 10,
}
mto := NewMemoryTestOrchestrator(config, logger)
// Create mock HTTP server for metadata endpoint with failure simulation
requestCount := 0
serverFailures := 0
mockServer := &http.Server{
Addr: "localhost:0", // Let system assign port
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount++
if r.URL.Path == "/.well-known/openid-configuration" {
// Simulate occasional failures to test cache extension
if requestCount%4 == 0 { // Fail every 4th request
serverFailures++
w.WriteHeader(http.StatusInternalServerError)
return
}
metadata := ProviderMetadata{
Issuer: "https://mock-provider.com",
AuthURL: "https://mock-provider.com/auth",
TokenURL: "https://mock-provider.com/token",
JWKSURL: "https://mock-provider.com/jwks",
RevokeURL: "https://mock-provider.com/revoke",
EndSessionURL: "https://mock-provider.com/logout",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "max-age=3600") // 1 hour cache hint
json.NewEncoder(w).Encode(metadata)
} else {
http.NotFound(w, r)
}
}),
}
// Start mock server
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
go mockServer.Serve(listener)
defer mockServer.Close()
providerURL := fmt.Sprintf("http://%s", listener.Addr().String())
httpClient := createDefaultHTTPClient()
// Create metadata cache with WaitGroup for proper goroutine synchronization
var wg sync.WaitGroup
metadataCache := NewMetadataCacheWithLogger(&wg, logger)
// Create profiler
profiler := NewProviderMetadataProfiler(metadataCache, httpClient, providerURL, logger)
mto.RegisterComponent("provider_metadata", profiler)
// Take initial baseline
baseline, err := profiler.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take baseline snapshot: %v", err)
}
initialGoroutines := runtime.NumGoroutine()
// Phase 1: Simulate periodic metadata fetching with some failures
t.Log("Phase 1: Testing periodic fetching with occasional failures...")
for i := 0; i < 20; i++ {
_, err := metadataCache.GetMetadata(providerURL, httpClient, logger)
if err != nil {
t.Logf("Metadata fetch %d failed (expected for cache extension testing): %v", i+1, err)
} else {
t.Logf("Metadata fetch %d succeeded", i+1)
}
time.Sleep(100 * time.Millisecond)
}
// Wait for background cleanup (normally every 5 minutes)
time.Sleep(300 * time.Millisecond)
// Take intermediate snapshot
intermediate, err := profiler.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take intermediate snapshot: %v", err)
}
// Phase 2: Continue with more fetches to test sustained operation
t.Log("Phase 2: Testing sustained operation with 1000 iterations...")
for i := 20; i < 1020; i++ {
_, err := metadataCache.GetMetadata(providerURL, httpClient, logger)
if err != nil {
t.Logf("Metadata fetch %d failed: %v", i+1, err)
}
time.Sleep(50 * time.Millisecond) // Reduced sleep for faster execution
}
// Take final snapshot
current, err := profiler.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take current snapshot: %v", err)
}
finalGoroutines := runtime.NumGoroutine()
// Analyze for leaks
analysis := profiler.AnalyzeLeaks(baseline, current)
// Assertions for memory leaks
if analysis.HasLeak {
if strictMode {
t.Errorf("Memory leak detected in provider metadata operations: %s", analysis.LeakDescription)
for _, leak := range analysis.SuspectedLeaks {
t.Errorf("Suspected leak: %s", leak)
}
} else {
t.Logf("Memory leak warning in provider metadata operations: %s", analysis.LeakDescription)
for _, leak := range analysis.SuspectedLeaks {
t.Logf("Suspected leak: %s", leak)
}
}
for _, rec := range analysis.Recommendations {
t.Logf("Recommendation: %s", rec)
}
}
// Check total memory growth
totalMemoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
if totalMemoryIncrease > 20*1024*1024 { // 20MB threshold for entire test
if strictMode {
t.Errorf("Total memory usage increased by %.2f MB during metadata operations", float64(totalMemoryIncrease)/(1024*1024))
} else {
t.Logf("Total memory usage increased by %.2f MB during metadata operations", float64(totalMemoryIncrease)/(1024*1024))
}
}
// Check for gradual memory growth patterns
intermediateMemoryIncrease := intermediate.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
if intermediateMemoryIncrease > 10*1024*1024 { // 10MB threshold for first phase
if strictMode {
t.Errorf("Memory usage increased by %.2f MB during first phase of metadata operations", float64(intermediateMemoryIncrease)/(1024*1024))
} else {
t.Logf("Memory usage increased by %.2f MB during first phase of metadata operations", float64(intermediateMemoryIncrease)/(1024*1024))
}
}
// Check goroutine count stability
goroutineIncrease := finalGoroutines - initialGoroutines
if goroutineIncrease > 5 { // Allow some variance for test environment
if strictMode {
t.Errorf("Goroutine count increased by %d during metadata operations (initial: %d, final: %d)",
goroutineIncrease, initialGoroutines, finalGoroutines)
} else {
t.Logf("Goroutine count increased by %d during metadata operations (initial: %d, final: %d)",
goroutineIncrease, initialGoroutines, finalGoroutines)
}
}
// Phase 3: Test cache extension behavior on persistent failures
t.Log("Phase 3: Testing cache extension on persistent failures...")
// Stop mock server to simulate provider unavailability
mockServer.Close()
// Try multiple fetches after server shutdown
postShutdownFailures := 0
for i := 0; i < 5; i++ {
_, err = metadataCache.GetMetadata(providerURL, httpClient, logger)
if err != nil {
postShutdownFailures++
t.Logf("Expected failure %d after server shutdown: %v", i+1, err)
} else {
t.Logf("Unexpected success %d after server shutdown - cache extension working", i+1)
}
time.Sleep(200 * time.Millisecond)
}
if postShutdownFailures == 0 {
if strictMode {
t.Error("Expected some metadata fetches to fail after server shutdown")
} else {
t.Log("Warning: No metadata fetches failed after server shutdown - cache extension may not be working as expected")
}
}
// Phase 4: Test background goroutine lifecycle and cleanup
t.Log("Phase 4: Testing background goroutine lifecycle...")
// Wait longer to allow background cleanup to run
time.Sleep(1 * time.Second)
// Take final snapshot after cleanup
finalAfterCleanup, err := profiler.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take final snapshot after cleanup: %v", err)
}
// Check if memory decreased after cleanup
if finalAfterCleanup.RuntimeStats.Alloc < current.RuntimeStats.Alloc {
memoryDecrease := current.RuntimeStats.Alloc - finalAfterCleanup.RuntimeStats.Alloc
t.Logf("Memory decreased by %.2f MB after cleanup phase", float64(memoryDecrease)/(1024*1024))
}
// Clean up resources
metadataCache.Close()
wg.Wait() // Ensure all background goroutines complete
t.Logf("Test completed: %d total requests, %d server failures, %d post-shutdown failures",
requestCount, serverFailures, postShutdownFailures)
t.Logf("Memory usage: baseline=%.2f MB, intermediate=%.2f MB, final=%.2f MB",
float64(baseline.RuntimeStats.Alloc)/(1024*1024),
float64(intermediate.RuntimeStats.Alloc)/(1024*1024),
float64(current.RuntimeStats.Alloc)/(1024*1024))
}
// TestMemoryPoolLeakDetection tests for memory leaks in memory pool operations
func TestMemoryPoolLeakDetection(t *testing.T) {
logger := NewLogger("debug")
strictMode := os.Getenv("STRICT_MEMORY_TEST") == "true"
if strictMode {
t.Log("Running in strict memory test mode - will fail on detected leaks")
} else {
t.Log("Running in lenient memory test mode - will log warnings instead of failing")
}
config := LeakDetectionConfig{
EnableLeakDetection: true,
LeakThresholdMB: 10,
}
mto := NewMemoryTestOrchestrator(config, logger)
// Create memory pool manager and token compression pool
memoryPoolManager := NewMemoryPoolManager()
tokenCompressionPool := NewTokenCompressionPool()
// Create profiler for memory pools
profiler := NewMemoryPoolProfiler(memoryPoolManager, tokenCompressionPool, logger)
mto.RegisterComponent("memory_pools", profiler)
// Take initial baseline
baseline, err := profiler.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take baseline snapshot: %v", err)
}
initialGoroutines := runtime.NumGoroutine()
// Phase 1: Simulate various memory pool operations
t.Log("Phase 1: Testing memory pool operations with various patterns...")
// Test compression buffer pool
for i := 0; i < 100; i++ {
buf := memoryPoolManager.GetCompressionBuffer()
// Simulate some work with the buffer
buf.WriteString(fmt.Sprintf("test data %d", i))
// Properly return buffer to pool
memoryPoolManager.PutCompressionBuffer(buf)
}
// Test JWT parsing buffer pool
for i := 0; i < 50; i++ {
jwtBuf := memoryPoolManager.GetJWTParsingBuffer()
// Simulate JWT parsing operations
jwtBuf.HeaderBuf = append(jwtBuf.HeaderBuf, []byte("header")...)
jwtBuf.PayloadBuf = append(jwtBuf.PayloadBuf, []byte("payload")...)
jwtBuf.SignatureBuf = append(jwtBuf.SignatureBuf, []byte("signature")...)
// Properly return buffer to pool
memoryPoolManager.PutJWTParsingBuffer(jwtBuf)
}
// Test HTTP response buffer pool
for i := 0; i < 75; i++ {
httpBuf := memoryPoolManager.GetHTTPResponseBuffer()
// Simulate HTTP response processing
copy(httpBuf[:min(len(httpBuf), 100)], []byte("http response data"))
// Properly return buffer to pool
memoryPoolManager.PutHTTPResponseBuffer(httpBuf)
}
// Test string builder pool
for i := 0; i < 60; i++ {
sb := memoryPoolManager.GetStringBuilder()
// Simulate string building operations
sb.WriteString(fmt.Sprintf("built string %d", i))
_ = sb.String() // Use the result
// Properly return string builder to pool
memoryPoolManager.PutStringBuilder(sb)
}
// Test token compression pool
for i := 0; i < 40; i++ {
compBuf := tokenCompressionPool.GetCompressionBuffer()
// Simulate compression operations
compBuf.WriteString(fmt.Sprintf("compress data %d", i))
// Properly return buffer to pool
tokenCompressionPool.PutCompressionBuffer(compBuf)
decompBuf := tokenCompressionPool.GetDecompressionBuffer()
// Simulate decompression operations
decompBuf.WriteString(fmt.Sprintf("decompress data %d", i))
// Properly return buffer to pool
tokenCompressionPool.PutDecompressionBuffer(decompBuf)
sb := tokenCompressionPool.GetStringBuilder()
// Simulate string operations
sb.WriteString(fmt.Sprintf("token string %d", i))
_ = sb.String()
// Properly return string builder to pool
tokenCompressionPool.PutStringBuilder(sb)
}
// Take intermediate snapshot
intermediate, err := profiler.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take intermediate snapshot: %v", err)
}
// Phase 2: Continue with more intensive operations to test sustained usage
t.Log("Phase 2: Testing sustained memory pool usage...")
// Simulate mixed operations with varying patterns
for i := 0; i < 200; i++ {
// Mix different pool operations
switch i % 4 {
case 0:
buf := memoryPoolManager.GetCompressionBuffer()
buf.WriteString("mixed operation data")
memoryPoolManager.PutCompressionBuffer(buf)
case 1:
jwtBuf := memoryPoolManager.GetJWTParsingBuffer()
jwtBuf.HeaderBuf = append(jwtBuf.HeaderBuf, []byte("mixed")...)
memoryPoolManager.PutJWTParsingBuffer(jwtBuf)
case 2:
httpBuf := memoryPoolManager.GetHTTPResponseBuffer()
copy(httpBuf[:min(len(httpBuf), 50)], []byte("mixed http"))
memoryPoolManager.PutHTTPResponseBuffer(httpBuf)
case 3:
sb := memoryPoolManager.GetStringBuilder()
sb.WriteString("mixed string building")
_ = sb.String()
memoryPoolManager.PutStringBuilder(sb)
}
}
// Take final snapshot
current, err := profiler.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take current snapshot: %v", err)
}
finalGoroutines := runtime.NumGoroutine()
// Analyze for leaks
analysis := profiler.AnalyzeLeaks(baseline, current)
// Assertions for memory leaks
if analysis.HasLeak {
if strictMode {
t.Errorf("Memory leak detected in memory pool operations: %s", analysis.LeakDescription)
for _, leak := range analysis.SuspectedLeaks {
t.Errorf("Suspected leak: %s", leak)
}
} else {
t.Logf("Memory leak warning in memory pool operations: %s", analysis.LeakDescription)
for _, leak := range analysis.SuspectedLeaks {
t.Logf("Suspected leak: %s", leak)
}
}
for _, rec := range analysis.Recommendations {
t.Logf("Recommendation: %s", rec)
}
}
// Check total memory growth
totalMemoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
if totalMemoryIncrease > 15*1024*1024 { // 15MB threshold for entire test
if strictMode {
t.Errorf("Total memory usage increased by %.2f MB during memory pool operations", float64(totalMemoryIncrease)/(1024*1024))
} else {
t.Logf("Total memory usage increased by %.2f MB during memory pool operations", float64(totalMemoryIncrease)/(1024*1024))
}
}
// Check for gradual memory growth patterns
intermediateMemoryIncrease := intermediate.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
if intermediateMemoryIncrease > 8*1024*1024 { // 8MB threshold for first phase
if strictMode {
t.Errorf("Memory usage increased by %.2f MB during first phase of memory pool operations", float64(intermediateMemoryIncrease)/(1024*1024))
} else {
t.Logf("Memory usage increased by %.2f MB during first phase of memory pool operations", float64(intermediateMemoryIncrease)/(1024*1024))
}
}
// Check goroutine count stability
goroutineIncrease := finalGoroutines - initialGoroutines
if goroutineIncrease > 3 { // Allow small variance for test environment
if strictMode {
t.Errorf("Goroutine count increased by %d during memory pool operations (initial: %d, final: %d)",
goroutineIncrease, initialGoroutines, finalGoroutines)
} else {
t.Logf("Goroutine count increased by %d during memory pool operations (initial: %d, final: %d)",
goroutineIncrease, initialGoroutines, finalGoroutines)
}
}
// Phase 3: Test cleanup verification
t.Log("Phase 3: Testing cleanup verification...")
// Force garbage collection to see if pools are properly managed
runtime.GC()
runtime.GC() // Run twice to ensure cleanup
time.Sleep(10 * time.Millisecond) // Allow cleanup to complete
// Take post-cleanup snapshot
postCleanup, err := profiler.TakeSnapshot()
if err != nil {
t.Fatalf("Failed to take post-cleanup snapshot: %v", err)
}
// Check if memory decreased after cleanup
if postCleanup.RuntimeStats.Alloc < current.RuntimeStats.Alloc {
memoryDecrease := current.RuntimeStats.Alloc - postCleanup.RuntimeStats.Alloc
t.Logf("Memory decreased by %.2f MB after cleanup phase", float64(memoryDecrease)/(1024*1024))
} else if postCleanup.RuntimeStats.Alloc > current.RuntimeStats.Alloc {
memoryIncrease := postCleanup.RuntimeStats.Alloc - current.RuntimeStats.Alloc
if strictMode {
t.Errorf("Memory increased by %.2f MB after cleanup phase - possible cleanup issues", float64(memoryIncrease)/(1024*1024))
} else {
t.Logf("Memory increased by %.2f MB after cleanup phase - possible cleanup issues", float64(memoryIncrease)/(1024*1024))
}
}
t.Logf("Memory pool leak detection test completed")
t.Logf("Memory usage: baseline=%.2f MB, intermediate=%.2f MB, final=%.2f MB, post-cleanup=%.2f MB",
float64(baseline.RuntimeStats.Alloc)/(1024*1024),
float64(intermediate.RuntimeStats.Alloc)/(1024*1024),
float64(current.RuntimeStats.Alloc)/(1024*1024),
float64(postCleanup.RuntimeStats.Alloc)/(1024*1024))
}
+6
View File
@@ -74,6 +74,9 @@ func TestRefreshGracePeriodConfiguration(t *testing.T) {
}
func TestTokenRefreshWithinGracePeriod(t *testing.T) {
// Reset global state to prevent test interference
resetGlobalState()
refreshCount := int32(0)
tokenVersion := int32(1)
@@ -286,6 +289,9 @@ func TestGracePeriodWithProviderSpecificBehavior(t *testing.T) {
}
func TestRefreshGracePeriodConcurrency(t *testing.T) {
// Reset global state to prevent test interference
resetGlobalState()
var refreshMutex sync.Mutex
refreshCount := 0
blockedRequests := int32(0)
+6
View File
@@ -40,6 +40,7 @@ func TestConcurrentTokenVerification(t *testing.T) {
}
// Create a fresh instance for this test
var goroutineWG sync.WaitGroup
tOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
@@ -51,6 +52,7 @@ func TestConcurrentTokenVerification(t *testing.T) {
allowedUserDomains: map[string]struct{}{"example.com": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
goroutineWG: &goroutineWG,
}
tOidc.tokenVerifier = tOidc
tOidc.jwtVerifier = tOidc
@@ -497,6 +499,7 @@ func TestMaliciousInputValidation(t *testing.T) {
for _, test := range maliciousInputs {
t.Run(test.name, func(t *testing.T) {
// Create a fresh instance for each test to avoid rate limiting issues
var goroutineWG sync.WaitGroup
freshOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
@@ -508,6 +511,7 @@ func TestMaliciousInputValidation(t *testing.T) {
allowedUserDomains: map[string]struct{}{"example.com": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
goroutineWG: &goroutineWG,
}
freshOidc.tokenVerifier = freshOidc
freshOidc.jwtVerifier = freshOidc
@@ -731,6 +735,7 @@ func TestPerformanceUnderLoad(t *testing.T) {
}
// Create fresh instance with high rate limit
var goroutineWG sync.WaitGroup
tOidc := &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
@@ -742,6 +747,7 @@ func TestPerformanceUnderLoad(t *testing.T) {
allowedUserDomains: map[string]struct{}{"example.com": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
goroutineWG: &goroutineWG,
}
tOidc.tokenVerifier = tOidc
tOidc.jwtVerifier = tOidc
+14 -3
View File
@@ -245,7 +245,8 @@ type SessionManager struct {
forceHTTPS bool
cookieDomain string
chunkManager *ChunkManager
cleanupDone bool // Track if we've attempted cookie cleanup
cleanupMutex sync.RWMutex // Protects cleanup operations
cleanupDone bool // Track if we've attempted cookie cleanup
}
// NewSessionManager creates a new session manager with the specified configuration.
@@ -628,7 +629,11 @@ func (sm *SessionManager) CleanupOldCookies(w http.ResponseWriter, r *http.Reque
// If we have a configured domain and this is the first request after config change,
// attempt to delete the cookie with various domain variations to ensure cleanup
if currentDomain != "" && !sm.cleanupDone {
sm.cleanupMutex.RLock()
shouldCleanup := currentDomain != "" && !sm.cleanupDone
sm.cleanupMutex.RUnlock()
if shouldCleanup {
for _, domain := range domainsToClean {
// Skip the current configured domain
if domain == currentDomain || domain == "."+currentDomain || "."+domain == currentDomain {
@@ -655,7 +660,9 @@ func (sm *SessionManager) CleanupOldCookies(w http.ResponseWriter, r *http.Reque
// Mark cleanup as done for this session manager instance
if !sm.cleanupDone && len(processedCookies) > 0 {
sm.cleanupMutex.Lock()
sm.cleanupDone = true
sm.cleanupMutex.Unlock()
}
}
@@ -671,6 +678,8 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
sessionData.dirty = false
var sessionReturned bool
// CRITICAL FIX: Defer logic to return session to pool only on panic, not on successful return
// This prevents memory leaks by ensuring sessions are only pooled when not in use by caller
defer func() {
if !sessionReturned && sessionData != nil {
if r := recover(); r != nil {
@@ -736,7 +745,9 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
sm.getTokenChunkSessions(r, idTokenCookie, sessionData.idTokenChunks)
sessionReturned = false
// CRITICAL FIX: Mark session as successfully returned to caller - prevents premature pool return
// This fixes memory leak where sessions were incorrectly returned to pool immediately after success
sessionReturned = true
return sessionData, nil
}
+37 -1
View File
@@ -6,6 +6,7 @@ import (
"log"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
@@ -25,6 +26,33 @@ func (w *testWriter) Write(p []byte) (n int, err error) {
// Test helper adapters for the new test files
// resetGlobalState resets all global singletons to prevent test interference
func resetGlobalState() {
// Reset global cache manager
cacheManagerMutex.Lock()
if globalCacheManager != nil {
globalCacheManager.Close()
globalCacheManager = nil
}
cacheManagerOnce = sync.Once{}
cacheManagerMutex.Unlock()
// Reset replay cache
replayCacheMu.Lock()
if replayCache != nil {
replayCache.Close()
replayCache = nil
}
replayCacheOnce = sync.Once{}
replayCacheMu.Unlock()
// Reset memory pools
memoryPoolMutex.Lock()
globalMemoryPools = nil
memoryPoolOnce = sync.Once{}
memoryPoolMutex.Unlock()
}
// createTestConfig creates a config with all required fields populated for testing
func createTestConfig() *Config {
config := CreateConfig()
@@ -38,6 +66,9 @@ func createTestConfig() *Config {
// setupTestOIDCMiddleware creates a test OIDC middleware instance with mock servers
func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httptest.Server) {
// Reset global state to ensure test isolation
resetGlobalState()
// Create mock OIDC server
var serverURL string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -114,6 +145,9 @@ func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httpt
testTokenURL := testIssuerURL + "/token"
testJWKSURL := testIssuerURL + "/keys"
// Create WaitGroup for background goroutines
var wg sync.WaitGroup
// Create TraefikOidc instance directly
oidc := &TraefikOidc{
next: nextHandler,
@@ -143,8 +177,10 @@ func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httpt
forceHTTPS: config.ForceHTTPS,
allowedUserDomains: make(map[string]struct{}),
jwkCache: &JWKCache{},
metadataCache: NewMetadataCache(),
metadataCache: NewMetadataCache(nil),
ctx: context.Background(),
goroutineWG: &wg,
providerURL: serverURL,
}
// Process excluded URLs
+1 -1
View File
@@ -14,4 +14,4 @@ func generateRandomString(length int) string {
return "random-string-fallback"
}
return hex.EncodeToString(bytes)
}
}