Files
traefikoidc/singleton_resources.go
lukaszraczylo 2d1b04c637 review fixes apr 2026 (#130)
* Multiple fixes

- refresh coordinator dedup + memory pressure wire
- middleware sse consolidation + timer leak + claim cache
- universal cache sync backfill + isDebug gate
- lazy background task race
- memory monitor stw cached + refresh() api

* fix(auth): suppress OIDC redirects on non-navigation requests

- [x] Add isNonNavigationRequest using Sec-Fetch-Mode and Accept headers
- [x] Add comprehensive TestIsNonNavigationRequest
- [x] Update ServeHTTP to 401 non-navigation and AJAX requests

Fixes #129

* feat(config): add custom CA and insecure skip verify for OIDC TLS

- [x] Add CACertPath, CACertPEM, InsecureSkipVerify to Config
- [x] Implement loadCACertPool for CA bundle loading
- [x] Update HTTPClientConfig with RootCAs and InsecureSkipVerify
- [x] Apply CA pool and skip verify to pooled HTTP clients
- [x] Enhance configKey to distinguish TLS configs
- [x] Add comprehensive ca_cert_test.go

Fixes #125

* feat(oidc): add custom CA certificate support for private OIDC providers

- [x] Add caCertPath, caCertPEM, insecureSkipVerify config options
- [x] Update traefik.yml with new OIDC client config fields
- [x] Add configuration schema descriptions for new options
- [x] Update README table and add Custom CA Certificates section

* Fix the documentation.

* test(redis): add oversized argument rejection test

- [x] Add TestRedisConn_RejectOversizedArgumentBytes
- [x] Import strings package

* Dependencies cleanup
2026-04-19 10:12:00 +01:00

581 lines
14 KiB
Go

package traefikoidc
import (
"context"
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
)
var (
globalResourceManager *ResourceManager
resourceManagerOnce sync.Once
resourceManagerMutex sync.Mutex
)
// ResourceManager manages shared resources across all middleware instances
// to prevent duplication and goroutine leaks when Traefik recreates middleware
type ResourceManager struct {
references map[string]*int32
caches map[string]interface{}
httpClients map[string]*http.Client
tasks map[string]*BackgroundTask
shutdownChan chan struct{}
pools map[string]*GoroutinePool
logger *Logger
wg sync.WaitGroup
cachesMu sync.RWMutex
referencesMu sync.RWMutex
poolsMu sync.RWMutex
tasksMu sync.RWMutex
clientsMu sync.RWMutex
shutdownOnce sync.Once
}
// GetResourceManager returns the global singleton ResourceManager instance
func GetResourceManager() *ResourceManager {
resourceManagerOnce.Do(func() {
globalResourceManager = &ResourceManager{
httpClients: make(map[string]*http.Client),
caches: make(map[string]interface{}),
tasks: make(map[string]*BackgroundTask),
pools: make(map[string]*GoroutinePool),
references: make(map[string]*int32),
logger: GetSingletonNoOpLogger(),
shutdownChan: make(chan struct{}),
}
})
return globalResourceManager
}
// GetHTTPClient returns a shared HTTP client for the given key
func (rm *ResourceManager) GetHTTPClient(key string) *http.Client {
rm.clientsMu.RLock()
client, exists := rm.httpClients[key]
rm.clientsMu.RUnlock()
if exists {
return client
}
rm.clientsMu.Lock()
defer rm.clientsMu.Unlock()
// Double-check after acquiring write lock
if client, exists := rm.httpClients[key]; exists {
return client
}
// SECURITY FIX: Use secure HTTP client configuration with limits
config := DefaultHTTPClientConfig()
factory := NewHTTPClientFactory()
client = factory.CreateHTTPClient(config)
rm.httpClients[key] = client
return client
}
// GetCache returns a shared cache for the given key
func (rm *ResourceManager) GetCache(key string) interface{} {
rm.cachesMu.RLock()
cache, exists := rm.caches[key]
rm.cachesMu.RUnlock()
if exists {
return cache
}
rm.cachesMu.Lock()
defer rm.cachesMu.Unlock()
// Double-check after acquiring write lock
if cache, exists := rm.caches[key]; exists {
return cache
}
// Create cache based on key type
// Use global cache manager for proper singleton caches
cacheManager := GetGlobalCacheManager(&rm.wg)
switch key {
case "metadata-cache":
cache = cacheManager.GetSharedMetadataCache()
case "token-cache":
cache = cacheManager.GetSharedTokenCache()
case "jwk-cache":
cache = cacheManager.GetSharedJWKCache()
default:
// Generic cache implementation
cache = NewGenericCache(1*time.Hour, rm.logger)
}
rm.caches[key] = cache
return cache
}
// RegisterBackgroundTask registers a singleton background task
func (rm *ResourceManager) RegisterBackgroundTask(name string, interval time.Duration, taskFunc func()) error {
rm.tasksMu.Lock()
defer rm.tasksMu.Unlock()
// Check if task already exists
if _, exists := rm.tasks[name]; exists {
if rm.logger != nil {
rm.logger.Debugf("Background task %s already registered", name)
}
// Return existing task without error for idempotency
return nil
}
// Create new task with WaitGroup for proper cleanup
task := NewBackgroundTask(name, interval, taskFunc, rm.logger, &rm.wg)
rm.tasks[name] = task
if rm.logger != nil {
rm.logger.Infof("Registered singleton background task: %s", name)
}
return nil
}
// StartBackgroundTask starts a registered background task
func (rm *ResourceManager) StartBackgroundTask(name string) error {
rm.tasksMu.RLock()
task, exists := rm.tasks[name]
rm.tasksMu.RUnlock()
if !exists {
return fmt.Errorf("task %s not registered", name)
}
task.Start()
return nil
}
// StopBackgroundTask stops a running background task
func (rm *ResourceManager) StopBackgroundTask(name string) error {
rm.tasksMu.RLock()
task, exists := rm.tasks[name]
rm.tasksMu.RUnlock()
if !exists {
return fmt.Errorf("task %s not registered", name)
}
task.Stop()
return nil
}
// IsTaskRunning checks if a background task is running
func (rm *ResourceManager) IsTaskRunning(name string) bool {
rm.tasksMu.RLock()
task, exists := rm.tasks[name]
rm.tasksMu.RUnlock()
if !exists {
return false
}
// Check if task has been started and not stopped
return atomic.LoadInt32(&task.started) == 1 && atomic.LoadInt32(&task.stopped) == 0
}
// GetGoroutinePool returns a shared goroutine pool for controlled concurrency
func (rm *ResourceManager) GetGoroutinePool(key string, maxWorkers int) *GoroutinePool {
rm.poolsMu.RLock()
pool, exists := rm.pools[key]
rm.poolsMu.RUnlock()
if exists {
return pool
}
rm.poolsMu.Lock()
defer rm.poolsMu.Unlock()
// Double-check after acquiring write lock
if pool, exists := rm.pools[key]; exists {
return pool
}
// Create new pool
pool = NewGoroutinePool(maxWorkers, rm.logger)
rm.pools[key] = pool
return pool
}
// AddReference increments the reference count for a given instance
func (rm *ResourceManager) AddReference(instanceID string) {
rm.referencesMu.Lock()
defer rm.referencesMu.Unlock()
if count, exists := rm.references[instanceID]; exists {
atomic.AddInt32(count, 1)
} else {
initial := int32(1)
rm.references[instanceID] = &initial
}
if rm.logger != nil {
rm.logger.Debugf("Added reference for instance %s", instanceID)
}
}
// RemoveReference decrements the reference count and triggers cleanup if needed
func (rm *ResourceManager) RemoveReference(instanceID string) {
rm.referencesMu.Lock()
defer rm.referencesMu.Unlock()
if count, exists := rm.references[instanceID]; exists {
newCount := atomic.AddInt32(count, -1)
if newCount <= 0 {
delete(rm.references, instanceID)
if rm.logger != nil {
rm.logger.Debugf("Removed last reference for instance %s", instanceID)
}
// Trigger cleanup for this instance if needed
rm.cleanupInstance(instanceID)
}
}
}
// GetReferenceCount returns the current reference count for an instance
func (rm *ResourceManager) GetReferenceCount(instanceID string) int32 {
rm.referencesMu.RLock()
defer rm.referencesMu.RUnlock()
if count, exists := rm.references[instanceID]; exists {
return atomic.LoadInt32(count)
}
return 0
}
// cleanupInstance performs cleanup for a specific instance when its reference count reaches zero
func (rm *ResourceManager) cleanupInstance(instanceID string) {
// Instance-specific cleanup logic
if rm.logger != nil {
rm.logger.Infof("Cleaning up resources for instance %s", instanceID)
}
// Clean up any instance-specific resources
// This is a hook for future instance-specific cleanup needs
}
// Shutdown gracefully shuts down all managed resources
func (rm *ResourceManager) Shutdown(ctx context.Context) error {
var err error
rm.shutdownOnce.Do(func() {
close(rm.shutdownChan)
if rm.logger != nil {
rm.logger.Info("Starting ResourceManager shutdown")
}
// Stop all background tasks
rm.tasksMu.RLock()
tasks := make([]*BackgroundTask, 0, len(rm.tasks))
for _, task := range rm.tasks {
tasks = append(tasks, task)
}
rm.tasksMu.RUnlock()
for _, task := range tasks {
task.Stop()
}
// Shutdown all goroutine pools
rm.poolsMu.RLock()
pools := make([]*GoroutinePool, 0, len(rm.pools))
for _, pool := range rm.pools {
pools = append(pools, pool)
}
rm.poolsMu.RUnlock()
for _, pool := range pools {
if shutdownErr := pool.Shutdown(ctx); shutdownErr != nil && err == nil {
err = shutdownErr
}
}
// Wait for all goroutines with timeout
done := make(chan struct{})
go func() {
rm.wg.Wait()
close(done)
}()
select {
case <-done:
if rm.logger != nil {
rm.logger.Info("ResourceManager shutdown completed successfully")
}
case <-ctx.Done():
err = fmt.Errorf("shutdown timeout: %w", ctx.Err())
if rm.logger != nil {
rm.logger.Errorf("ResourceManager shutdown timeout: %v", err)
}
}
})
return err
}
// GoroutinePool provides a pool of workers for controlled concurrency
type GoroutinePool struct {
taskQueue chan func()
shutdownChan chan struct{}
logger *Logger
taskCond *sync.Cond
workerWG sync.WaitGroup
maxWorkers int
pendingTasks int64
shutdownOnce sync.Once
started int32
}
// NewGoroutinePool creates a new goroutine pool with the specified max workers
func NewGoroutinePool(maxWorkers int, logger *Logger) *GoroutinePool {
pool := &GoroutinePool{
maxWorkers: maxWorkers,
taskQueue: make(chan func(), maxWorkers*2), // Buffer for queuing
shutdownChan: make(chan struct{}),
logger: logger,
taskCond: sync.NewCond(&sync.Mutex{}),
pendingTasks: 0,
}
// Start workers
for i := 0; i < maxWorkers; i++ {
pool.workerWG.Add(1)
go pool.worker(i)
}
atomic.StoreInt32(&pool.started, 1)
if logger != nil {
logger.Infof("Created goroutine pool with %d workers", maxWorkers)
}
return pool
}
// worker is the main loop for a pool worker
func (p *GoroutinePool) worker(id int) {
defer p.workerWG.Done()
for {
select {
case task := <-p.taskQueue:
if task != nil {
// Execute task with panic recovery
func() {
defer func() {
if r := recover(); r != nil {
if p.logger != nil {
p.logger.Errorf("Worker %d panic recovered: %v", id, r)
}
}
}()
task()
}()
// Signal that task is complete - decrement pending count and notify waiters
newCount := atomic.AddInt64(&p.pendingTasks, -1)
if newCount == 0 {
p.taskCond.L.Lock()
p.taskCond.Broadcast() // Wake up all waiters when queue is empty
p.taskCond.L.Unlock()
}
}
case <-p.shutdownChan:
if p.logger != nil {
p.logger.Debugf("Worker %d shutting down", id)
}
return
}
}
}
// Submit submits a task to the pool
func (p *GoroutinePool) Submit(task func()) error {
if atomic.LoadInt32(&p.started) == 0 {
return fmt.Errorf("pool is shutdown")
}
// Increment pending task count BEFORE queuing to avoid race with Wait()
atomic.AddInt64(&p.pendingTasks, 1)
select {
case p.taskQueue <- task:
return nil
case <-p.shutdownChan:
// Decrement since task won't be processed
atomic.AddInt64(&p.pendingTasks, -1)
return fmt.Errorf("pool is shutting down")
default:
// Queue is full, try with a small timeout
select {
case p.taskQueue <- task:
return nil
case <-time.After(100 * time.Millisecond):
// Decrement since task won't be processed
atomic.AddInt64(&p.pendingTasks, -1)
return fmt.Errorf("task queue is full")
case <-p.shutdownChan:
// Decrement since task won't be processed
atomic.AddInt64(&p.pendingTasks, -1)
return fmt.Errorf("pool is shutting down")
}
}
}
// Wait waits for all submitted tasks to complete using condition variable
// This is efficient and does not busy-poll, avoiding CPU spikes
func (p *GoroutinePool) Wait() {
p.taskCond.L.Lock()
defer p.taskCond.L.Unlock()
// Wait until all pending tasks are complete
// Uses condition variable to sleep efficiently instead of busy-polling
for atomic.LoadInt64(&p.pendingTasks) > 0 {
p.taskCond.Wait() // Efficiently blocks until signaled
}
}
// WaitWithTimeout waits for all submitted tasks to complete with a timeout
// Returns true if all tasks completed, false if timeout occurred
func (p *GoroutinePool) WaitWithTimeout(timeout time.Duration) bool {
done := make(chan struct{})
go func() {
p.Wait()
close(done)
}()
select {
case <-done:
return true
case <-time.After(timeout):
return false
}
}
// PendingTasks returns the number of tasks currently pending (queued or in-progress)
func (p *GoroutinePool) PendingTasks() int64 {
return atomic.LoadInt64(&p.pendingTasks)
}
// Shutdown gracefully shuts down the pool
func (p *GoroutinePool) Shutdown(ctx context.Context) error {
var err error
p.shutdownOnce.Do(func() {
atomic.StoreInt32(&p.started, 0)
close(p.shutdownChan)
// Wait for workers to finish with context timeout
done := make(chan struct{})
go func() {
p.workerWG.Wait()
close(done)
}()
select {
case <-done:
if p.logger != nil {
p.logger.Debug("Goroutine pool shutdown completed")
}
case <-ctx.Done():
err = fmt.Errorf("pool shutdown timeout: %w", ctx.Err())
if p.logger != nil {
p.logger.Errorf("Goroutine pool shutdown timeout: %v", err)
}
}
})
return err
}
// GenericCache provides a simple cache implementation for testing
type GenericCache struct {
data map[string]interface{}
logger *Logger
stopChan chan struct{}
ttl time.Duration
mu sync.RWMutex
}
// NewGenericCache creates a new generic cache
func NewGenericCache(ttl time.Duration, logger *Logger) *GenericCache {
cache := &GenericCache{
data: make(map[string]interface{}),
ttl: ttl,
logger: logger,
stopChan: make(chan struct{}),
}
// Start cleanup routine
go cache.cleanupRoutine()
return cache
}
// Get retrieves a value from the cache
func (gc *GenericCache) Get(key string) (interface{}, bool) {
gc.mu.RLock()
defer gc.mu.RUnlock()
val, exists := gc.data[key]
return val, exists
}
// Set stores a value in the cache
func (gc *GenericCache) Set(key string, value interface{}) {
gc.mu.Lock()
defer gc.mu.Unlock()
gc.data[key] = value
}
// Delete removes a value from the cache
func (gc *GenericCache) Delete(key string) {
gc.mu.Lock()
defer gc.mu.Unlock()
delete(gc.data, key)
}
// cleanupRoutine periodically wipes the cache.
//
// NOTE: GenericCache does not track per-entry timestamps, so this is a
// "clear-all on tick" strategy — every `gc.ttl` interval the entire map
// is replaced, regardless of when each entry was written. This is the
// intentional (simplified) behavior of GenericCache, which exists mainly
// as a generic fallback for tests and non-typed caches. Callers that
// require true per-entry TTL must use UniversalCache / UnifiedCache which
// track expiry per entry.
func (gc *GenericCache) cleanupRoutine() {
wipeTicker := time.NewTicker(gc.ttl)
defer wipeTicker.Stop()
for {
select {
case <-wipeTicker.C:
gc.mu.Lock()
// Clear-all on tick, not per-entry TTL (see function doc).
gc.data = make(map[string]interface{})
gc.mu.Unlock()
case <-gc.stopChan:
return
}
}
}
// Stop stops the cleanup routine
func (gc *GenericCache) Stop() {
close(gc.stopChan)
}