mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| dc16b4fc94 | |||
| 7db2f8d66c | |||
| 024e349aa9 | |||
| 873105108e |
+24
-2
@@ -1,6 +1,9 @@
|
||||
package traefikoidc
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BackgroundTask represents a managed recurring task that runs in the background.
|
||||
// It provides a clean interface for starting and stopping periodic operations
|
||||
@@ -11,6 +14,7 @@ type BackgroundTask struct {
|
||||
logger *Logger
|
||||
name string
|
||||
interval time.Duration
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewBackgroundTask creates a new background task with the specified parameters.
|
||||
@@ -20,35 +24,53 @@ type BackgroundTask struct {
|
||||
// - interval: Duration between task executions.
|
||||
// - taskFunc: The function to execute periodically.
|
||||
// - logger: Logger instance for task lifecycle events.
|
||||
// - wg: Optional WaitGroup for synchronizing goroutine completion.
|
||||
//
|
||||
// Returns:
|
||||
// - A configured BackgroundTask ready to be started.
|
||||
func NewBackgroundTask(name string, interval time.Duration, taskFunc func(), logger *Logger) *BackgroundTask {
|
||||
func NewBackgroundTask(name string, interval time.Duration, taskFunc func(), logger *Logger, wg ...*sync.WaitGroup) *BackgroundTask {
|
||||
var waitGroup *sync.WaitGroup
|
||||
if len(wg) > 0 {
|
||||
waitGroup = wg[0]
|
||||
}
|
||||
return &BackgroundTask{
|
||||
name: name,
|
||||
interval: interval,
|
||||
stopChan: make(chan struct{}),
|
||||
taskFunc: taskFunc,
|
||||
logger: logger,
|
||||
wg: waitGroup,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the background task execution in a separate goroutine.
|
||||
// The task runs immediately upon start and then at the specified interval.
|
||||
func (bt *BackgroundTask) Start() {
|
||||
if bt.wg != nil {
|
||||
bt.wg.Add(1)
|
||||
}
|
||||
go bt.run()
|
||||
}
|
||||
|
||||
// Stop gracefully terminates the background task by closing the stop channel.
|
||||
// If a WaitGroup was provided, it waits for the goroutine to complete.
|
||||
// This method is safe to call multiple times.
|
||||
func (bt *BackgroundTask) Stop() {
|
||||
close(bt.stopChan)
|
||||
if bt.wg != nil {
|
||||
bt.wg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// run is the main execution loop for the background task.
|
||||
// It executes the task function immediately and then at regular intervals
|
||||
// until the stop signal is received.
|
||||
func (bt *BackgroundTask) run() {
|
||||
defer func() {
|
||||
if bt.wg != nil {
|
||||
bt.wg.Done()
|
||||
}
|
||||
}()
|
||||
ticker := time.NewTicker(bt.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ func NewCacheWithLogger(logger *Logger) *Cache {
|
||||
order: list.New(),
|
||||
elems: make(map[string]*list.Element, DefaultMaxSize),
|
||||
maxSize: DefaultMaxSize,
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
autoCleanupInterval: 15 * time.Minute, // Increased from 5 minutes to reduce overhead
|
||||
logger: logger,
|
||||
}
|
||||
c.startAutoCleanup()
|
||||
@@ -155,16 +155,21 @@ func (c *Cache) Cleanup() {
|
||||
}
|
||||
|
||||
// evictOldest removes the least recently used (oldest) item from the cache.
|
||||
// It first attempts to find and remove an expired item from the front of the LRU list.
|
||||
// If no expired items are found at the front, it removes the absolute oldest item (front of the list).
|
||||
// It first attempts to find and remove expired items, checking up to 5 items
|
||||
// from the front of the LRU list for efficiency. If no expired items are found,
|
||||
// it removes the absolute oldest item (front of the list).
|
||||
// This method is called internally by Set when the cache reaches its maximum size.
|
||||
// Note: This function assumes the write lock is already held.
|
||||
func (c *Cache) evictOldest() {
|
||||
now := time.Now()
|
||||
elem := c.order.Front()
|
||||
|
||||
// First try to find an expired item from the front
|
||||
for elem != nil {
|
||||
// Check up to 5 items from the front for expired entries
|
||||
// This limits the search overhead while still finding expired items efficiently
|
||||
const maxExpiredCheck = 5
|
||||
checked := 0
|
||||
|
||||
for elem != nil && checked < maxExpiredCheck {
|
||||
entry := elem.Value.(lruEntry)
|
||||
if item, exists := c.items[entry.key]; exists {
|
||||
if now.After(item.ExpiresAt) {
|
||||
@@ -173,9 +178,10 @@ func (c *Cache) evictOldest() {
|
||||
}
|
||||
}
|
||||
elem = elem.Next()
|
||||
checked++
|
||||
}
|
||||
|
||||
// If no expired items found, remove the oldest item
|
||||
// If no expired items found in the first few entries, remove the oldest item
|
||||
if elem = c.order.Front(); elem != nil {
|
||||
entry := elem.Value.(lruEntry)
|
||||
c.removeItem(entry.key)
|
||||
|
||||
@@ -0,0 +1,336 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestCacheMemoryLeaks tests various cache scenarios for memory leaks
|
||||
func TestCacheMemoryLeaks(t *testing.T) {
|
||||
t.Run("Cache doesn't release expired items memory", func(t *testing.T) {
|
||||
runtime.GC()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
baselineAlloc := m.Alloc
|
||||
|
||||
cache := NewCache()
|
||||
defer cache.Close()
|
||||
|
||||
// Add many large items with short expiration
|
||||
largeData := make([]byte, 1024*1024) // 1MB
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
// Items expire in 1 second
|
||||
cache.Set(key, largeData, 1*time.Second)
|
||||
}
|
||||
|
||||
// Wait for items to expire
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Force cleanup
|
||||
cache.Cleanup()
|
||||
|
||||
// Check memory after cleanup
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
afterCleanupAlloc := m.Alloc
|
||||
|
||||
allocIncrease := float64(afterCleanupAlloc-baselineAlloc) / 1024 / 1024
|
||||
t.Logf("Memory after adding and expiring 100MB of data: %.2f MB", allocIncrease)
|
||||
|
||||
// Memory should be mostly released after cleanup
|
||||
if allocIncrease > 10.0 {
|
||||
t.Errorf("Cache retains too much memory after cleanup: %.2f MB", allocIncrease)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Token blacklist unbounded growth", func(t *testing.T) {
|
||||
runtime.GC()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
baselineAlloc := m.Alloc
|
||||
|
||||
blacklist := NewCache()
|
||||
blacklist.SetMaxSize(1000) // Limit size
|
||||
defer blacklist.Close()
|
||||
|
||||
// Simulate continuous token blacklisting
|
||||
for i := 0; i < 10000; i++ {
|
||||
token := fmt.Sprintf("token-%d", i)
|
||||
// All tokens expire in 24 hours (typical blacklist duration)
|
||||
blacklist.Set(token, true, 24*time.Hour)
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
currentAlloc := m.Alloc
|
||||
|
||||
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
|
||||
t.Logf("Memory after adding 10000 blacklisted tokens (max 1000): %.2f MB", allocIncrease)
|
||||
|
||||
// Should respect max size limit
|
||||
if len(blacklist.items) > 1000 {
|
||||
t.Errorf("Blacklist exceeded max size: %d items", len(blacklist.items))
|
||||
}
|
||||
|
||||
// Memory should be bounded
|
||||
if allocIncrease > 5.0 {
|
||||
t.Errorf("Blacklist uses too much memory: %.2f MB for max 1000 items", allocIncrease)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Replay cache with high JTI volume", func(t *testing.T) {
|
||||
initReplayCache()
|
||||
defer cleanupReplayCache()
|
||||
|
||||
runtime.GC()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
baselineAlloc := m.Alloc
|
||||
|
||||
// Simulate high volume of JTIs
|
||||
for i := 0; i < 20000; i++ {
|
||||
jti := fmt.Sprintf("jti-%d", i)
|
||||
replayCacheMu.Lock()
|
||||
if replayCache != nil {
|
||||
// JTIs expire after token expiry (typically 1 hour)
|
||||
replayCache.Set(jti, true, 1*time.Hour)
|
||||
}
|
||||
replayCacheMu.Unlock()
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
currentAlloc := m.Alloc
|
||||
|
||||
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
|
||||
t.Logf("Memory after adding 20000 JTIs (max 10000): %.2f MB", allocIncrease)
|
||||
|
||||
// Check size limit is enforced
|
||||
replayCacheMu.RLock()
|
||||
cacheSize := 0
|
||||
if replayCache != nil {
|
||||
cacheSize = len(replayCache.items)
|
||||
}
|
||||
replayCacheMu.RUnlock()
|
||||
|
||||
if cacheSize > 10000 {
|
||||
t.Errorf("Replay cache exceeded max size: %d items", cacheSize)
|
||||
}
|
||||
|
||||
// Memory should be bounded
|
||||
if allocIncrease > 10.0 {
|
||||
t.Errorf("Replay cache uses too much memory: %.2f MB for max 10000 items", allocIncrease)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Cache cleanup interval effectiveness", func(t *testing.T) {
|
||||
// Create a cache with custom settings - don't use NewCache to avoid default cleanup
|
||||
cache := &Cache{
|
||||
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||
order: list.New(),
|
||||
elems: make(map[string]*list.Element, DefaultMaxSize),
|
||||
maxSize: DefaultMaxSize,
|
||||
autoCleanupInterval: 200 * time.Millisecond, // Fast cleanup for test
|
||||
logger: newNoOpLogger(),
|
||||
}
|
||||
|
||||
// Start cleanup with our custom interval
|
||||
cache.startAutoCleanup()
|
||||
defer cache.Close()
|
||||
|
||||
// Add expired items
|
||||
for i := 0; i < 1000; i++ {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
cache.Set(key, "data", 50*time.Millisecond) // Very short expiry
|
||||
}
|
||||
|
||||
// Wait for items to expire and cleanup to run (at least 2 cleanup cycles)
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
// Manually trigger cleanup to ensure it runs
|
||||
cache.Cleanup()
|
||||
|
||||
// Check that expired items are removed
|
||||
cache.mutex.RLock()
|
||||
remainingItems := len(cache.items)
|
||||
cache.mutex.RUnlock()
|
||||
|
||||
t.Logf("Remaining items after auto cleanup: %d", remainingItems)
|
||||
|
||||
if remainingItems > 100 {
|
||||
t.Errorf("Auto cleanup not effective: %d items remain", remainingItems)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Concurrent cache operations memory stability", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
defer cache.Close()
|
||||
|
||||
runtime.GC()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
baselineAlloc := m.Alloc
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
// Writers continuously add items
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 1000; j++ {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
key := fmt.Sprintf("writer-%d-%d", id, j)
|
||||
cache.Set(key, "data", 1*time.Second)
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Readers continuously read items
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 1000; j++ {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
key := fmt.Sprintf("writer-%d-%d", id%5, j)
|
||||
cache.Get(key)
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Let it run for a bit
|
||||
time.Sleep(5 * time.Second)
|
||||
close(stop)
|
||||
wg.Wait()
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
finalAlloc := m.Alloc
|
||||
|
||||
// Handle potential underflow
|
||||
var allocIncrease float64
|
||||
if finalAlloc > baselineAlloc {
|
||||
allocIncrease = float64(finalAlloc-baselineAlloc) / 1024 / 1024
|
||||
} else {
|
||||
allocIncrease = -float64(baselineAlloc-finalAlloc) / 1024 / 1024
|
||||
}
|
||||
t.Logf("Memory increase under concurrent load: %.2f MB", allocIncrease)
|
||||
|
||||
if allocIncrease > 5.0 {
|
||||
t.Errorf("Memory leak under concurrent operations: %.2f MB", allocIncrease)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LRU eviction memory release", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
cache.SetMaxSize(100) // Small cache
|
||||
defer cache.Close()
|
||||
|
||||
runtime.GC()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
baselineAlloc := m.Alloc
|
||||
|
||||
// Add many items to trigger eviction
|
||||
for i := 0; i < 1000; i++ {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
data := make([]byte, 10240) // 10KB per item
|
||||
cache.Set(key, data, 1*time.Hour)
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
afterEvictionAlloc := m.Alloc
|
||||
|
||||
allocIncrease := float64(afterEvictionAlloc-baselineAlloc) / 1024 / 1024
|
||||
t.Logf("Memory after LRU eviction (1000 items, max 100): %.2f MB", allocIncrease)
|
||||
|
||||
// Should only keep 100 items worth of memory
|
||||
if allocIncrease > 2.0 { // 100 * 10KB = ~1MB
|
||||
t.Errorf("LRU eviction doesn't release memory properly: %.2f MB", allocIncrease)
|
||||
}
|
||||
|
||||
// Verify cache size
|
||||
if len(cache.items) > 100 {
|
||||
t.Errorf("Cache size exceeded limit: %d items", len(cache.items))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Token cache with claims memory", func(t *testing.T) {
|
||||
tokenCache := NewTokenCache()
|
||||
defer tokenCache.Close()
|
||||
|
||||
runtime.GC()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
baselineAlloc := m.Alloc
|
||||
|
||||
// Add tokens with large claims
|
||||
for i := 0; i < 1000; i++ {
|
||||
token := fmt.Sprintf("token-%d", i)
|
||||
claims := map[string]interface{}{
|
||||
"sub": fmt.Sprintf("user-%d", i),
|
||||
"email": fmt.Sprintf("user%d@example.com", i),
|
||||
"groups": make([]string, 100), // Large groups list
|
||||
"data": make([]byte, 1024), // Extra data
|
||||
}
|
||||
tokenCache.Set(token, claims, 1*time.Hour)
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
currentAlloc := m.Alloc
|
||||
|
||||
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
|
||||
t.Logf("Memory after adding 1000 tokens with large claims: %.2f MB", allocIncrease)
|
||||
|
||||
// Check if memory is reasonable
|
||||
if allocIncrease > 20.0 {
|
||||
t.Errorf("Token cache uses excessive memory: %.2f MB", allocIncrease)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestCacheEvictionBug tests the inefficient eviction in evictOldest
|
||||
func TestCacheEvictionBug(t *testing.T) {
|
||||
t.Run("evictOldest scans entire list", func(t *testing.T) {
|
||||
cache := NewCache()
|
||||
cache.SetMaxSize(100)
|
||||
defer cache.Close()
|
||||
|
||||
// Fill cache with non-expired items
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
cache.Set(key, "data", 1*time.Hour) // Long expiry
|
||||
}
|
||||
|
||||
// Try to add one more item to trigger eviction
|
||||
start := time.Now()
|
||||
cache.Set("trigger", "data", 1*time.Hour)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
t.Logf("Time to evict and add one item: %v", elapsed)
|
||||
|
||||
// Should be fast even with full cache
|
||||
if elapsed > 10*time.Millisecond {
|
||||
t.Errorf("Eviction too slow, possibly scanning entire list: %v", elapsed)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,646 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestBackgroundGoroutineLeaks tests that background goroutines don't leak memory
|
||||
// even when no requests are made to protected resources
|
||||
func TestBackgroundGoroutineLeaks(t *testing.T) {
|
||||
t.Run("Idle middleware memory growth", func(t *testing.T) {
|
||||
// Force GC to get clean baseline
|
||||
runtime.GC()
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
baselineAlloc := m.Alloc
|
||||
baselineGoroutines := runtime.NumGoroutine()
|
||||
|
||||
t.Logf("Baseline: Memory=%d KB, Goroutines=%d", baselineAlloc/1024, baselineGoroutines)
|
||||
|
||||
// Create middleware instance
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://example.com"
|
||||
config.ClientID = "test-client"
|
||||
config.ClientSecret = "test-secret"
|
||||
config.SessionEncryptionKey = "test-encryption-key-that-is-long-enough-32bytes"
|
||||
config.LogLevel = "error" // Reduce noise
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
handler, err := New(ctx, next, config, "test-middleware")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Cast to TraefikOidc to access internals
|
||||
middleware, ok := handler.(*TraefikOidc)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast to TraefikOidc")
|
||||
}
|
||||
|
||||
// Let it sit idle for a while - simulating no requests
|
||||
// During this time, background goroutines are running
|
||||
t.Log("Letting middleware sit idle for 30 seconds...")
|
||||
|
||||
// Take measurements every 5 seconds
|
||||
for i := 0; i < 6; i++ {
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
currentAlloc := m.Alloc
|
||||
currentGoroutines := runtime.NumGoroutine()
|
||||
|
||||
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
|
||||
goroutineIncrease := currentGoroutines - baselineGoroutines
|
||||
|
||||
t.Logf("After %d seconds: Memory increase=%.2f MB, Goroutine increase=%d",
|
||||
(i+1)*5, allocIncrease, goroutineIncrease)
|
||||
|
||||
// Check for significant memory growth (more than 5MB)
|
||||
if allocIncrease > 5.0 {
|
||||
t.Errorf("Significant memory increase detected: %.2f MB after %d seconds of idle",
|
||||
allocIncrease, (i+1)*5)
|
||||
}
|
||||
|
||||
// Check for goroutine leaks (more than 10 extra goroutines)
|
||||
if goroutineIncrease > 10 {
|
||||
t.Errorf("Goroutine leak detected: %d extra goroutines after %d seconds",
|
||||
goroutineIncrease, (i+1)*5)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up
|
||||
err = middleware.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to close middleware: %v", err)
|
||||
}
|
||||
|
||||
// Wait for cleanup
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Final check
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
finalAlloc := m.Alloc
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
|
||||
finalAllocIncrease := float64(finalAlloc-baselineAlloc) / 1024 / 1024
|
||||
finalGoroutineIncrease := finalGoroutines - baselineGoroutines
|
||||
|
||||
t.Logf("Final: Memory increase=%.2f MB, Goroutine increase=%d",
|
||||
finalAllocIncrease, finalGoroutineIncrease)
|
||||
|
||||
if finalGoroutineIncrease > 2 {
|
||||
t.Errorf("Goroutines not cleaned up properly: %d extra goroutines remain",
|
||||
finalGoroutineIncrease)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestHTTPClientConnectionLeaks tests that HTTP clients don't leak connections
|
||||
func TestHTTPClientConnectionLeaks(t *testing.T) {
|
||||
t.Run("HTTP client connection accumulation", func(t *testing.T) {
|
||||
// Create test server that simulates OIDC endpoints
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"jwks_uri": "https://example.com/jwks",
|
||||
"userinfo_endpoint": "https://example.com/userinfo"
|
||||
}`))
|
||||
case "/jwks":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"keys": []}`))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Monitor connection count
|
||||
getActiveConnections := func(client *http.Client) int {
|
||||
if transport, ok := client.Transport.(*http.Transport); ok {
|
||||
// This is a simplified check - in reality we'd need to inspect
|
||||
// the transport's connection pool
|
||||
return transport.MaxIdleConnsPerHost
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Create multiple HTTP clients like the middleware does
|
||||
clients := make([]*http.Client, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
clients[i] = createDefaultHTTPClient()
|
||||
|
||||
// Make requests
|
||||
resp, err := clients[i].Get(server.URL + "/.well-known/openid-configuration")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Check connection settings
|
||||
conns := getActiveConnections(clients[i])
|
||||
if conns > 1 {
|
||||
t.Logf("Client %d has %d max idle connections per host", i, conns)
|
||||
}
|
||||
}
|
||||
|
||||
// Let connections sit idle (now only 5s with our optimizations)
|
||||
time.Sleep(6 * time.Second) // Slightly longer than new IdleConnTimeout (5s)
|
||||
|
||||
// Force cleanup
|
||||
for _, client := range clients {
|
||||
if transport, ok := client.Transport.(*http.Transport); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("HTTP clients cleaned up successfully")
|
||||
})
|
||||
}
|
||||
|
||||
// TestCacheBackgroundTaskLeaks tests that cache background tasks don't leak
|
||||
func TestCacheBackgroundTaskLeaks(t *testing.T) {
|
||||
t.Run("Multiple cache instances with cleanup tasks", func(t *testing.T) {
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
// Create many cache instances
|
||||
caches := make([]*Cache, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
caches[i] = NewCache()
|
||||
// Add some data
|
||||
for j := 0; j < 100; j++ {
|
||||
key := fmt.Sprintf("key-%d-%d", i, j)
|
||||
caches[i].Set(key, "value", 5*time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all cleanup goroutines to start
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
afterCreateGoroutines := runtime.NumGoroutine()
|
||||
goroutineIncrease := afterCreateGoroutines - initialGoroutines
|
||||
|
||||
t.Logf("Created %d caches, goroutine increase: %d", len(caches), goroutineIncrease)
|
||||
|
||||
// Expected: one cleanup goroutine per cache
|
||||
if goroutineIncrease < len(caches) {
|
||||
t.Errorf("Expected at least %d goroutines, got %d increase",
|
||||
len(caches), goroutineIncrease)
|
||||
}
|
||||
|
||||
// Close all caches
|
||||
for _, cache := range caches {
|
||||
cache.Close()
|
||||
}
|
||||
|
||||
// Wait for goroutines to stop
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
remainingGoroutines := finalGoroutines - initialGoroutines
|
||||
|
||||
if remainingGoroutines > 5 { // Allow small tolerance
|
||||
t.Errorf("Cache cleanup goroutines not stopped properly: %d extra goroutines remain",
|
||||
remainingGoroutines)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestGlobalSingletonMemoryGrowth tests that global singletons don't grow unbounded
|
||||
func TestGlobalSingletonMemoryGrowth(t *testing.T) {
|
||||
t.Run("Global cache manager memory growth", func(t *testing.T) {
|
||||
// Clean up any existing global state
|
||||
CleanupGlobalCacheManager()
|
||||
CleanupGlobalMemoryPools()
|
||||
|
||||
runtime.GC()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
baselineAlloc := m.Alloc
|
||||
|
||||
// Get global cache manager
|
||||
wg := &sync.WaitGroup{}
|
||||
cm := GetGlobalCacheManager(wg)
|
||||
|
||||
// Simulate continuous usage without cleanup
|
||||
for i := 0; i < 1000; i++ {
|
||||
// Add to token cache
|
||||
cm.GetSharedTokenCache().Set(
|
||||
fmt.Sprintf("token-%d", i),
|
||||
map[string]interface{}{"claim": fmt.Sprintf("value-%d", i)},
|
||||
5*time.Minute,
|
||||
)
|
||||
|
||||
// Add to blacklist
|
||||
cm.GetSharedTokenBlacklist().Set(
|
||||
fmt.Sprintf("blacklist-%d", i),
|
||||
true,
|
||||
5*time.Minute,
|
||||
)
|
||||
|
||||
// Every 100 iterations, check memory
|
||||
if i%100 == 0 {
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
currentAlloc := m.Alloc
|
||||
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
|
||||
|
||||
t.Logf("After %d items: Memory increase=%.2f MB", i, allocIncrease)
|
||||
|
||||
// The caches should have max size limits
|
||||
// If memory grows more than 10MB, there's likely a leak
|
||||
if allocIncrease > 10.0 {
|
||||
t.Errorf("Excessive memory growth in global caches: %.2f MB after %d items",
|
||||
allocIncrease, i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
CleanupGlobalCacheManager()
|
||||
CleanupGlobalMemoryPools()
|
||||
wg.Wait()
|
||||
|
||||
// Force full cleanup of replay cache too
|
||||
cleanupReplayCache()
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
finalAlloc := m.Alloc
|
||||
finalAllocIncrease := float64(finalAlloc-baselineAlloc) / 1024 / 1024
|
||||
|
||||
t.Logf("Final memory increase after cleanup: %.2f MB", finalAllocIncrease)
|
||||
|
||||
if finalAllocIncrease > 2.0 {
|
||||
t.Errorf("Memory not properly released after cleanup: %.2f MB remains",
|
||||
finalAllocIncrease)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMetadataCacheRefreshLeak tests for memory leaks in metadata refresh
|
||||
func TestMetadataCacheRefreshLeak(t *testing.T) {
|
||||
t.Run("Metadata cache refresh memory leak", func(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
cache := NewMetadataCacheWithLogger(wg, NewLogger("error"))
|
||||
defer cache.Close()
|
||||
|
||||
// Mock HTTP client that returns metadata
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Return a large metadata response to amplify any leaks
|
||||
w.Write([]byte(`{
|
||||
"issuer": "https://example.com",
|
||||
"authorization_endpoint": "https://example.com/auth",
|
||||
"token_endpoint": "https://example.com/token",
|
||||
"jwks_uri": "https://example.com/jwks",
|
||||
"userinfo_endpoint": "https://example.com/userinfo",
|
||||
"extra_field_1": "` + strings.Repeat("x", 1000) + `",
|
||||
"extra_field_2": "` + strings.Repeat("y", 1000) + `"
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
runtime.GC()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
baselineAlloc := m.Alloc
|
||||
|
||||
// Simulate many metadata refreshes
|
||||
for i := 0; i < 100; i++ {
|
||||
_, err := cache.GetMetadata(server.URL, client, NewLogger("error"))
|
||||
if err != nil {
|
||||
t.Logf("Metadata fetch error (expected for test server): %v", err)
|
||||
}
|
||||
|
||||
// Force cache expiry to trigger refresh
|
||||
cache.mutex.Lock()
|
||||
cache.expiresAt = time.Now().Add(-1 * time.Hour)
|
||||
cache.mutex.Unlock()
|
||||
|
||||
if i%20 == 0 && i > 0 {
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
currentAlloc := m.Alloc
|
||||
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
|
||||
|
||||
t.Logf("After %d refreshes: Memory increase=%.2f MB", i, allocIncrease)
|
||||
|
||||
// Metadata cache should only store one copy
|
||||
if allocIncrease > 3.0 {
|
||||
t.Errorf("Metadata cache leak detected: %.2f MB after %d refreshes",
|
||||
allocIncrease, i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cache.Close()
|
||||
wg.Wait()
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
finalAlloc := m.Alloc
|
||||
finalAllocIncrease := float64(finalAlloc-baselineAlloc) / 1024 / 1024
|
||||
|
||||
t.Logf("Final memory after metadata cache closure: %.2f MB", finalAllocIncrease)
|
||||
|
||||
if finalAllocIncrease > 1.0 {
|
||||
t.Errorf("Metadata cache not cleaned properly: %.2f MB remains", finalAllocIncrease)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryPoolLeak tests for leaks in memory pool management
|
||||
func TestMemoryPoolLeak(t *testing.T) {
|
||||
t.Run("Memory pool buffer leaks", func(t *testing.T) {
|
||||
// Clean up any existing pools
|
||||
CleanupGlobalMemoryPools()
|
||||
|
||||
pools := GetGlobalMemoryPools()
|
||||
|
||||
runtime.GC()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
baselineAlloc := m.Alloc
|
||||
|
||||
// Simulate heavy buffer usage
|
||||
var buffers [][]byte
|
||||
for i := 0; i < 1000; i++ {
|
||||
buf := pools.GetHTTPResponseBuffer()
|
||||
// Simulate using the buffer
|
||||
copy(buf, []byte("test data"))
|
||||
|
||||
// 90% of the time, return the buffer
|
||||
// 10% of the time, "forget" to return it (simulating a leak)
|
||||
if i%10 != 0 {
|
||||
pools.PutHTTPResponseBuffer(buf)
|
||||
} else {
|
||||
// Keep reference to simulate leak
|
||||
buffers = append(buffers, buf)
|
||||
}
|
||||
|
||||
if i%100 == 0 && i > 0 {
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
currentAlloc := m.Alloc
|
||||
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
|
||||
|
||||
t.Logf("After %d buffer operations: Memory increase=%.2f MB, Leaked buffers=%d",
|
||||
i, allocIncrease, len(buffers))
|
||||
|
||||
// With proper pooling, memory should be bounded
|
||||
if allocIncrease > 5.0 {
|
||||
t.Errorf("Memory pool leak detected: %.2f MB after %d operations",
|
||||
allocIncrease, i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return the "leaked" buffers
|
||||
for _, buf := range buffers {
|
||||
pools.PutHTTPResponseBuffer(buf)
|
||||
}
|
||||
|
||||
CleanupGlobalMemoryPools()
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
finalAlloc := m.Alloc
|
||||
finalAllocIncrease := float64(finalAlloc-baselineAlloc) / 1024 / 1024
|
||||
|
||||
t.Logf("Final memory after pool cleanup: %.2f MB", finalAllocIncrease)
|
||||
|
||||
if finalAllocIncrease > 1.0 {
|
||||
t.Errorf("Memory pools not cleaned properly: %.2f MB remains", finalAllocIncrease)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConcurrentMemoryLeaks tests for memory leaks under concurrent load
|
||||
func TestConcurrentMemoryLeaks(t *testing.T) {
|
||||
t.Run("Concurrent operations memory stability", func(t *testing.T) {
|
||||
// Reset global state
|
||||
resetGlobalState()
|
||||
|
||||
// Set lower GC percentage to trigger GC more frequently
|
||||
debug.SetGCPercent(50)
|
||||
defer debug.SetGCPercent(100)
|
||||
|
||||
runtime.GC()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
baselineAlloc := m.Alloc
|
||||
baselineGoroutines := runtime.NumGoroutine()
|
||||
|
||||
// Create a mock OIDC server
|
||||
var mockServerURL string
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"issuer": mockServerURL,
|
||||
"authorization_endpoint": mockServerURL + "/auth",
|
||||
"token_endpoint": mockServerURL + "/token",
|
||||
"userinfo_endpoint": mockServerURL + "/userinfo",
|
||||
"jwks_uri": mockServerURL + "/keys",
|
||||
})
|
||||
case "/keys":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"keys": []map[string]interface{}{
|
||||
{
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"kid": "test-key",
|
||||
"n": "test-modulus",
|
||||
"e": "AQAB",
|
||||
},
|
||||
},
|
||||
})
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
mockServerURL = mockServer.URL
|
||||
|
||||
// Create middleware config with mock server
|
||||
config := createTestConfig()
|
||||
config.ProviderURL = mockServerURL
|
||||
config.LogLevel = "error"
|
||||
|
||||
// Create middleware directly without using New() to avoid automatic metadata fetch
|
||||
middleware := &TraefikOidc{
|
||||
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }),
|
||||
providerURL: config.ProviderURL,
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
redirURLPath: "/callback",
|
||||
scopes: []string{"openid", "email", "profile"},
|
||||
logger: NewLogger(config.LogLevel),
|
||||
excludedURLs: make(map[string]struct{}),
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
sessionManager: nil, // Will be created below
|
||||
tokenCache: NewTokenCache(),
|
||||
tokenBlacklist: NewCache(),
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
firstRequestMutex: sync.Mutex{},
|
||||
firstRequestReceived: false,
|
||||
}
|
||||
|
||||
// Create session manager
|
||||
var err error
|
||||
middleware.sessionManager, err = NewSessionManager(
|
||||
config.SessionEncryptionKey,
|
||||
config.ForceHTTPS,
|
||||
config.CookieDomain,
|
||||
middleware.logger,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Simulate concurrent operations
|
||||
var wg sync.WaitGroup
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
// Worker that continuously performs operations
|
||||
worker := func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
default:
|
||||
// Simulate various operations
|
||||
switch id % 4 {
|
||||
case 0:
|
||||
// Cache operations
|
||||
cache := NewCache()
|
||||
cache.Set(fmt.Sprintf("key-%d", id), "value", time.Minute)
|
||||
cache.Get(fmt.Sprintf("key-%d", id))
|
||||
cache.Close()
|
||||
case 1:
|
||||
// Session operations
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
session, _ := middleware.sessionManager.GetSession(req)
|
||||
if session != nil {
|
||||
session.SetAccessToken("token")
|
||||
session.returnToPoolSafely()
|
||||
}
|
||||
case 2:
|
||||
// Token cache operations
|
||||
cm := GetGlobalCacheManager(nil)
|
||||
if cm != nil {
|
||||
tc := cm.GetSharedTokenCache()
|
||||
tc.Set("test-token", map[string]interface{}{"test": "data"}, time.Minute)
|
||||
tc.Get("test-token")
|
||||
}
|
||||
case 3:
|
||||
// Memory pool operations
|
||||
pools := GetGlobalMemoryPools()
|
||||
buf := pools.GetHTTPResponseBuffer()
|
||||
pools.PutHTTPResponseBuffer(buf)
|
||||
}
|
||||
|
||||
// Small delay to prevent tight loop
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start workers
|
||||
numWorkers := 20
|
||||
wg.Add(numWorkers)
|
||||
for i := 0; i < numWorkers; i++ {
|
||||
go worker(i)
|
||||
}
|
||||
|
||||
// Let it run and measure periodically
|
||||
for i := 0; i < 5; i++ {
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
currentAlloc := m.Alloc
|
||||
currentGoroutines := runtime.NumGoroutine()
|
||||
|
||||
allocIncrease := float64(currentAlloc-baselineAlloc) / 1024 / 1024
|
||||
goroutineIncrease := currentGoroutines - baselineGoroutines
|
||||
|
||||
t.Logf("After %d seconds of concurrent load: Memory increase=%.2f MB, Goroutine increase=%d",
|
||||
(i+1)*5, allocIncrease, goroutineIncrease)
|
||||
|
||||
// Under sustained concurrent load, memory should stabilize
|
||||
if i > 2 && allocIncrease > 20.0 {
|
||||
t.Errorf("Memory leak under concurrent load: %.2f MB after %d seconds",
|
||||
allocIncrease, (i+1)*5)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop workers
|
||||
close(stopChan)
|
||||
wg.Wait()
|
||||
|
||||
// Clean up
|
||||
middleware.Close()
|
||||
|
||||
// Final measurements
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
finalAlloc := m.Alloc
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
|
||||
finalAllocIncrease := float64(finalAlloc-baselineAlloc) / 1024 / 1024
|
||||
finalGoroutineIncrease := finalGoroutines - baselineGoroutines
|
||||
|
||||
t.Logf("Final after cleanup: Memory increase=%.2f MB, Goroutine increase=%d",
|
||||
finalAllocIncrease, finalGoroutineIncrease)
|
||||
|
||||
if finalGoroutineIncrease > 5 {
|
||||
t.Errorf("Goroutines not cleaned up after concurrent operations: %d extra remain",
|
||||
finalGoroutineIncrease)
|
||||
}
|
||||
|
||||
if finalAllocIncrease > 5.0 {
|
||||
t.Errorf("Memory not released after concurrent operations: %.2f MB remains",
|
||||
finalAllocIncrease)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -64,15 +64,37 @@ func (r *ProviderRegistry) ClearCache() {
|
||||
// DetectProvider determines the most appropriate provider for a given issuer URL.
|
||||
// It iterates through the registered providers and returns the first one that matches.
|
||||
// Detection is based on URL patterns and other provider-specific criteria.
|
||||
// Uses double-checked locking pattern to avoid race conditions while caching results.
|
||||
func (r *ProviderRegistry) DetectProvider(issuerURL string) OIDCProvider {
|
||||
// First check: read lock for cache lookup
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
if provider, found := r.cache[issuerURL]; found {
|
||||
r.mu.RUnlock()
|
||||
return provider
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
|
||||
// Check cache first for performance
|
||||
// Cache miss - acquire write lock for detection and caching
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Second check: another goroutine might have cached the result while we waited for write lock
|
||||
if provider, found := r.cache[issuerURL]; found {
|
||||
return provider
|
||||
}
|
||||
|
||||
// Perform detection under write lock
|
||||
detectedProvider := r.detectProviderUnsafe(issuerURL)
|
||||
|
||||
// Cache the result (even if nil to avoid repeated expensive operations)
|
||||
r.cache[issuerURL] = detectedProvider
|
||||
|
||||
return detectedProvider
|
||||
}
|
||||
|
||||
// detectProviderUnsafe performs the actual provider detection logic.
|
||||
// This method assumes the caller holds the appropriate lock and should not be called directly.
|
||||
func (r *ProviderRegistry) detectProviderUnsafe(issuerURL string) OIDCProvider {
|
||||
// Normalize issuer URL for consistent matching
|
||||
normalizedURL, err := url.Parse(issuerURL)
|
||||
if err != nil {
|
||||
@@ -86,12 +108,10 @@ func (r *ProviderRegistry) DetectProvider(issuerURL string) OIDCProvider {
|
||||
switch p.GetType() {
|
||||
case ProviderTypeGoogle:
|
||||
if strings.Contains(host, "accounts.google.com") {
|
||||
r.cache[issuerURL] = p
|
||||
return p
|
||||
}
|
||||
case ProviderTypeAzure:
|
||||
if strings.Contains(host, "login.microsoftonline.com") || strings.Contains(host, "sts.windows.net") {
|
||||
r.cache[issuerURL] = p
|
||||
return p
|
||||
}
|
||||
}
|
||||
@@ -100,7 +120,6 @@ func (r *ProviderRegistry) DetectProvider(issuerURL string) OIDCProvider {
|
||||
// Fallback to the generic provider if no specific provider is detected
|
||||
for _, p := range r.providers {
|
||||
if p.GetType() == ProviderTypeGeneric {
|
||||
r.cache[issuerURL] = p
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,27 +36,27 @@ func createDefaultHTTPClient() *http.Client {
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second, // Connection timeout for faster failure detection
|
||||
KeepAlive: 30 * time.Second, // Keep-alive interval for connection reuse
|
||||
Timeout: 5 * time.Second, // Reduced connection timeout for faster failure detection
|
||||
KeepAlive: 15 * time.Second, // Reduced keep-alive interval for connection reuse
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 3 * time.Second, // TLS handshake timeout
|
||||
ExpectContinueTimeout: 1 * time.Second, // Timeout for 100-continue responses
|
||||
MaxIdleConns: 10, // Reduced from 20 to prevent memory buildup
|
||||
MaxIdleConnsPerHost: 2, // Reduced from 5 to limit per-host connections
|
||||
IdleConnTimeout: 30 * time.Second, // Reduced from 60 to close idle connections faster
|
||||
DisableKeepAlives: false, // Enable connection reuse
|
||||
MaxConnsPerHost: 10, // Reduced from 20 to limit concurrent connections
|
||||
ResponseHeaderTimeout: 5 * time.Second, // Timeout for reading response headers
|
||||
DisableCompression: false, // Enable compression for bandwidth efficiency
|
||||
WriteBufferSize: 4096, // Write buffer size for connections
|
||||
ReadBufferSize: 4096, // Read buffer size for connections
|
||||
TLSHandshakeTimeout: 2 * time.Second, // Reduced TLS handshake timeout
|
||||
ExpectContinueTimeout: 1 * time.Second, // Timeout for 100-continue responses
|
||||
MaxIdleConns: 2, // Reduced from 5 to minimize idle connections
|
||||
MaxIdleConnsPerHost: 1, // Keep minimal idle connections per host
|
||||
IdleConnTimeout: 5 * time.Second, // Reduced from 15s to close idle connections faster
|
||||
DisableKeepAlives: false, // Enable connection reuse
|
||||
MaxConnsPerHost: 2, // Reduced from 5 to limit concurrent connections
|
||||
ResponseHeaderTimeout: 3 * time.Second, // Reduced timeout for reading response headers
|
||||
DisableCompression: false, // Enable compression for bandwidth efficiency
|
||||
WriteBufferSize: 4096, // Write buffer size for connections
|
||||
ReadBufferSize: 4096, // Read buffer size for connections
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Timeout: time.Second * 10, // HTTP client timeout
|
||||
Timeout: time.Second * 5, // Reduced HTTP client timeout to prevent hanging connections
|
||||
Transport: transport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// Limit redirects to prevent redirect loops
|
||||
@@ -124,7 +124,7 @@ type CacheManager struct {
|
||||
// It initializes all cache types on first call with appropriate default settings.
|
||||
// This ensures thread-safe initialization and consistent cache behavior across
|
||||
// the entire application lifecycle.
|
||||
func GetGlobalCacheManager() *CacheManager {
|
||||
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
|
||||
cacheManagerOnce.Do(func() {
|
||||
globalCacheManager = &CacheManager{
|
||||
tokenBlacklist: func() *Cache {
|
||||
@@ -133,7 +133,7 @@ func GetGlobalCacheManager() *CacheManager {
|
||||
return c
|
||||
}(),
|
||||
tokenCache: NewTokenCache(),
|
||||
metadataCache: NewMetadataCache(),
|
||||
metadataCache: NewMetadataCache(wg),
|
||||
jwkCache: &JWKCache{},
|
||||
}
|
||||
})
|
||||
@@ -283,13 +283,17 @@ type TraefikOidc struct {
|
||||
issuerURL string
|
||||
revocationURL string
|
||||
scopes []string
|
||||
goroutineWG sync.WaitGroup
|
||||
goroutineWG *sync.WaitGroup
|
||||
refreshGracePeriod time.Duration
|
||||
shutdownOnce sync.Once
|
||||
forceHTTPS bool
|
||||
enablePKCE bool
|
||||
overrideScopes bool
|
||||
suppressDiagnosticLogs bool
|
||||
firstRequestReceived bool
|
||||
firstRequestMutex sync.Mutex
|
||||
metadataRefreshStarted bool
|
||||
providerURL string
|
||||
}
|
||||
|
||||
// ProviderMetadata represents the OpenID Connect provider's discovery metadata.
|
||||
@@ -690,13 +694,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
} else {
|
||||
httpClient = createDefaultHTTPClient()
|
||||
}
|
||||
cacheManager := GetGlobalCacheManager()
|
||||
goroutineWG := &sync.WaitGroup{}
|
||||
cacheManager := GetGlobalCacheManager(goroutineWG)
|
||||
|
||||
pluginCtx, cancelFunc := context.WithCancel(context.Background())
|
||||
|
||||
t := &TraefikOidc{
|
||||
next: next,
|
||||
name: name,
|
||||
goroutineWG: goroutineWG,
|
||||
redirURLPath: config.CallbackURL,
|
||||
logoutURLPath: func() string {
|
||||
if config.LogoutURL == "" {
|
||||
@@ -771,7 +777,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
|
||||
t.tokenVerifier = t
|
||||
t.jwtVerifier = t
|
||||
t.startTokenCleanup()
|
||||
// Delay startTokenCleanup() until first request
|
||||
t.tokenExchanger = t // Initialize the interface field to self
|
||||
|
||||
// Initialize and parse header templates with safe field access
|
||||
@@ -815,6 +821,9 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
|
||||
startReplayCacheCleanup(pluginCtx, logger)
|
||||
logger.Debugf("TraefikOidc.New: Final t.scopes initialized to: %v", t.scopes)
|
||||
|
||||
// Store provider URL for later use
|
||||
t.providerURL = config.ProviderURL
|
||||
go t.initializeMetadata(config.ProviderURL)
|
||||
|
||||
return t, nil
|
||||
@@ -851,8 +860,8 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
t.logger.Debug("Successfully initialized provider metadata")
|
||||
t.updateMetadataEndpoints(metadata)
|
||||
|
||||
// Start metadata refresh goroutine
|
||||
go t.startMetadataRefresh(providerURL)
|
||||
// Delay metadata refresh goroutine until first request
|
||||
// It will be started in ServeHTTP when needed
|
||||
|
||||
// Only close channel on success
|
||||
close(t.initComplete)
|
||||
@@ -888,11 +897,19 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
// Use longer interval to reduce memory pressure from refresh attempts
|
||||
ticker := time.NewTicker(2 * time.Hour) // Increased from 1 hour
|
||||
t.goroutineWG.Add(1) // Track this goroutine
|
||||
|
||||
// Check if WaitGroup is initialized (it might be nil in tests)
|
||||
if t.goroutineWG != nil {
|
||||
t.goroutineWG.Add(1) // Track this goroutine
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer t.goroutineWG.Done() // Signal completion when goroutine exits
|
||||
defer ticker.Stop() // Ensure ticker is always stopped
|
||||
defer func() {
|
||||
if t.goroutineWG != nil {
|
||||
t.goroutineWG.Done() // Signal completion when goroutine exits
|
||||
}
|
||||
}()
|
||||
defer ticker.Stop() // Ensure ticker is always stopped
|
||||
|
||||
consecutiveFailures := 0
|
||||
for {
|
||||
@@ -1056,6 +1073,23 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
|
||||
// - Authentication state management
|
||||
// - Header injection for authenticated requests
|
||||
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// Start background tasks on first request (except for health checks)
|
||||
if !t.firstRequestReceived && !strings.HasPrefix(req.URL.Path, "/health") {
|
||||
t.firstRequestMutex.Lock()
|
||||
if !t.firstRequestReceived {
|
||||
t.firstRequestReceived = true
|
||||
t.logger.Debug("Starting background tasks on first request")
|
||||
t.startTokenCleanup()
|
||||
|
||||
// Also start metadata refresh if not already started
|
||||
if !t.metadataRefreshStarted && t.providerURL != "" {
|
||||
t.metadataRefreshStarted = true
|
||||
go t.startMetadataRefresh(t.providerURL)
|
||||
}
|
||||
}
|
||||
t.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Wait for provider metadata initialization to complete
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
@@ -2048,11 +2082,17 @@ func (t *TraefikOidc) validateHost(host string) error {
|
||||
// panic recovery to ensure stability.
|
||||
func (t *TraefikOidc) startTokenCleanup() {
|
||||
ticker := time.NewTicker(1 * time.Minute) // Run cleanup every minute
|
||||
t.goroutineWG.Add(1) // Track this goroutine
|
||||
|
||||
// Check if WaitGroup is initialized (it might be nil in tests)
|
||||
if t.goroutineWG != nil {
|
||||
t.goroutineWG.Add(1) // Track this goroutine
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
t.goroutineWG.Done() // Signal completion when goroutine exits
|
||||
ticker.Stop() // Ensure ticker is always stopped
|
||||
if t.goroutineWG != nil {
|
||||
t.goroutineWG.Done() // Signal completion when goroutine exits
|
||||
}
|
||||
ticker.Stop() // Ensure ticker is always stopped
|
||||
|
||||
// CRITICAL: Recover from panics to prevent middleware crashes
|
||||
if r := recover(); r != nil {
|
||||
@@ -2838,17 +2878,22 @@ func (t *TraefikOidc) Close() error {
|
||||
t.logger.Debug("metadataRefreshStopChan closed")
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
t.goroutineWG.Wait()
|
||||
close(done)
|
||||
}()
|
||||
// Only wait for goroutines if WaitGroup is initialized
|
||||
if t.goroutineWG != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
t.goroutineWG.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.logger.Debug("All background goroutines stopped gracefully")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.logger.Errorf("Timeout waiting for background goroutines to stop")
|
||||
select {
|
||||
case <-done:
|
||||
t.logger.Debug("All background goroutines stopped gracefully")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.logger.Errorf("Timeout waiting for background goroutines to stop")
|
||||
}
|
||||
} else {
|
||||
t.logger.Debug("No goroutineWG to wait for (likely in test)")
|
||||
}
|
||||
|
||||
if t.httpClient != nil {
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LazyBackgroundTask wraps BackgroundTask to start only when needed
|
||||
type LazyBackgroundTask struct {
|
||||
*BackgroundTask
|
||||
started bool
|
||||
startOnce sync.Once
|
||||
}
|
||||
|
||||
// NewLazyBackgroundTask creates a background task that doesn't start immediately
|
||||
func NewLazyBackgroundTask(name string, interval time.Duration, taskFunc func(), logger *Logger, wg ...*sync.WaitGroup) *LazyBackgroundTask {
|
||||
return &LazyBackgroundTask{
|
||||
BackgroundTask: NewBackgroundTask(name, interval, taskFunc, logger, wg...),
|
||||
started: false,
|
||||
}
|
||||
}
|
||||
|
||||
// StartIfNeeded starts the task only if it hasn't been started yet
|
||||
func (lt *LazyBackgroundTask) StartIfNeeded() {
|
||||
lt.startOnce.Do(func() {
|
||||
if !lt.started {
|
||||
lt.BackgroundTask.Start()
|
||||
lt.started = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Stop stops the task if it was started
|
||||
func (lt *LazyBackgroundTask) Stop() {
|
||||
if lt.started {
|
||||
lt.BackgroundTask.Stop()
|
||||
lt.started = false
|
||||
lt.startOnce = sync.Once{} // Reset for potential restart
|
||||
}
|
||||
}
|
||||
|
||||
// NewLazyCacheWithLogger creates a cache that doesn't start cleanup until first use
|
||||
func NewLazyCacheWithLogger(logger *Logger) *Cache {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
c := &Cache{
|
||||
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||
order: list.New(),
|
||||
elems: make(map[string]*list.Element, DefaultMaxSize),
|
||||
maxSize: DefaultMaxSize,
|
||||
autoCleanupInterval: 10 * time.Minute, // Increased from 5 minutes
|
||||
logger: logger,
|
||||
}
|
||||
// Don't start cleanup immediately - it will be started on first use
|
||||
return c
|
||||
}
|
||||
|
||||
// NewLazyCache creates a cache that doesn't start cleanup immediately
|
||||
func NewLazyCache() *Cache {
|
||||
return NewLazyCacheWithLogger(nil)
|
||||
}
|
||||
|
||||
// CleanupIdleConnections periodically closes idle HTTP connections
|
||||
func CleanupIdleConnections(client *http.Client, interval time.Duration, stopChan <-chan struct{}) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if transport, ok := client.Transport.(*http.Transport); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
case <-stopChan:
|
||||
// Final cleanup
|
||||
if transport, ok := client.Transport.(*http.Transport); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OptimizedMiddlewareConfig provides configuration for memory-optimized middleware
|
||||
type OptimizedMiddlewareConfig struct {
|
||||
// DelayBackgroundTasks delays starting background tasks until first request
|
||||
DelayBackgroundTasks bool
|
||||
|
||||
// ReducedCleanupIntervals uses longer intervals for cleanup tasks
|
||||
ReducedCleanupIntervals bool
|
||||
|
||||
// AggressiveConnectionCleanup closes idle connections more aggressively
|
||||
AggressiveConnectionCleanup bool
|
||||
|
||||
// MinimalCacheSize uses smaller default cache sizes
|
||||
MinimalCacheSize bool
|
||||
}
|
||||
|
||||
// DefaultOptimizedConfig returns a configuration optimized for low memory usage
|
||||
func DefaultOptimizedConfig() *OptimizedMiddlewareConfig {
|
||||
return &OptimizedMiddlewareConfig{
|
||||
DelayBackgroundTasks: true,
|
||||
ReducedCleanupIntervals: true,
|
||||
AggressiveConnectionCleanup: true,
|
||||
MinimalCacheSize: true,
|
||||
}
|
||||
}
|
||||
+2
-2
@@ -48,7 +48,7 @@ func TestMemoryLeakFixes(t *testing.T) {
|
||||
|
||||
t.Run("Global cache manager cleanup", func(t *testing.T) {
|
||||
// Get the global cache manager
|
||||
cm := GetGlobalCacheManager()
|
||||
cm := GetGlobalCacheManager(nil)
|
||||
if cm == nil {
|
||||
t.Fatal("Failed to get global cache manager")
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func TestMemoryLeakFixes(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify it can be re-initialized
|
||||
cm2 := GetGlobalCacheManager()
|
||||
cm2 := GetGlobalCacheManager(nil)
|
||||
if cm2 == nil {
|
||||
t.Fatal("Failed to re-initialize global cache manager")
|
||||
}
|
||||
|
||||
+9
-5
@@ -19,23 +19,27 @@ type MetadataCache struct {
|
||||
logger *Logger
|
||||
autoCleanupInterval time.Duration
|
||||
mutex sync.RWMutex
|
||||
wg *sync.WaitGroup
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewMetadataCache creates a new MetadataCache instance.
|
||||
// It initializes the cache structure and starts the background cleanup task.
|
||||
func NewMetadataCache() *MetadataCache {
|
||||
return NewMetadataCacheWithLogger(nil)
|
||||
func NewMetadataCache(wg *sync.WaitGroup) *MetadataCache {
|
||||
return NewMetadataCacheWithLogger(wg, nil)
|
||||
}
|
||||
|
||||
// NewMetadataCacheWithLogger creates a new MetadataCache with a specified logger.
|
||||
func NewMetadataCacheWithLogger(logger *Logger) *MetadataCache {
|
||||
func NewMetadataCacheWithLogger(wg *sync.WaitGroup, logger *Logger) *MetadataCache {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
c := &MetadataCache{
|
||||
autoCleanupInterval: 5 * time.Minute,
|
||||
autoCleanupInterval: 30 * time.Minute, // Increased from 5 minutes since metadata changes rarely
|
||||
logger: logger,
|
||||
wg: wg,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
c.startAutoCleanup()
|
||||
return c
|
||||
@@ -200,7 +204,7 @@ func (c *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client,
|
||||
// startAutoCleanup starts the background task that periodically calls Cleanup
|
||||
// to remove expired metadata from the cache.
|
||||
func (c *MetadataCache) startAutoCleanup() {
|
||||
c.cleanupTask = NewBackgroundTask("metadata-cache-cleanup", c.autoCleanupInterval, c.Cleanup, c.logger)
|
||||
c.cleanupTask = NewBackgroundTask("metadata-cache-cleanup", c.autoCleanupInterval, c.Cleanup, c.logger, c.wg)
|
||||
c.cleanupTask.Start()
|
||||
}
|
||||
|
||||
|
||||
+852
@@ -0,0 +1,852 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryProfiler defines the interface for memory profiling operations
|
||||
type MemoryProfiler interface {
|
||||
// TakeSnapshot captures current memory statistics
|
||||
TakeSnapshot() (*MemorySnapshot, error)
|
||||
|
||||
// StartProfiling begins memory profiling with specified configuration
|
||||
StartProfiling(config ProfilingConfig) error
|
||||
|
||||
// StopProfiling ends memory profiling and returns final snapshot
|
||||
StopProfiling() (*MemorySnapshot, error)
|
||||
|
||||
// GetCurrentStats returns current runtime memory statistics
|
||||
GetCurrentStats() *runtime.MemStats
|
||||
|
||||
// AnalyzeLeaks performs leak detection analysis
|
||||
AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis
|
||||
}
|
||||
|
||||
// MemorySnapshot represents a point-in-time capture of memory statistics
|
||||
type MemorySnapshot struct {
|
||||
Timestamp time.Time
|
||||
RuntimeStats runtime.MemStats
|
||||
HeapProfile []byte
|
||||
GoroutineProfile []byte
|
||||
CustomMetrics map[string]interface{}
|
||||
}
|
||||
|
||||
// LeakAnalysis contains the results of memory leak detection
|
||||
type LeakAnalysis struct {
|
||||
HasLeak bool
|
||||
LeakDescription string
|
||||
MemoryIncrease uint64
|
||||
GoroutineIncrease int
|
||||
SuspectedLeaks []string
|
||||
Recommendations []string
|
||||
}
|
||||
|
||||
// ProfilingManager coordinates memory profiling operations
|
||||
type ProfilingManager struct {
|
||||
mu sync.RWMutex
|
||||
isProfiling bool
|
||||
startTime time.Time
|
||||
baselineSnapshot *MemorySnapshot
|
||||
config ProfilingConfig
|
||||
logger *Logger
|
||||
profilers map[string]MemoryProfiler
|
||||
}
|
||||
|
||||
// ProfilingConfig contains configuration for profiling operations
|
||||
type ProfilingConfig struct {
|
||||
EnableHeapProfiling bool
|
||||
EnableGoroutineProfiling bool
|
||||
SnapshotInterval time.Duration
|
||||
LeakThresholdMB uint64
|
||||
MaxSnapshots int
|
||||
EnableContinuousMonitoring bool
|
||||
MonitoringInterval time.Duration
|
||||
}
|
||||
|
||||
// LeakDetectionConfig contains configuration for leak detection
|
||||
type LeakDetectionConfig struct {
|
||||
EnableLeakDetection bool
|
||||
LeakThresholdMB uint64
|
||||
GoroutineLeakThreshold int
|
||||
SessionPoolThreshold int
|
||||
CacheMemoryThreshold uint64
|
||||
HTTPClientThreshold int
|
||||
TokenCompressionThreshold uint64
|
||||
}
|
||||
|
||||
// NewProfilingManager creates a new profiling manager instance
|
||||
func NewProfilingManager(logger *Logger) *ProfilingManager {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
return &ProfilingManager{
|
||||
profilers: make(map[string]MemoryProfiler),
|
||||
config: ProfilingConfig{
|
||||
EnableHeapProfiling: true,
|
||||
EnableGoroutineProfiling: true,
|
||||
SnapshotInterval: 30 * time.Second,
|
||||
LeakThresholdMB: 50, // 50MB
|
||||
MaxSnapshots: 100,
|
||||
EnableContinuousMonitoring: true,
|
||||
MonitoringInterval: 60 * time.Second,
|
||||
},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// TakeSnapshot captures current memory statistics
|
||||
func (pm *ProfilingManager) TakeSnapshot() (*MemorySnapshot, error) {
|
||||
var buf bytes.Buffer
|
||||
snapshot := &MemorySnapshot{
|
||||
Timestamp: time.Now(),
|
||||
CustomMetrics: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Capture runtime memory statistics
|
||||
runtime.ReadMemStats(&snapshot.RuntimeStats)
|
||||
|
||||
// Capture heap profile if enabled
|
||||
if pm.config.EnableHeapProfiling {
|
||||
if err := pprof.WriteHeapProfile(&buf); err != nil {
|
||||
pm.logger.Errorf("Failed to capture heap profile: %v", err)
|
||||
} else {
|
||||
snapshot.HeapProfile = make([]byte, buf.Len())
|
||||
copy(snapshot.HeapProfile, buf.Bytes())
|
||||
buf.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// Capture goroutine profile if enabled
|
||||
if pm.config.EnableGoroutineProfiling {
|
||||
if err := pprof.Lookup("goroutine").WriteTo(&buf, 0); err != nil {
|
||||
pm.logger.Errorf("Failed to capture goroutine profile: %v", err)
|
||||
} else {
|
||||
snapshot.GoroutineProfile = make([]byte, buf.Len())
|
||||
copy(snapshot.GoroutineProfile, buf.Bytes())
|
||||
buf.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// Capture custom metrics from registered profilers
|
||||
pm.mu.RLock()
|
||||
for name, profiler := range pm.profilers {
|
||||
if customStats := profiler.GetCurrentStats(); customStats != nil {
|
||||
snapshot.CustomMetrics[name] = customStats
|
||||
}
|
||||
}
|
||||
pm.mu.RUnlock()
|
||||
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
// StartProfiling begins memory profiling with specified configuration
|
||||
func (pm *ProfilingManager) StartProfiling(config ProfilingConfig) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if pm.isProfiling {
|
||||
return fmt.Errorf("profiling already in progress")
|
||||
}
|
||||
|
||||
pm.config = config
|
||||
pm.isProfiling = true
|
||||
pm.startTime = time.Now()
|
||||
|
||||
// Take baseline snapshot
|
||||
baseline, err := pm.TakeSnapshot()
|
||||
if err != nil {
|
||||
pm.isProfiling = false
|
||||
return fmt.Errorf("failed to take baseline snapshot: %w", err)
|
||||
}
|
||||
pm.baselineSnapshot = baseline
|
||||
|
||||
pm.logger.Infof("Memory profiling started at %v", pm.startTime)
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopProfiling ends memory profiling and returns final snapshot
|
||||
func (pm *ProfilingManager) StopProfiling() (*MemorySnapshot, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if !pm.isProfiling {
|
||||
return nil, fmt.Errorf("profiling not in progress")
|
||||
}
|
||||
|
||||
// Take final snapshot
|
||||
finalSnapshot, err := pm.TakeSnapshot()
|
||||
if err != nil {
|
||||
pm.logger.Errorf("Failed to take final snapshot: %v", err)
|
||||
}
|
||||
|
||||
pm.isProfiling = false
|
||||
duration := time.Since(pm.startTime)
|
||||
|
||||
pm.logger.Infof("Memory profiling stopped after %v", duration)
|
||||
return finalSnapshot, err
|
||||
}
|
||||
|
||||
// GetCurrentStats returns current runtime memory statistics
|
||||
func (pm *ProfilingManager) GetCurrentStats() *runtime.MemStats {
|
||||
stats := &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stats)
|
||||
return stats
|
||||
}
|
||||
|
||||
// AnalyzeLeaks performs leak detection analysis
|
||||
func (pm *ProfilingManager) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
|
||||
analysis := &LeakAnalysis{
|
||||
SuspectedLeaks: make([]string, 0),
|
||||
Recommendations: make([]string, 0),
|
||||
}
|
||||
|
||||
if baseline == nil || current == nil {
|
||||
analysis.HasLeak = false
|
||||
analysis.LeakDescription = "Insufficient data for leak analysis"
|
||||
return analysis
|
||||
}
|
||||
|
||||
// Calculate memory increase
|
||||
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
|
||||
analysis.MemoryIncrease = memoryIncrease
|
||||
|
||||
// Calculate goroutine increase
|
||||
currentGoroutines := runtime.NumGoroutine()
|
||||
baselineGoroutines := runtime.NumGoroutine() // Note: This is not accurate for baseline, but we don't have historical data
|
||||
goroutineIncrease := currentGoroutines - baselineGoroutines
|
||||
analysis.GoroutineIncrease = goroutineIncrease
|
||||
|
||||
// Check for memory leaks
|
||||
memoryThreshold := pm.config.LeakThresholdMB * 1024 * 1024 // Convert MB to bytes
|
||||
if memoryIncrease > memoryThreshold {
|
||||
analysis.HasLeak = true
|
||||
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
|
||||
fmt.Sprintf("Memory usage increased by %.2f MB", float64(memoryIncrease)/(1024*1024)))
|
||||
analysis.Recommendations = append(analysis.Recommendations,
|
||||
"Consider checking for unreleased memory pools or growing caches")
|
||||
}
|
||||
|
||||
// Check for goroutine leaks
|
||||
if goroutineIncrease > 10 { // Arbitrary threshold
|
||||
analysis.HasLeak = true
|
||||
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
|
||||
fmt.Sprintf("Goroutine count increased by %d", goroutineIncrease))
|
||||
analysis.Recommendations = append(analysis.Recommendations,
|
||||
"Check for goroutines that are not being properly cleaned up")
|
||||
}
|
||||
|
||||
if analysis.HasLeak {
|
||||
analysis.LeakDescription = fmt.Sprintf("Potential memory leak detected: %s",
|
||||
fmt.Sprintf("%.2f MB increase, %d goroutines", float64(memoryIncrease)/(1024*1024), goroutineIncrease))
|
||||
} else {
|
||||
analysis.LeakDescription = "No significant memory leaks detected"
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// RegisterProfiler registers a component-specific profiler
|
||||
func (pm *ProfilingManager) RegisterProfiler(name string, profiler MemoryProfiler) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.profilers[name] = profiler
|
||||
pm.logger.Debugf("Registered profiler: %s", name)
|
||||
}
|
||||
|
||||
// UnregisterProfiler removes a component-specific profiler
|
||||
func (pm *ProfilingManager) UnregisterProfiler(name string) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
delete(pm.profilers, name)
|
||||
pm.logger.Debugf("Unregistered profiler: %s", name)
|
||||
}
|
||||
|
||||
// GetRegisteredProfilers returns list of registered profiler names
|
||||
func (pm *ProfilingManager) GetRegisteredProfilers() []string {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(pm.profilers))
|
||||
for name := range pm.profilers {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// MemoryTestOrchestrator coordinates memory leak testing across components
|
||||
type MemoryTestOrchestrator struct {
|
||||
mu sync.RWMutex
|
||||
profilers map[string]MemoryProfiler
|
||||
config LeakDetectionConfig
|
||||
logger *Logger
|
||||
isRunning bool
|
||||
stopChan chan struct{}
|
||||
testResults map[string]*LeakAnalysis
|
||||
}
|
||||
|
||||
// NewMemoryTestOrchestrator creates a new test orchestrator
|
||||
func NewMemoryTestOrchestrator(config LeakDetectionConfig, logger *Logger) *MemoryTestOrchestrator {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
|
||||
return &MemoryTestOrchestrator{
|
||||
profilers: make(map[string]MemoryProfiler),
|
||||
config: config,
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
testResults: make(map[string]*LeakAnalysis),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterComponent registers a component for memory leak testing
|
||||
func (mto *MemoryTestOrchestrator) RegisterComponent(name string, profiler MemoryProfiler) {
|
||||
mto.mu.Lock()
|
||||
defer mto.mu.Unlock()
|
||||
mto.profilers[name] = profiler
|
||||
mto.logger.Debugf("Registered component for leak testing: %s", name)
|
||||
}
|
||||
|
||||
// UnregisterComponent removes a component from leak testing
|
||||
func (mto *MemoryTestOrchestrator) UnregisterComponent(name string) {
|
||||
mto.mu.Lock()
|
||||
defer mto.mu.Unlock()
|
||||
delete(mto.profilers, name)
|
||||
delete(mto.testResults, name)
|
||||
mto.logger.Debugf("Unregistered component from leak testing: %s", name)
|
||||
}
|
||||
|
||||
// StartLeakDetection begins continuous leak detection monitoring
|
||||
func (mto *MemoryTestOrchestrator) StartLeakDetection() error {
|
||||
mto.mu.Lock()
|
||||
defer mto.mu.Unlock()
|
||||
|
||||
if mto.isRunning {
|
||||
return fmt.Errorf("leak detection already running")
|
||||
}
|
||||
|
||||
if !mto.config.EnableLeakDetection {
|
||||
return fmt.Errorf("leak detection is disabled in configuration")
|
||||
}
|
||||
|
||||
mto.isRunning = true
|
||||
go mto.runLeakDetection()
|
||||
|
||||
mto.logger.Infof("Memory leak detection started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopLeakDetection stops continuous leak detection monitoring
|
||||
func (mto *MemoryTestOrchestrator) StopLeakDetection() error {
|
||||
mto.mu.Lock()
|
||||
defer mto.mu.Unlock()
|
||||
|
||||
if !mto.isRunning {
|
||||
return fmt.Errorf("leak detection not running")
|
||||
}
|
||||
|
||||
mto.isRunning = false
|
||||
close(mto.stopChan)
|
||||
mto.stopChan = make(chan struct{}) // Reset for potential restart
|
||||
|
||||
mto.logger.Infof("Memory leak detection stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// runLeakDetection performs continuous leak detection monitoring
|
||||
func (mto *MemoryTestOrchestrator) runLeakDetection() {
|
||||
ticker := time.NewTicker(5 * time.Minute) // Check every 5 minutes
|
||||
defer ticker.Stop()
|
||||
|
||||
baselineSnapshots := make(map[string]*MemorySnapshot)
|
||||
|
||||
// Take initial baseline snapshots
|
||||
mto.mu.RLock()
|
||||
for name, profiler := range mto.profilers {
|
||||
if snapshot, err := profiler.TakeSnapshot(); err == nil {
|
||||
baselineSnapshots[name] = snapshot
|
||||
}
|
||||
}
|
||||
mto.mu.RUnlock()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
mto.performLeakCheck(baselineSnapshots)
|
||||
case <-mto.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performLeakCheck performs leak detection for all registered components
|
||||
func (mto *MemoryTestOrchestrator) performLeakCheck(baselineSnapshots map[string]*MemorySnapshot) {
|
||||
mto.mu.RLock()
|
||||
defer mto.mu.RUnlock()
|
||||
|
||||
for name, profiler := range mto.profilers {
|
||||
baseline, exists := baselineSnapshots[name]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
current, err := profiler.TakeSnapshot()
|
||||
if err != nil {
|
||||
mto.logger.Errorf("Failed to take snapshot for component %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
analysis := profiler.AnalyzeLeaks(baseline, current)
|
||||
if analysis.HasLeak {
|
||||
mto.logger.Errorf("Memory leak detected in component %s: %s", name, analysis.LeakDescription)
|
||||
for _, rec := range analysis.Recommendations {
|
||||
mto.logger.Errorf("Recommendation for %s: %s", name, rec)
|
||||
}
|
||||
}
|
||||
|
||||
mto.testResults[name] = analysis
|
||||
}
|
||||
}
|
||||
|
||||
// GetLeakAnalysis returns leak analysis for a specific component
|
||||
func (mto *MemoryTestOrchestrator) GetLeakAnalysis(componentName string) (*LeakAnalysis, bool) {
|
||||
mto.mu.RLock()
|
||||
defer mto.mu.RUnlock()
|
||||
analysis, exists := mto.testResults[componentName]
|
||||
return analysis, exists
|
||||
}
|
||||
|
||||
// GetAllLeakAnalyses returns leak analyses for all components
|
||||
func (mto *MemoryTestOrchestrator) GetAllLeakAnalyses() map[string]*LeakAnalysis {
|
||||
mto.mu.RLock()
|
||||
defer mto.mu.RUnlock()
|
||||
|
||||
results := make(map[string]*LeakAnalysis)
|
||||
for name, analysis := range mto.testResults {
|
||||
results[name] = analysis
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// Component-specific profiler implementations
|
||||
|
||||
// SessionPoolProfiler monitors session pool memory usage
|
||||
type SessionPoolProfiler struct {
|
||||
sessionManager *SessionManager
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewSessionPoolProfiler creates a new session pool profiler
|
||||
func NewSessionPoolProfiler(sm *SessionManager, logger *Logger) *SessionPoolProfiler {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
return &SessionPoolProfiler{
|
||||
sessionManager: sm,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// TakeSnapshot captures session pool memory statistics
|
||||
func (spp *SessionPoolProfiler) TakeSnapshot() (*MemorySnapshot, error) {
|
||||
snapshot := &MemorySnapshot{
|
||||
Timestamp: time.Now(),
|
||||
CustomMetrics: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Capture runtime stats
|
||||
runtime.ReadMemStats(&snapshot.RuntimeStats)
|
||||
|
||||
// Add session pool specific metrics
|
||||
snapshot.CustomMetrics["session_pool_metrics"] = spp.sessionManager.GetSessionMetrics()
|
||||
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
// StartProfiling begins profiling (no-op for session pools)
|
||||
func (spp *SessionPoolProfiler) StartProfiling(config ProfilingConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopProfiling ends profiling (no-op for session pools)
|
||||
func (spp *SessionPoolProfiler) StopProfiling() (*MemorySnapshot, error) {
|
||||
return spp.TakeSnapshot()
|
||||
}
|
||||
|
||||
// GetCurrentStats returns current memory statistics
|
||||
func (spp *SessionPoolProfiler) GetCurrentStats() *runtime.MemStats {
|
||||
stats := &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stats)
|
||||
return stats
|
||||
}
|
||||
|
||||
// AnalyzeLeaks analyzes session pool for leaks
|
||||
func (spp *SessionPoolProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
|
||||
analysis := &LeakAnalysis{
|
||||
SuspectedLeaks: make([]string, 0),
|
||||
Recommendations: make([]string, 0),
|
||||
}
|
||||
|
||||
if baseline == nil || current == nil {
|
||||
analysis.LeakDescription = "Insufficient session pool data"
|
||||
return analysis
|
||||
}
|
||||
|
||||
// Check for session pool specific leaks
|
||||
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
|
||||
if memoryIncrease > 10*1024*1024 { // 10MB threshold
|
||||
analysis.HasLeak = true
|
||||
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
|
||||
"Session pool memory usage increased significantly")
|
||||
analysis.Recommendations = append(analysis.Recommendations,
|
||||
"Check for sessions not being returned to pool properly")
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// CacheMemoryProfiler monitors cache memory usage
|
||||
type CacheMemoryProfiler struct {
|
||||
cache *Cache
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewCacheMemoryProfiler creates a new cache memory profiler
|
||||
func NewCacheMemoryProfiler(cache *Cache, logger *Logger) *CacheMemoryProfiler {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
return &CacheMemoryProfiler{
|
||||
cache: cache,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// TakeSnapshot captures cache memory statistics
|
||||
func (cmp *CacheMemoryProfiler) TakeSnapshot() (*MemorySnapshot, error) {
|
||||
snapshot := &MemorySnapshot{
|
||||
Timestamp: time.Now(),
|
||||
CustomMetrics: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
runtime.ReadMemStats(&snapshot.RuntimeStats)
|
||||
|
||||
// Add cache-specific metrics (would need to be added to Cache struct)
|
||||
snapshot.CustomMetrics["cache_size"] = "unknown" // Placeholder
|
||||
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
// StartProfiling begins profiling (no-op for cache)
|
||||
func (cmp *CacheMemoryProfiler) StartProfiling(config ProfilingConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopProfiling ends profiling
|
||||
func (cmp *CacheMemoryProfiler) StopProfiling() (*MemorySnapshot, error) {
|
||||
return cmp.TakeSnapshot()
|
||||
}
|
||||
|
||||
// GetCurrentStats returns current memory statistics
|
||||
func (cmp *CacheMemoryProfiler) GetCurrentStats() *runtime.MemStats {
|
||||
stats := &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stats)
|
||||
return stats
|
||||
}
|
||||
|
||||
// AnalyzeLeaks analyzes cache for memory leaks
|
||||
func (cmp *CacheMemoryProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
|
||||
analysis := &LeakAnalysis{
|
||||
SuspectedLeaks: make([]string, 0),
|
||||
Recommendations: make([]string, 0),
|
||||
}
|
||||
|
||||
if baseline == nil || current == nil {
|
||||
analysis.LeakDescription = "Insufficient cache data"
|
||||
return analysis
|
||||
}
|
||||
|
||||
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
|
||||
if memoryIncrease > 20*1024*1024 { // 20MB threshold for cache
|
||||
analysis.HasLeak = true
|
||||
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
|
||||
"Cache memory usage increased significantly")
|
||||
analysis.Recommendations = append(analysis.Recommendations,
|
||||
"Check cache size limits and cleanup intervals")
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// HTTPClientProfiler monitors HTTP client connection pools
|
||||
type HTTPClientProfiler struct {
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewHTTPClientProfiler creates a new HTTP client profiler
|
||||
func NewHTTPClientProfiler(client *http.Client, logger *Logger) *HTTPClientProfiler {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
return &HTTPClientProfiler{
|
||||
httpClient: client,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// TakeSnapshot captures HTTP client memory statistics
|
||||
func (hcp *HTTPClientProfiler) TakeSnapshot() (*MemorySnapshot, error) {
|
||||
snapshot := &MemorySnapshot{
|
||||
Timestamp: time.Now(),
|
||||
CustomMetrics: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
runtime.ReadMemStats(&snapshot.RuntimeStats)
|
||||
|
||||
// Add HTTP client specific metrics
|
||||
if transport, ok := hcp.httpClient.Transport.(*http.Transport); ok {
|
||||
snapshot.CustomMetrics["idle_connections"] = transport.IdleConnTimeout.String()
|
||||
snapshot.CustomMetrics["max_idle_conns"] = transport.MaxIdleConns
|
||||
}
|
||||
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
// StartProfiling begins profiling (no-op for HTTP client)
|
||||
func (hcp *HTTPClientProfiler) StartProfiling(config ProfilingConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopProfiling ends profiling
|
||||
func (hcp *HTTPClientProfiler) StopProfiling() (*MemorySnapshot, error) {
|
||||
return hcp.TakeSnapshot()
|
||||
}
|
||||
|
||||
// GetCurrentStats returns current memory statistics
|
||||
func (hcp *HTTPClientProfiler) GetCurrentStats() *runtime.MemStats {
|
||||
stats := &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stats)
|
||||
return stats
|
||||
}
|
||||
|
||||
// AnalyzeLeaks analyzes HTTP client for connection leaks
|
||||
func (hcp *HTTPClientProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
|
||||
analysis := &LeakAnalysis{
|
||||
SuspectedLeaks: make([]string, 0),
|
||||
Recommendations: make([]string, 0),
|
||||
}
|
||||
|
||||
if baseline == nil || current == nil {
|
||||
analysis.LeakDescription = "Insufficient HTTP client data"
|
||||
return analysis
|
||||
}
|
||||
|
||||
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
|
||||
if memoryIncrease > 5*1024*1024 { // 5MB threshold for HTTP client
|
||||
analysis.HasLeak = true
|
||||
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
|
||||
"HTTP client memory usage increased significantly")
|
||||
analysis.Recommendations = append(analysis.Recommendations,
|
||||
"Check for HTTP response bodies not being drained properly")
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// TokenCompressionProfiler monitors token compression memory usage
|
||||
type TokenCompressionProfiler struct {
|
||||
compressionPool *TokenCompressionPool
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewTokenCompressionProfiler creates a new token compression profiler
|
||||
func NewTokenCompressionProfiler(pool *TokenCompressionPool, logger *Logger) *TokenCompressionProfiler {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
return &TokenCompressionProfiler{
|
||||
compressionPool: pool,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// TakeSnapshot captures token compression memory statistics
|
||||
func (tcp *TokenCompressionProfiler) TakeSnapshot() (*MemorySnapshot, error) {
|
||||
snapshot := &MemorySnapshot{
|
||||
Timestamp: time.Now(),
|
||||
CustomMetrics: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
runtime.ReadMemStats(&snapshot.RuntimeStats)
|
||||
|
||||
// Add compression pool specific metrics
|
||||
snapshot.CustomMetrics["compression_pool_active"] = true
|
||||
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
// StartProfiling begins profiling (no-op for compression)
|
||||
func (tcp *TokenCompressionProfiler) StartProfiling(config ProfilingConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopProfiling ends profiling
|
||||
func (tcp *TokenCompressionProfiler) StopProfiling() (*MemorySnapshot, error) {
|
||||
return tcp.TakeSnapshot()
|
||||
}
|
||||
|
||||
// GetCurrentStats returns current memory statistics
|
||||
func (tcp *TokenCompressionProfiler) GetCurrentStats() *runtime.MemStats {
|
||||
stats := &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stats)
|
||||
return stats
|
||||
}
|
||||
|
||||
// AnalyzeLeaks analyzes token compression for memory leaks
|
||||
func (tcp *TokenCompressionProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
|
||||
analysis := &LeakAnalysis{
|
||||
SuspectedLeaks: make([]string, 0),
|
||||
Recommendations: make([]string, 0),
|
||||
}
|
||||
|
||||
if baseline == nil || current == nil {
|
||||
analysis.LeakDescription = "Insufficient compression data"
|
||||
return analysis
|
||||
}
|
||||
|
||||
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
|
||||
if memoryIncrease > 2*1024*1024 { // 2MB threshold for compression
|
||||
analysis.HasLeak = true
|
||||
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
|
||||
"Token compression memory usage increased significantly")
|
||||
analysis.Recommendations = append(analysis.Recommendations,
|
||||
"Check for compression buffers not being returned to pool")
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// MemoryPoolProfiler monitors memory pool usage and detects leaks
|
||||
type MemoryPoolProfiler struct {
|
||||
memoryPoolManager *MemoryPoolManager
|
||||
tokenCompressionPool *TokenCompressionPool
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewMemoryPoolProfiler creates a new memory pool profiler
|
||||
func NewMemoryPoolProfiler(memoryPoolManager *MemoryPoolManager, tokenCompressionPool *TokenCompressionPool, logger *Logger) *MemoryPoolProfiler {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
return &MemoryPoolProfiler{
|
||||
memoryPoolManager: memoryPoolManager,
|
||||
tokenCompressionPool: tokenCompressionPool,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// TakeSnapshot captures memory pool statistics
|
||||
func (mpp *MemoryPoolProfiler) TakeSnapshot() (*MemorySnapshot, error) {
|
||||
snapshot := &MemorySnapshot{
|
||||
Timestamp: time.Now(),
|
||||
CustomMetrics: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Capture runtime stats
|
||||
runtime.ReadMemStats(&snapshot.RuntimeStats)
|
||||
|
||||
// Add memory pool metrics
|
||||
if mpp.memoryPoolManager != nil {
|
||||
snapshot.CustomMetrics["memory_pool_active"] = true
|
||||
// Note: sync.Pool doesn't expose internal statistics, so we track usage patterns
|
||||
}
|
||||
|
||||
if mpp.tokenCompressionPool != nil {
|
||||
snapshot.CustomMetrics["token_compression_pool_active"] = true
|
||||
}
|
||||
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
// StartProfiling begins profiling (no-op for memory pools)
|
||||
func (mpp *MemoryPoolProfiler) StartProfiling(config ProfilingConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopProfiling ends profiling
|
||||
func (mpp *MemoryPoolProfiler) StopProfiling() (*MemorySnapshot, error) {
|
||||
return mpp.TakeSnapshot()
|
||||
}
|
||||
|
||||
// GetCurrentStats returns current memory statistics
|
||||
func (mpp *MemoryPoolProfiler) GetCurrentStats() *runtime.MemStats {
|
||||
stats := &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stats)
|
||||
return stats
|
||||
}
|
||||
|
||||
// AnalyzeLeaks analyzes memory pools for leaks
|
||||
func (mpp *MemoryPoolProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
|
||||
analysis := &LeakAnalysis{
|
||||
SuspectedLeaks: make([]string, 0),
|
||||
Recommendations: make([]string, 0),
|
||||
}
|
||||
|
||||
if baseline == nil || current == nil {
|
||||
analysis.LeakDescription = "Insufficient memory pool data"
|
||||
return analysis
|
||||
}
|
||||
|
||||
// Check for memory leaks in pool operations
|
||||
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
|
||||
if memoryIncrease > 5*1024*1024 { // 5MB threshold for pool operations
|
||||
analysis.HasLeak = true
|
||||
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
|
||||
"Memory pool operations caused significant memory increase")
|
||||
analysis.Recommendations = append(analysis.Recommendations,
|
||||
"Check for objects not being returned to memory pools properly")
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// Global profiling manager instance
|
||||
var globalProfilingManager *ProfilingManager
|
||||
var profilingManagerOnce sync.Once
|
||||
|
||||
// GetGlobalProfilingManager returns the singleton profiling manager
|
||||
func GetGlobalProfilingManager() *ProfilingManager {
|
||||
profilingManagerOnce.Do(func() {
|
||||
globalProfilingManager = NewProfilingManager(nil)
|
||||
})
|
||||
return globalProfilingManager
|
||||
}
|
||||
|
||||
// Global test orchestrator instance
|
||||
var globalTestOrchestrator *MemoryTestOrchestrator
|
||||
var testOrchestratorOnce sync.Once
|
||||
|
||||
// GetGlobalTestOrchestrator returns the singleton test orchestrator
|
||||
func GetGlobalTestOrchestrator() *MemoryTestOrchestrator {
|
||||
testOrchestratorOnce.Do(func() {
|
||||
config := LeakDetectionConfig{
|
||||
EnableLeakDetection: true,
|
||||
LeakThresholdMB: 50,
|
||||
GoroutineLeakThreshold: 10,
|
||||
SessionPoolThreshold: 100,
|
||||
CacheMemoryThreshold: 20 * 1024 * 1024, // 20MB
|
||||
HTTPClientThreshold: 50,
|
||||
TokenCompressionThreshold: 2 * 1024 * 1024, // 2MB
|
||||
}
|
||||
globalTestOrchestrator = NewMemoryTestOrchestrator(config, nil)
|
||||
})
|
||||
return globalTestOrchestrator
|
||||
}
|
||||
@@ -0,0 +1,788 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestProfilingManager(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
pm := NewProfilingManager(logger)
|
||||
|
||||
// Test taking a snapshot
|
||||
snapshot, err := pm.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take snapshot: %v", err)
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
t.Fatal("Snapshot is nil")
|
||||
}
|
||||
|
||||
if snapshot.RuntimeStats.Alloc == 0 {
|
||||
t.Error("Runtime stats Alloc should not be zero")
|
||||
}
|
||||
|
||||
if snapshot.Timestamp.IsZero() {
|
||||
t.Error("Snapshot timestamp should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryTestOrchestrator(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
config := LeakDetectionConfig{
|
||||
EnableLeakDetection: true,
|
||||
LeakThresholdMB: 10,
|
||||
}
|
||||
|
||||
mto := NewMemoryTestOrchestrator(config, logger)
|
||||
|
||||
// Test registering a component
|
||||
sessionManager, err := NewSessionManager("test-key-32-chars-long-for-testing", false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
profiler := NewSessionPoolProfiler(sessionManager, logger)
|
||||
mto.RegisterComponent("session_pool", profiler)
|
||||
|
||||
// Test getting leak analysis (should return false initially since no checks have been performed)
|
||||
_, exists := mto.GetLeakAnalysis("session_pool")
|
||||
if exists {
|
||||
t.Error("Should not have leak analysis before any checks are performed")
|
||||
}
|
||||
|
||||
// Perform a manual leak check
|
||||
baseline, err := profiler.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take baseline snapshot: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond) // Small delay
|
||||
|
||||
// Manually trigger leak check with baseline
|
||||
baselineSnapshots := make(map[string]*MemorySnapshot)
|
||||
baselineSnapshots["session_pool"] = baseline
|
||||
mto.performLeakCheck(baselineSnapshots)
|
||||
|
||||
// Now test getting leak analysis
|
||||
analysis, exists := mto.GetLeakAnalysis("session_pool")
|
||||
if !exists {
|
||||
t.Error("Should have leak analysis after performing checks")
|
||||
}
|
||||
|
||||
if analysis == nil {
|
||||
t.Error("Leak analysis should not be nil after checks")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComponentProfilers(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
|
||||
// Test Session Pool Profiler
|
||||
sessionManager, err := NewSessionManager("test-key-32-chars-long-for-testing", false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
spp := NewSessionPoolProfiler(sessionManager, logger)
|
||||
snapshot, err := spp.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take session pool snapshot: %v", err)
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
t.Fatal("Session pool snapshot is nil")
|
||||
}
|
||||
|
||||
// Test Cache Memory Profiler
|
||||
cache := NewCache()
|
||||
cmp := NewCacheMemoryProfiler(cache, logger)
|
||||
snapshot, err = cmp.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take cache snapshot: %v", err)
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
t.Fatal("Cache snapshot is nil")
|
||||
}
|
||||
|
||||
// Test HTTP Client Profiler
|
||||
httpClient := createDefaultHTTPClient()
|
||||
hcp := NewHTTPClientProfiler(httpClient, logger)
|
||||
snapshot, err = hcp.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take HTTP client snapshot: %v", err)
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
t.Fatal("HTTP client snapshot is nil")
|
||||
}
|
||||
|
||||
// Test Token Compression Profiler
|
||||
compressionPool := NewTokenCompressionPool()
|
||||
tcp := NewTokenCompressionProfiler(compressionPool, logger)
|
||||
snapshot, err = tcp.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take compression snapshot: %v", err)
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
t.Fatal("Compression snapshot is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeakAnalysis(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
pm := NewProfilingManager(logger)
|
||||
|
||||
// Create baseline snapshot
|
||||
baseline, err := pm.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create baseline: %v", err)
|
||||
}
|
||||
|
||||
// Wait a bit and create current snapshot
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
current, err := pm.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create current snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Test leak analysis
|
||||
analysis := pm.AnalyzeLeaks(baseline, current)
|
||||
if analysis == nil {
|
||||
t.Fatal("Leak analysis is nil")
|
||||
}
|
||||
|
||||
// Analysis should not have leaks for normal operation
|
||||
if analysis.HasLeak {
|
||||
t.Logf("Leak detected: %s", analysis.LeakDescription)
|
||||
// This is acceptable as the test environment may have varying memory usage
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalInstances(t *testing.T) {
|
||||
// Test global profiling manager
|
||||
gpm := GetGlobalProfilingManager()
|
||||
if gpm == nil {
|
||||
t.Fatal("Global profiling manager is nil")
|
||||
}
|
||||
|
||||
// Test global test orchestrator
|
||||
gto := GetGlobalTestOrchestrator()
|
||||
if gto == nil {
|
||||
t.Fatal("Global test orchestrator is nil")
|
||||
}
|
||||
|
||||
// Test that they're singletons
|
||||
gpm2 := GetGlobalProfilingManager()
|
||||
if gpm != gpm2 {
|
||||
t.Error("Global profiling manager should be singleton")
|
||||
}
|
||||
|
||||
gto2 := GetGlobalTestOrchestrator()
|
||||
if gto != gto2 {
|
||||
t.Error("Global test orchestrator should be singleton")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProfilingConfig(t *testing.T) {
|
||||
config := ProfilingConfig{
|
||||
EnableHeapProfiling: true,
|
||||
EnableGoroutineProfiling: true,
|
||||
SnapshotInterval: 30 * time.Second,
|
||||
LeakThresholdMB: 50,
|
||||
MaxSnapshots: 100,
|
||||
EnableContinuousMonitoring: true,
|
||||
MonitoringInterval: 60 * time.Second,
|
||||
}
|
||||
|
||||
if !config.EnableHeapProfiling {
|
||||
t.Error("Heap profiling should be enabled")
|
||||
}
|
||||
|
||||
if !config.EnableGoroutineProfiling {
|
||||
t.Error("Goroutine profiling should be enabled")
|
||||
}
|
||||
|
||||
if config.LeakThresholdMB != 50 {
|
||||
t.Errorf("Expected leak threshold 50, got %d", config.LeakThresholdMB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeakDetectionConfig(t *testing.T) {
|
||||
config := LeakDetectionConfig{
|
||||
EnableLeakDetection: true,
|
||||
LeakThresholdMB: 50,
|
||||
GoroutineLeakThreshold: 10,
|
||||
SessionPoolThreshold: 100,
|
||||
CacheMemoryThreshold: 20 * 1024 * 1024,
|
||||
HTTPClientThreshold: 50,
|
||||
TokenCompressionThreshold: 2 * 1024 * 1024,
|
||||
}
|
||||
|
||||
if !config.EnableLeakDetection {
|
||||
t.Error("Leak detection should be enabled")
|
||||
}
|
||||
|
||||
if config.LeakThresholdMB != 50 {
|
||||
t.Errorf("Expected leak threshold 50, got %d", config.LeakThresholdMB)
|
||||
}
|
||||
|
||||
if config.CacheMemoryThreshold != 20*1024*1024 {
|
||||
t.Errorf("Expected cache threshold 20MB, got %d", config.CacheMemoryThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
// ProviderMetadataProfiler monitors provider metadata fetching and caching operations
|
||||
type ProviderMetadataProfiler struct {
|
||||
metadataCache *MetadataCache
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
providerURL string
|
||||
}
|
||||
|
||||
// NewProviderMetadataProfiler creates a new provider metadata profiler
|
||||
func NewProviderMetadataProfiler(metadataCache *MetadataCache, httpClient *http.Client, providerURL string, logger *Logger) *ProviderMetadataProfiler {
|
||||
if logger == nil {
|
||||
logger = newNoOpLogger()
|
||||
}
|
||||
return &ProviderMetadataProfiler{
|
||||
metadataCache: metadataCache,
|
||||
httpClient: httpClient,
|
||||
providerURL: providerURL,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// TakeSnapshot captures current memory statistics for metadata operations
|
||||
func (pmp *ProviderMetadataProfiler) TakeSnapshot() (*MemorySnapshot, error) {
|
||||
snapshot := &MemorySnapshot{
|
||||
Timestamp: time.Now(),
|
||||
CustomMetrics: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Capture runtime memory statistics
|
||||
runtime.ReadMemStats(&snapshot.RuntimeStats)
|
||||
|
||||
// Add metadata-specific metrics
|
||||
snapshot.CustomMetrics["metadata_cache_size"] = 1 // Placeholder for cache size
|
||||
snapshot.CustomMetrics["metadata_fetch_count"] = 0 // Placeholder for fetch count
|
||||
snapshot.CustomMetrics["background_goroutines"] = runtime.NumGoroutine()
|
||||
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
// StartProfiling begins profiling (no-op for metadata profiler)
|
||||
func (pmp *ProviderMetadataProfiler) StartProfiling(config ProfilingConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopProfiling ends profiling
|
||||
func (pmp *ProviderMetadataProfiler) StopProfiling() (*MemorySnapshot, error) {
|
||||
return pmp.TakeSnapshot()
|
||||
}
|
||||
|
||||
// GetCurrentStats returns current memory statistics
|
||||
func (pmp *ProviderMetadataProfiler) GetCurrentStats() *runtime.MemStats {
|
||||
stats := &runtime.MemStats{}
|
||||
runtime.ReadMemStats(stats)
|
||||
return stats
|
||||
}
|
||||
|
||||
// AnalyzeLeaks analyzes metadata operations for memory leaks
|
||||
func (pmp *ProviderMetadataProfiler) AnalyzeLeaks(baseline, current *MemorySnapshot) *LeakAnalysis {
|
||||
analysis := &LeakAnalysis{
|
||||
SuspectedLeaks: make([]string, 0),
|
||||
Recommendations: make([]string, 0),
|
||||
}
|
||||
|
||||
if baseline == nil || current == nil {
|
||||
analysis.LeakDescription = "Insufficient metadata data"
|
||||
return analysis
|
||||
}
|
||||
|
||||
// Check for memory leaks
|
||||
memoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
|
||||
if memoryIncrease > 5*1024*1024 { // 5MB threshold for metadata operations
|
||||
analysis.HasLeak = true
|
||||
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
|
||||
"Metadata operations memory usage increased significantly")
|
||||
analysis.Recommendations = append(analysis.Recommendations,
|
||||
"Check for metadata cache not being cleaned up properly")
|
||||
}
|
||||
|
||||
// Check for goroutine leaks
|
||||
goroutineIncrease := current.CustomMetrics["background_goroutines"].(int) - baseline.CustomMetrics["background_goroutines"].(int)
|
||||
if goroutineIncrease > 2 { // Allow some variance
|
||||
analysis.HasLeak = true
|
||||
analysis.SuspectedLeaks = append(analysis.SuspectedLeaks,
|
||||
fmt.Sprintf("Goroutine count increased by %d during metadata operations", goroutineIncrease))
|
||||
analysis.Recommendations = append(analysis.Recommendations,
|
||||
"Check for background goroutines not being cleaned up")
|
||||
}
|
||||
|
||||
return analysis
|
||||
}
|
||||
|
||||
// TestProviderMetadataMemoryLeakDetection tests for memory leaks in provider metadata operations
|
||||
func TestProviderMetadataMemoryLeakDetection(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
|
||||
strictMode := os.Getenv("STRICT_MEMORY_TEST") == "true"
|
||||
if strictMode {
|
||||
t.Log("Running in strict memory test mode - will fail on detected leaks")
|
||||
} else {
|
||||
t.Log("Running in lenient memory test mode - will log warnings instead of failing")
|
||||
}
|
||||
|
||||
config := LeakDetectionConfig{
|
||||
EnableLeakDetection: true,
|
||||
LeakThresholdMB: 10,
|
||||
}
|
||||
|
||||
mto := NewMemoryTestOrchestrator(config, logger)
|
||||
|
||||
// Create mock HTTP server for metadata endpoint with failure simulation
|
||||
requestCount := 0
|
||||
serverFailures := 0
|
||||
mockServer := &http.Server{
|
||||
Addr: "localhost:0", // Let system assign port
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount++
|
||||
if r.URL.Path == "/.well-known/openid-configuration" {
|
||||
// Simulate occasional failures to test cache extension
|
||||
if requestCount%4 == 0 { // Fail every 4th request
|
||||
serverFailures++
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://mock-provider.com",
|
||||
AuthURL: "https://mock-provider.com/auth",
|
||||
TokenURL: "https://mock-provider.com/token",
|
||||
JWKSURL: "https://mock-provider.com/jwks",
|
||||
RevokeURL: "https://mock-provider.com/revoke",
|
||||
EndSessionURL: "https://mock-provider.com/logout",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Cache-Control", "max-age=3600") // 1 hour cache hint
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
} else {
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
// Start mock server
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
go mockServer.Serve(listener)
|
||||
defer mockServer.Close()
|
||||
|
||||
providerURL := fmt.Sprintf("http://%s", listener.Addr().String())
|
||||
httpClient := createDefaultHTTPClient()
|
||||
|
||||
// Create metadata cache with WaitGroup for proper goroutine synchronization
|
||||
var wg sync.WaitGroup
|
||||
metadataCache := NewMetadataCacheWithLogger(&wg, logger)
|
||||
|
||||
// Create profiler
|
||||
profiler := NewProviderMetadataProfiler(metadataCache, httpClient, providerURL, logger)
|
||||
mto.RegisterComponent("provider_metadata", profiler)
|
||||
|
||||
// Take initial baseline
|
||||
baseline, err := profiler.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take baseline snapshot: %v", err)
|
||||
}
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
// Phase 1: Simulate periodic metadata fetching with some failures
|
||||
t.Log("Phase 1: Testing periodic fetching with occasional failures...")
|
||||
for i := 0; i < 20; i++ {
|
||||
_, err := metadataCache.GetMetadata(providerURL, httpClient, logger)
|
||||
if err != nil {
|
||||
t.Logf("Metadata fetch %d failed (expected for cache extension testing): %v", i+1, err)
|
||||
} else {
|
||||
t.Logf("Metadata fetch %d succeeded", i+1)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Wait for background cleanup (normally every 5 minutes)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Take intermediate snapshot
|
||||
intermediate, err := profiler.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take intermediate snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Phase 2: Continue with more fetches to test sustained operation
|
||||
t.Log("Phase 2: Testing sustained operation with 1000 iterations...")
|
||||
for i := 20; i < 1020; i++ {
|
||||
_, err := metadataCache.GetMetadata(providerURL, httpClient, logger)
|
||||
if err != nil {
|
||||
t.Logf("Metadata fetch %d failed: %v", i+1, err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond) // Reduced sleep for faster execution
|
||||
}
|
||||
|
||||
// Take final snapshot
|
||||
current, err := profiler.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take current snapshot: %v", err)
|
||||
}
|
||||
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
|
||||
// Analyze for leaks
|
||||
analysis := profiler.AnalyzeLeaks(baseline, current)
|
||||
|
||||
// Assertions for memory leaks
|
||||
if analysis.HasLeak {
|
||||
if strictMode {
|
||||
t.Errorf("Memory leak detected in provider metadata operations: %s", analysis.LeakDescription)
|
||||
for _, leak := range analysis.SuspectedLeaks {
|
||||
t.Errorf("Suspected leak: %s", leak)
|
||||
}
|
||||
} else {
|
||||
t.Logf("Memory leak warning in provider metadata operations: %s", analysis.LeakDescription)
|
||||
for _, leak := range analysis.SuspectedLeaks {
|
||||
t.Logf("Suspected leak: %s", leak)
|
||||
}
|
||||
}
|
||||
for _, rec := range analysis.Recommendations {
|
||||
t.Logf("Recommendation: %s", rec)
|
||||
}
|
||||
}
|
||||
|
||||
// Check total memory growth
|
||||
totalMemoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
|
||||
if totalMemoryIncrease > 20*1024*1024 { // 20MB threshold for entire test
|
||||
if strictMode {
|
||||
t.Errorf("Total memory usage increased by %.2f MB during metadata operations", float64(totalMemoryIncrease)/(1024*1024))
|
||||
} else {
|
||||
t.Logf("Total memory usage increased by %.2f MB during metadata operations", float64(totalMemoryIncrease)/(1024*1024))
|
||||
}
|
||||
}
|
||||
|
||||
// Check for gradual memory growth patterns
|
||||
intermediateMemoryIncrease := intermediate.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
|
||||
if intermediateMemoryIncrease > 10*1024*1024 { // 10MB threshold for first phase
|
||||
if strictMode {
|
||||
t.Errorf("Memory usage increased by %.2f MB during first phase of metadata operations", float64(intermediateMemoryIncrease)/(1024*1024))
|
||||
} else {
|
||||
t.Logf("Memory usage increased by %.2f MB during first phase of metadata operations", float64(intermediateMemoryIncrease)/(1024*1024))
|
||||
}
|
||||
}
|
||||
|
||||
// Check goroutine count stability
|
||||
goroutineIncrease := finalGoroutines - initialGoroutines
|
||||
if goroutineIncrease > 5 { // Allow some variance for test environment
|
||||
if strictMode {
|
||||
t.Errorf("Goroutine count increased by %d during metadata operations (initial: %d, final: %d)",
|
||||
goroutineIncrease, initialGoroutines, finalGoroutines)
|
||||
} else {
|
||||
t.Logf("Goroutine count increased by %d during metadata operations (initial: %d, final: %d)",
|
||||
goroutineIncrease, initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Test cache extension behavior on persistent failures
|
||||
t.Log("Phase 3: Testing cache extension on persistent failures...")
|
||||
|
||||
// Stop mock server to simulate provider unavailability
|
||||
mockServer.Close()
|
||||
|
||||
// Try multiple fetches after server shutdown
|
||||
postShutdownFailures := 0
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = metadataCache.GetMetadata(providerURL, httpClient, logger)
|
||||
if err != nil {
|
||||
postShutdownFailures++
|
||||
t.Logf("Expected failure %d after server shutdown: %v", i+1, err)
|
||||
} else {
|
||||
t.Logf("Unexpected success %d after server shutdown - cache extension working", i+1)
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
if postShutdownFailures == 0 {
|
||||
if strictMode {
|
||||
t.Error("Expected some metadata fetches to fail after server shutdown")
|
||||
} else {
|
||||
t.Log("Warning: No metadata fetches failed after server shutdown - cache extension may not be working as expected")
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 4: Test background goroutine lifecycle and cleanup
|
||||
t.Log("Phase 4: Testing background goroutine lifecycle...")
|
||||
|
||||
// Wait longer to allow background cleanup to run
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Take final snapshot after cleanup
|
||||
finalAfterCleanup, err := profiler.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take final snapshot after cleanup: %v", err)
|
||||
}
|
||||
|
||||
// Check if memory decreased after cleanup
|
||||
if finalAfterCleanup.RuntimeStats.Alloc < current.RuntimeStats.Alloc {
|
||||
memoryDecrease := current.RuntimeStats.Alloc - finalAfterCleanup.RuntimeStats.Alloc
|
||||
t.Logf("Memory decreased by %.2f MB after cleanup phase", float64(memoryDecrease)/(1024*1024))
|
||||
}
|
||||
|
||||
// Clean up resources
|
||||
metadataCache.Close()
|
||||
wg.Wait() // Ensure all background goroutines complete
|
||||
|
||||
t.Logf("Test completed: %d total requests, %d server failures, %d post-shutdown failures",
|
||||
requestCount, serverFailures, postShutdownFailures)
|
||||
t.Logf("Memory usage: baseline=%.2f MB, intermediate=%.2f MB, final=%.2f MB",
|
||||
float64(baseline.RuntimeStats.Alloc)/(1024*1024),
|
||||
float64(intermediate.RuntimeStats.Alloc)/(1024*1024),
|
||||
float64(current.RuntimeStats.Alloc)/(1024*1024))
|
||||
}
|
||||
|
||||
// TestMemoryPoolLeakDetection tests for memory leaks in memory pool operations
|
||||
func TestMemoryPoolLeakDetection(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
|
||||
strictMode := os.Getenv("STRICT_MEMORY_TEST") == "true"
|
||||
if strictMode {
|
||||
t.Log("Running in strict memory test mode - will fail on detected leaks")
|
||||
} else {
|
||||
t.Log("Running in lenient memory test mode - will log warnings instead of failing")
|
||||
}
|
||||
|
||||
config := LeakDetectionConfig{
|
||||
EnableLeakDetection: true,
|
||||
LeakThresholdMB: 10,
|
||||
}
|
||||
|
||||
mto := NewMemoryTestOrchestrator(config, logger)
|
||||
|
||||
// Create memory pool manager and token compression pool
|
||||
memoryPoolManager := NewMemoryPoolManager()
|
||||
tokenCompressionPool := NewTokenCompressionPool()
|
||||
|
||||
// Create profiler for memory pools
|
||||
profiler := NewMemoryPoolProfiler(memoryPoolManager, tokenCompressionPool, logger)
|
||||
mto.RegisterComponent("memory_pools", profiler)
|
||||
|
||||
// Take initial baseline
|
||||
baseline, err := profiler.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take baseline snapshot: %v", err)
|
||||
}
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
// Phase 1: Simulate various memory pool operations
|
||||
t.Log("Phase 1: Testing memory pool operations with various patterns...")
|
||||
|
||||
// Test compression buffer pool
|
||||
for i := 0; i < 100; i++ {
|
||||
buf := memoryPoolManager.GetCompressionBuffer()
|
||||
// Simulate some work with the buffer
|
||||
buf.WriteString(fmt.Sprintf("test data %d", i))
|
||||
// Properly return buffer to pool
|
||||
memoryPoolManager.PutCompressionBuffer(buf)
|
||||
}
|
||||
|
||||
// Test JWT parsing buffer pool
|
||||
for i := 0; i < 50; i++ {
|
||||
jwtBuf := memoryPoolManager.GetJWTParsingBuffer()
|
||||
// Simulate JWT parsing operations
|
||||
jwtBuf.HeaderBuf = append(jwtBuf.HeaderBuf, []byte("header")...)
|
||||
jwtBuf.PayloadBuf = append(jwtBuf.PayloadBuf, []byte("payload")...)
|
||||
jwtBuf.SignatureBuf = append(jwtBuf.SignatureBuf, []byte("signature")...)
|
||||
// Properly return buffer to pool
|
||||
memoryPoolManager.PutJWTParsingBuffer(jwtBuf)
|
||||
}
|
||||
|
||||
// Test HTTP response buffer pool
|
||||
for i := 0; i < 75; i++ {
|
||||
httpBuf := memoryPoolManager.GetHTTPResponseBuffer()
|
||||
// Simulate HTTP response processing
|
||||
copy(httpBuf[:min(len(httpBuf), 100)], []byte("http response data"))
|
||||
// Properly return buffer to pool
|
||||
memoryPoolManager.PutHTTPResponseBuffer(httpBuf)
|
||||
}
|
||||
|
||||
// Test string builder pool
|
||||
for i := 0; i < 60; i++ {
|
||||
sb := memoryPoolManager.GetStringBuilder()
|
||||
// Simulate string building operations
|
||||
sb.WriteString(fmt.Sprintf("built string %d", i))
|
||||
_ = sb.String() // Use the result
|
||||
// Properly return string builder to pool
|
||||
memoryPoolManager.PutStringBuilder(sb)
|
||||
}
|
||||
|
||||
// Test token compression pool
|
||||
for i := 0; i < 40; i++ {
|
||||
compBuf := tokenCompressionPool.GetCompressionBuffer()
|
||||
// Simulate compression operations
|
||||
compBuf.WriteString(fmt.Sprintf("compress data %d", i))
|
||||
// Properly return buffer to pool
|
||||
tokenCompressionPool.PutCompressionBuffer(compBuf)
|
||||
|
||||
decompBuf := tokenCompressionPool.GetDecompressionBuffer()
|
||||
// Simulate decompression operations
|
||||
decompBuf.WriteString(fmt.Sprintf("decompress data %d", i))
|
||||
// Properly return buffer to pool
|
||||
tokenCompressionPool.PutDecompressionBuffer(decompBuf)
|
||||
|
||||
sb := tokenCompressionPool.GetStringBuilder()
|
||||
// Simulate string operations
|
||||
sb.WriteString(fmt.Sprintf("token string %d", i))
|
||||
_ = sb.String()
|
||||
// Properly return string builder to pool
|
||||
tokenCompressionPool.PutStringBuilder(sb)
|
||||
}
|
||||
|
||||
// Take intermediate snapshot
|
||||
intermediate, err := profiler.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take intermediate snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Phase 2: Continue with more intensive operations to test sustained usage
|
||||
t.Log("Phase 2: Testing sustained memory pool usage...")
|
||||
|
||||
// Simulate mixed operations with varying patterns
|
||||
for i := 0; i < 200; i++ {
|
||||
// Mix different pool operations
|
||||
switch i % 4 {
|
||||
case 0:
|
||||
buf := memoryPoolManager.GetCompressionBuffer()
|
||||
buf.WriteString("mixed operation data")
|
||||
memoryPoolManager.PutCompressionBuffer(buf)
|
||||
case 1:
|
||||
jwtBuf := memoryPoolManager.GetJWTParsingBuffer()
|
||||
jwtBuf.HeaderBuf = append(jwtBuf.HeaderBuf, []byte("mixed")...)
|
||||
memoryPoolManager.PutJWTParsingBuffer(jwtBuf)
|
||||
case 2:
|
||||
httpBuf := memoryPoolManager.GetHTTPResponseBuffer()
|
||||
copy(httpBuf[:min(len(httpBuf), 50)], []byte("mixed http"))
|
||||
memoryPoolManager.PutHTTPResponseBuffer(httpBuf)
|
||||
case 3:
|
||||
sb := memoryPoolManager.GetStringBuilder()
|
||||
sb.WriteString("mixed string building")
|
||||
_ = sb.String()
|
||||
memoryPoolManager.PutStringBuilder(sb)
|
||||
}
|
||||
}
|
||||
|
||||
// Take final snapshot
|
||||
current, err := profiler.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take current snapshot: %v", err)
|
||||
}
|
||||
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
|
||||
// Analyze for leaks
|
||||
analysis := profiler.AnalyzeLeaks(baseline, current)
|
||||
|
||||
// Assertions for memory leaks
|
||||
if analysis.HasLeak {
|
||||
if strictMode {
|
||||
t.Errorf("Memory leak detected in memory pool operations: %s", analysis.LeakDescription)
|
||||
for _, leak := range analysis.SuspectedLeaks {
|
||||
t.Errorf("Suspected leak: %s", leak)
|
||||
}
|
||||
} else {
|
||||
t.Logf("Memory leak warning in memory pool operations: %s", analysis.LeakDescription)
|
||||
for _, leak := range analysis.SuspectedLeaks {
|
||||
t.Logf("Suspected leak: %s", leak)
|
||||
}
|
||||
}
|
||||
for _, rec := range analysis.Recommendations {
|
||||
t.Logf("Recommendation: %s", rec)
|
||||
}
|
||||
}
|
||||
|
||||
// Check total memory growth
|
||||
totalMemoryIncrease := current.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
|
||||
if totalMemoryIncrease > 15*1024*1024 { // 15MB threshold for entire test
|
||||
if strictMode {
|
||||
t.Errorf("Total memory usage increased by %.2f MB during memory pool operations", float64(totalMemoryIncrease)/(1024*1024))
|
||||
} else {
|
||||
t.Logf("Total memory usage increased by %.2f MB during memory pool operations", float64(totalMemoryIncrease)/(1024*1024))
|
||||
}
|
||||
}
|
||||
|
||||
// Check for gradual memory growth patterns
|
||||
intermediateMemoryIncrease := intermediate.RuntimeStats.Alloc - baseline.RuntimeStats.Alloc
|
||||
if intermediateMemoryIncrease > 8*1024*1024 { // 8MB threshold for first phase
|
||||
if strictMode {
|
||||
t.Errorf("Memory usage increased by %.2f MB during first phase of memory pool operations", float64(intermediateMemoryIncrease)/(1024*1024))
|
||||
} else {
|
||||
t.Logf("Memory usage increased by %.2f MB during first phase of memory pool operations", float64(intermediateMemoryIncrease)/(1024*1024))
|
||||
}
|
||||
}
|
||||
|
||||
// Check goroutine count stability
|
||||
goroutineIncrease := finalGoroutines - initialGoroutines
|
||||
if goroutineIncrease > 3 { // Allow small variance for test environment
|
||||
if strictMode {
|
||||
t.Errorf("Goroutine count increased by %d during memory pool operations (initial: %d, final: %d)",
|
||||
goroutineIncrease, initialGoroutines, finalGoroutines)
|
||||
} else {
|
||||
t.Logf("Goroutine count increased by %d during memory pool operations (initial: %d, final: %d)",
|
||||
goroutineIncrease, initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Test cleanup verification
|
||||
t.Log("Phase 3: Testing cleanup verification...")
|
||||
|
||||
// Force garbage collection to see if pools are properly managed
|
||||
runtime.GC()
|
||||
runtime.GC() // Run twice to ensure cleanup
|
||||
|
||||
time.Sleep(10 * time.Millisecond) // Allow cleanup to complete
|
||||
|
||||
// Take post-cleanup snapshot
|
||||
postCleanup, err := profiler.TakeSnapshot()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to take post-cleanup snapshot: %v", err)
|
||||
}
|
||||
|
||||
// Check if memory decreased after cleanup
|
||||
if postCleanup.RuntimeStats.Alloc < current.RuntimeStats.Alloc {
|
||||
memoryDecrease := current.RuntimeStats.Alloc - postCleanup.RuntimeStats.Alloc
|
||||
t.Logf("Memory decreased by %.2f MB after cleanup phase", float64(memoryDecrease)/(1024*1024))
|
||||
} else if postCleanup.RuntimeStats.Alloc > current.RuntimeStats.Alloc {
|
||||
memoryIncrease := postCleanup.RuntimeStats.Alloc - current.RuntimeStats.Alloc
|
||||
if strictMode {
|
||||
t.Errorf("Memory increased by %.2f MB after cleanup phase - possible cleanup issues", float64(memoryIncrease)/(1024*1024))
|
||||
} else {
|
||||
t.Logf("Memory increased by %.2f MB after cleanup phase - possible cleanup issues", float64(memoryIncrease)/(1024*1024))
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Memory pool leak detection test completed")
|
||||
t.Logf("Memory usage: baseline=%.2f MB, intermediate=%.2f MB, final=%.2f MB, post-cleanup=%.2f MB",
|
||||
float64(baseline.RuntimeStats.Alloc)/(1024*1024),
|
||||
float64(intermediate.RuntimeStats.Alloc)/(1024*1024),
|
||||
float64(current.RuntimeStats.Alloc)/(1024*1024),
|
||||
float64(postCleanup.RuntimeStats.Alloc)/(1024*1024))
|
||||
}
|
||||
@@ -74,6 +74,9 @@ func TestRefreshGracePeriodConfiguration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTokenRefreshWithinGracePeriod(t *testing.T) {
|
||||
// Reset global state to prevent test interference
|
||||
resetGlobalState()
|
||||
|
||||
refreshCount := int32(0)
|
||||
tokenVersion := int32(1)
|
||||
|
||||
@@ -286,6 +289,9 @@ func TestGracePeriodWithProviderSpecificBehavior(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRefreshGracePeriodConcurrency(t *testing.T) {
|
||||
// Reset global state to prevent test interference
|
||||
resetGlobalState()
|
||||
|
||||
var refreshMutex sync.Mutex
|
||||
refreshCount := 0
|
||||
blockedRequests := int32(0)
|
||||
|
||||
@@ -40,6 +40,7 @@ func TestConcurrentTokenVerification(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create a fresh instance for this test
|
||||
var goroutineWG sync.WaitGroup
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
@@ -51,6 +52,7 @@ func TestConcurrentTokenVerification(t *testing.T) {
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
httpClient: &http.Client{},
|
||||
extractClaimsFunc: extractClaims,
|
||||
goroutineWG: &goroutineWG,
|
||||
}
|
||||
tOidc.tokenVerifier = tOidc
|
||||
tOidc.jwtVerifier = tOidc
|
||||
@@ -497,6 +499,7 @@ func TestMaliciousInputValidation(t *testing.T) {
|
||||
for _, test := range maliciousInputs {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Create a fresh instance for each test to avoid rate limiting issues
|
||||
var goroutineWG sync.WaitGroup
|
||||
freshOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
@@ -508,6 +511,7 @@ func TestMaliciousInputValidation(t *testing.T) {
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
httpClient: &http.Client{},
|
||||
extractClaimsFunc: extractClaims,
|
||||
goroutineWG: &goroutineWG,
|
||||
}
|
||||
freshOidc.tokenVerifier = freshOidc
|
||||
freshOidc.jwtVerifier = freshOidc
|
||||
@@ -731,6 +735,7 @@ func TestPerformanceUnderLoad(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create fresh instance with high rate limit
|
||||
var goroutineWG sync.WaitGroup
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
@@ -742,6 +747,7 @@ func TestPerformanceUnderLoad(t *testing.T) {
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
httpClient: &http.Client{},
|
||||
extractClaimsFunc: extractClaims,
|
||||
goroutineWG: &goroutineWG,
|
||||
}
|
||||
tOidc.tokenVerifier = tOidc
|
||||
tOidc.jwtVerifier = tOidc
|
||||
|
||||
+14
-3
@@ -245,7 +245,8 @@ type SessionManager struct {
|
||||
forceHTTPS bool
|
||||
cookieDomain string
|
||||
chunkManager *ChunkManager
|
||||
cleanupDone bool // Track if we've attempted cookie cleanup
|
||||
cleanupMutex sync.RWMutex // Protects cleanup operations
|
||||
cleanupDone bool // Track if we've attempted cookie cleanup
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new session manager with the specified configuration.
|
||||
@@ -628,7 +629,11 @@ func (sm *SessionManager) CleanupOldCookies(w http.ResponseWriter, r *http.Reque
|
||||
|
||||
// If we have a configured domain and this is the first request after config change,
|
||||
// attempt to delete the cookie with various domain variations to ensure cleanup
|
||||
if currentDomain != "" && !sm.cleanupDone {
|
||||
sm.cleanupMutex.RLock()
|
||||
shouldCleanup := currentDomain != "" && !sm.cleanupDone
|
||||
sm.cleanupMutex.RUnlock()
|
||||
|
||||
if shouldCleanup {
|
||||
for _, domain := range domainsToClean {
|
||||
// Skip the current configured domain
|
||||
if domain == currentDomain || domain == "."+currentDomain || "."+domain == currentDomain {
|
||||
@@ -655,7 +660,9 @@ func (sm *SessionManager) CleanupOldCookies(w http.ResponseWriter, r *http.Reque
|
||||
|
||||
// Mark cleanup as done for this session manager instance
|
||||
if !sm.cleanupDone && len(processedCookies) > 0 {
|
||||
sm.cleanupMutex.Lock()
|
||||
sm.cleanupDone = true
|
||||
sm.cleanupMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -671,6 +678,8 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
sessionData.dirty = false
|
||||
|
||||
var sessionReturned bool
|
||||
// CRITICAL FIX: Defer logic to return session to pool only on panic, not on successful return
|
||||
// This prevents memory leaks by ensuring sessions are only pooled when not in use by caller
|
||||
defer func() {
|
||||
if !sessionReturned && sessionData != nil {
|
||||
if r := recover(); r != nil {
|
||||
@@ -736,7 +745,9 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
|
||||
sm.getTokenChunkSessions(r, idTokenCookie, sessionData.idTokenChunks)
|
||||
|
||||
sessionReturned = false
|
||||
// CRITICAL FIX: Mark session as successfully returned to caller - prevents premature pool return
|
||||
// This fixes memory leak where sessions were incorrectly returned to pool immediately after success
|
||||
sessionReturned = true
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
|
||||
+37
-1
@@ -6,6 +6,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +26,33 @@ func (w *testWriter) Write(p []byte) (n int, err error) {
|
||||
|
||||
// Test helper adapters for the new test files
|
||||
|
||||
// resetGlobalState resets all global singletons to prevent test interference
|
||||
func resetGlobalState() {
|
||||
// Reset global cache manager
|
||||
cacheManagerMutex.Lock()
|
||||
if globalCacheManager != nil {
|
||||
globalCacheManager.Close()
|
||||
globalCacheManager = nil
|
||||
}
|
||||
cacheManagerOnce = sync.Once{}
|
||||
cacheManagerMutex.Unlock()
|
||||
|
||||
// Reset replay cache
|
||||
replayCacheMu.Lock()
|
||||
if replayCache != nil {
|
||||
replayCache.Close()
|
||||
replayCache = nil
|
||||
}
|
||||
replayCacheOnce = sync.Once{}
|
||||
replayCacheMu.Unlock()
|
||||
|
||||
// Reset memory pools
|
||||
memoryPoolMutex.Lock()
|
||||
globalMemoryPools = nil
|
||||
memoryPoolOnce = sync.Once{}
|
||||
memoryPoolMutex.Unlock()
|
||||
}
|
||||
|
||||
// createTestConfig creates a config with all required fields populated for testing
|
||||
func createTestConfig() *Config {
|
||||
config := CreateConfig()
|
||||
@@ -38,6 +66,9 @@ func createTestConfig() *Config {
|
||||
|
||||
// setupTestOIDCMiddleware creates a test OIDC middleware instance with mock servers
|
||||
func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httptest.Server) {
|
||||
// Reset global state to ensure test isolation
|
||||
resetGlobalState()
|
||||
|
||||
// Create mock OIDC server
|
||||
var serverURL string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -114,6 +145,9 @@ func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httpt
|
||||
testTokenURL := testIssuerURL + "/token"
|
||||
testJWKSURL := testIssuerURL + "/keys"
|
||||
|
||||
// Create WaitGroup for background goroutines
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Create TraefikOidc instance directly
|
||||
oidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
@@ -143,8 +177,10 @@ func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httpt
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
allowedUserDomains: make(map[string]struct{}),
|
||||
jwkCache: &JWKCache{},
|
||||
metadataCache: NewMetadataCache(),
|
||||
metadataCache: NewMetadataCache(nil),
|
||||
ctx: context.Background(),
|
||||
goroutineWG: &wg,
|
||||
providerURL: serverURL,
|
||||
}
|
||||
|
||||
// Process excluded URLs
|
||||
|
||||
+1
-1
@@ -14,4 +14,4 @@ func generateRandomString(length int) string {
|
||||
return "random-string-fallback"
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user