Files
traefikoidc/singleton_resources_test.go
T
lukaszraczylo 57724918fe fix 116 (#118)
* Fix cache serialisation

* fix(cache): add integer overflow protection for serialization

- [x] Add maxCacheEntrySize constant (64 MiB) to prevent memory overflow
- [x] Validate byte slice size before adding marker byte
- [x] Validate JSON-serialized data size before marker addition
- [x] Add comprehensive overflow protection test cases

* docs: add security fix documentation for integer overflow protection

* test: fix goroutine tests to use mock OIDC servers

The TestContextAwareGoroutineManagement tests were making real HTTP
calls to hardcoded URLs like https://example.com, causing failures
in CI when those requests timeout or return HTTP errors.

Changes:
- Added createMockOIDCServer() helper function using httptest
- Updated GoroutineCleanupOnContextCancel to use mock server
- Updated NoGoroutineLeakOnMultipleInstances to use 3 mock servers
- Updated SingletonTasksAcrossInstances to use mock servers array

This prevents network calls and makes tests more reliable and faster.

Fixes test failures in GitHub Actions CI.
2026-01-08 22:50:46 +00:00

932 lines
24 KiB
Go

package traefikoidc
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
// TestSingletonResourceManager tests the singleton resource manager implementation
func TestSingletonResourceManager(t *testing.T) {
t.Run("SingletonInstance", func(t *testing.T) {
// Test that GetResourceManager returns the same instance
rm1 := GetResourceManager()
rm2 := GetResourceManager()
if rm1 != rm2 {
t.Error("GetResourceManager did not return singleton instance")
}
})
t.Run("ThreadSafeInitialization", func(t *testing.T) {
// Reset singleton for test
resetResourceManagerForTesting()
const numGoroutines = 100
instances := make([]*ResourceManager, numGoroutines)
var wg sync.WaitGroup
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
instances[idx] = GetResourceManager()
}(i)
}
wg.Wait()
// Verify all instances are the same
first := instances[0]
for i := 1; i < numGoroutines; i++ {
if instances[i] != first {
t.Errorf("Instance %d differs from first instance", i)
}
}
})
t.Run("SharedHTTPClient", func(t *testing.T) {
rm := GetResourceManager()
client1 := rm.GetHTTPClient("test-client-1")
client2 := rm.GetHTTPClient("test-client-1")
if client1 != client2 {
t.Error("GetHTTPClient did not return same client for same key")
}
client3 := rm.GetHTTPClient("test-client-2")
if client1 == client3 {
t.Error("GetHTTPClient returned same client for different keys")
}
})
t.Run("SharedCache", func(t *testing.T) {
rm := GetResourceManager()
cache1 := rm.GetCache("test-cache-1")
cache2 := rm.GetCache("test-cache-1")
if cache1 != cache2 {
t.Error("GetCache did not return same cache for same key")
}
})
t.Run("SingletonTaskRegistry", func(t *testing.T) {
rm := GetResourceManager()
err := rm.RegisterBackgroundTask("test-task", 1*time.Second, func() {
// Test task
})
if err != nil {
t.Errorf("Failed to register task: %v", err)
}
// Try to register same task again - should return existing
err = rm.RegisterBackgroundTask("test-task", 1*time.Second, func() {
// Duplicate task
})
if err != nil {
t.Errorf("Failed to handle duplicate task registration: %v", err)
}
})
t.Run("MultiRealmMetadataRefreshTaskNaming", func(t *testing.T) {
// This test verifies that different provider URLs generate different task names
// which is critical for multi-realm Keycloak support (PR #88)
// Reset singletons for clean test state
resetResourceManagerForTesting()
ResetGlobalTaskRegistry()
defer ResetGlobalTaskRegistry()
rm := GetResourceManager()
// Simulate different Keycloak realms
providerURL1 := "https://keycloak.example.com/realms/realm1"
providerURL2 := "https://keycloak.example.com/realms/realm2"
// Generate task names using the same logic as startMetadataRefresh
hash1 := sha256.Sum256([]byte(providerURL1))
taskName1 := "singleton-metadata-refresh-" + hex.EncodeToString(hash1[:])[0:6]
hash2 := sha256.Sum256([]byte(providerURL2))
taskName2 := "singleton-metadata-refresh-" + hex.EncodeToString(hash2[:])[0:6]
// Verify task names are different
if taskName1 == taskName2 {
t.Errorf("Task names should be different for different provider URLs: %s vs %s", taskName1, taskName2)
}
// Register both tasks
task1Called := int32(0)
task2Called := int32(0)
err := rm.RegisterBackgroundTask(taskName1, 100*time.Millisecond, func() {
atomic.AddInt32(&task1Called, 1)
})
if err != nil {
t.Errorf("Failed to register task 1: %v", err)
}
err = rm.RegisterBackgroundTask(taskName2, 100*time.Millisecond, func() {
atomic.AddInt32(&task2Called, 1)
})
if err != nil {
t.Errorf("Failed to register task 2: %v", err)
}
// Start both tasks
_ = rm.StartBackgroundTask(taskName1)
_ = rm.StartBackgroundTask(taskName2)
// Wait for tasks to execute
time.Sleep(250 * time.Millisecond)
// Verify both tasks are running independently
if !rm.IsTaskRunning(taskName1) {
t.Error("Task 1 should be running")
}
if !rm.IsTaskRunning(taskName2) {
t.Error("Task 2 should be running")
}
// Verify both tasks were called (at least once)
if atomic.LoadInt32(&task1Called) == 0 {
t.Error("Task 1 should have been called at least once")
}
if atomic.LoadInt32(&task2Called) == 0 {
t.Error("Task 2 should have been called at least once")
}
// Stop both tasks
_ = rm.StopBackgroundTask(taskName1)
_ = rm.StopBackgroundTask(taskName2)
// Verify tasks are stopped
time.Sleep(50 * time.Millisecond)
if rm.IsTaskRunning(taskName1) {
t.Error("Task 1 should be stopped")
}
if rm.IsTaskRunning(taskName2) {
t.Error("Task 2 should be stopped")
}
t.Logf("Successfully verified multi-realm task isolation: task1=%s, task2=%s", taskName1, taskName2)
})
t.Run("ReferenceCountingCleanup", func(t *testing.T) {
rm := GetResourceManager()
// Add reference
rm.AddReference("test-instance-1")
// Get reference count
if rm.GetReferenceCount("test-instance-1") != 1 {
t.Error("Reference count should be 1")
}
// Add another reference
rm.AddReference("test-instance-1")
if rm.GetReferenceCount("test-instance-1") != 2 {
t.Error("Reference count should be 2")
}
// Remove reference
rm.RemoveReference("test-instance-1")
if rm.GetReferenceCount("test-instance-1") != 1 {
t.Error("Reference count should be 1 after removal")
}
// Remove last reference
rm.RemoveReference("test-instance-1")
if rm.GetReferenceCount("test-instance-1") != 0 {
t.Error("Reference count should be 0 after removing all references")
}
})
t.Run("GracefulShutdown", func(t *testing.T) {
rm := GetResourceManager()
// Register a task with atomic variable to avoid race condition
var taskExecuted int32
err := rm.RegisterBackgroundTask("shutdown-test-task", 100*time.Millisecond, func() {
atomic.StoreInt32(&taskExecuted, 1)
})
if err != nil {
t.Errorf("Failed to register task: %v", err)
}
// Start the task
rm.StartBackgroundTask("shutdown-test-task")
// Wait for task to execute at least once
time.Sleep(150 * time.Millisecond)
if atomic.LoadInt32(&taskExecuted) == 0 {
t.Error("Task was not executed")
}
// Shutdown
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = rm.Shutdown(ctx)
if err != nil {
t.Errorf("Shutdown failed: %v", err)
}
// Verify task is stopped
if rm.IsTaskRunning("shutdown-test-task") {
t.Error("Task should be stopped after shutdown")
}
})
}
// createMockOIDCServer creates a mock OIDC server for testing
func createMockOIDCServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/openid-configuration":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/authorize",
"token_endpoint": "https://example.com/token",
"jwks_uri": "https://example.com/jwks",
"userinfo_endpoint": "https://example.com/userinfo",
})
case "/jwks":
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"keys": []interface{}{},
})
default:
w.WriteHeader(http.StatusNotFound)
}
}))
}
// TestContextAwareGoroutineManagement tests context-aware goroutine management
func TestContextAwareGoroutineManagement(t *testing.T) {
t.Run("GoroutineCleanupOnContextCancel", func(t *testing.T) {
// Reset singletons to ensure clean state
resetResourceManagerForTesting()
ResetUniversalCacheManagerForTesting()
defer ResetUniversalCacheManagerForTesting()
// Create mock OIDC server
mockServer := createMockOIDCServer()
defer mockServer.Close()
initialGoroutines := runtime.NumGoroutine()
ctx, cancel := context.WithCancel(context.Background())
// Create a TraefikOidc instance with context
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
}
plugin, err := NewWithContext(ctx, config, nil, "test")
if err != nil {
t.Fatalf("Failed to create plugin: %v", err)
}
// Wait for goroutines to start
time.Sleep(100 * time.Millisecond)
midGoroutines := runtime.NumGoroutine()
if midGoroutines <= initialGoroutines {
t.Error("No goroutines were created")
}
// Cancel context
cancel()
// Close the plugin to trigger cleanup
plugin.Close()
// Wait for cleanup
time.Sleep(500 * time.Millisecond)
finalGoroutines := runtime.NumGoroutine()
// Allow for some singleton background goroutines (caches, pools, etc.)
// These are shared across all instances and persist for the test duration
tolerance := 10
if finalGoroutines > initialGoroutines+tolerance {
t.Errorf("Goroutine leak detected: initial=%d, final=%d", initialGoroutines, finalGoroutines)
}
})
t.Run("NoGoroutineLeakOnMultipleInstances", func(t *testing.T) {
// Reset singletons to ensure clean state
resetResourceManagerForTesting()
ResetUniversalCacheManagerForTesting()
defer ResetUniversalCacheManagerForTesting()
// Create mock OIDC servers
mockServer1 := createMockOIDCServer()
defer mockServer1.Close()
mockServer2 := createMockOIDCServer()
defer mockServer2.Close()
mockServer3 := createMockOIDCServer()
defer mockServer3.Close()
initialGoroutines := runtime.NumGoroutine()
configs := []Config{
{ProviderURL: mockServer1.URL, ClientID: "client1", ClientSecret: "secret1"},
{ProviderURL: mockServer2.URL, ClientID: "client2", ClientSecret: "secret2"},
{ProviderURL: mockServer3.URL, ClientID: "client3", ClientSecret: "secret3"},
}
var plugins []*TraefikOidc
var cancels []context.CancelFunc
// Create multiple instances
for i, config := range configs {
ctx, cancel := context.WithCancel(context.Background())
cancels = append(cancels, cancel)
plugin, err := NewWithContext(ctx, &config, nil, fmt.Sprintf("test-%d", i))
if err != nil {
t.Fatalf("Failed to create plugin %d: %v", i, err)
}
plugins = append(plugins, plugin)
}
// Wait for all goroutines to start
time.Sleep(200 * time.Millisecond)
midGoroutines := runtime.NumGoroutine()
// Cancel all contexts
for _, cancel := range cancels {
cancel()
}
// Close all plugins
for _, plugin := range plugins {
plugin.Close()
}
// Wait for cleanup
time.Sleep(500 * time.Millisecond)
finalGoroutines := runtime.NumGoroutine()
// Check for leaks
tolerance := 5
if finalGoroutines > initialGoroutines+tolerance {
t.Errorf("Goroutine leak with multiple instances: initial=%d, mid=%d, final=%d",
initialGoroutines, midGoroutines, finalGoroutines)
}
})
t.Run("SingletonTasksAcrossInstances", func(t *testing.T) {
// Reset singletons to ensure clean state
ResetGlobalTaskRegistry() // Reset circuit breaker and task registry
resetResourceManagerForTesting()
ResetUniversalCacheManagerForTesting()
defer ResetUniversalCacheManagerForTesting()
// Create mock OIDC servers
mockServers := make([]*httptest.Server, 3)
for i := 0; i < 3; i++ {
mockServers[i] = createMockOIDCServer()
defer mockServers[i].Close()
}
rm := GetResourceManager()
// Register singleton cleanup task
var cleanupCount int32
err := rm.RegisterBackgroundTask("singleton-cleanup", 100*time.Millisecond, func() {
atomic.AddInt32(&cleanupCount, 1)
})
if err != nil {
t.Fatalf("Failed to register singleton task: %v", err)
}
// Start the task
rm.StartBackgroundTask("singleton-cleanup")
// Create multiple plugin instances
var plugins []*TraefikOidc
for i := 0; i < 3; i++ {
ctx := context.Background()
config := &Config{
ProviderURL: mockServers[i].URL,
ClientID: fmt.Sprintf("client%d", i),
ClientSecret: fmt.Sprintf("secret%d", i),
}
plugin, err := NewWithContext(ctx, config, nil, fmt.Sprintf("test-%d", i))
if err != nil {
t.Fatalf("Failed to create plugin %d: %v", i, err)
}
plugins = append(plugins, plugin)
}
// Wait for cleanup to run at least 2 times with adaptive timeout
// This handles race detector overhead which can slow goroutine scheduling significantly
// When running as part of full test suite, CPU contention is even higher, so use generous timeout
const minExpectedCount = 2
const maxExpectedCount = 5
timeout := time.After(5 * time.Second)
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
var count int32
waitLoop:
for {
select {
case <-ticker.C:
count = atomic.LoadInt32(&cleanupCount)
if count >= minExpectedCount {
// Success: reached minimum threshold
break waitLoop
}
case <-timeout:
count = atomic.LoadInt32(&cleanupCount)
t.Errorf("Timeout waiting for cleanup count to reach %d, got %d (race detector may be slowing execution)", minExpectedCount, count)
break waitLoop
}
}
// Verify count is within expected range (should be singleton, not running excessively)
if count > maxExpectedCount {
t.Errorf("Cleanup count too high: %d (expected max %d for singleton)", count, maxExpectedCount)
}
// Cleanup
for _, plugin := range plugins {
plugin.Close()
}
rm.StopBackgroundTask("singleton-cleanup")
})
}
// TestResourcePooling tests resource pooling implementation
func TestResourcePooling(t *testing.T) {
t.Run("GoroutinePoolLimiting", func(t *testing.T) {
rm := GetResourceManager()
// Configure pool with max workers
pool := rm.GetGoroutinePool("test-pool", 5) // Max 5 workers
if pool == nil {
t.Fatal("Failed to get goroutine pool")
}
// Submit more tasks than pool size
var taskCount int32
var runningCount int32
maxRunning := int32(0)
for i := 0; i < 20; i++ {
err := pool.Submit(func() {
atomic.AddInt32(&taskCount, 1)
current := atomic.AddInt32(&runningCount, 1)
// Track max concurrent tasks
for {
oldMax := atomic.LoadInt32(&maxRunning)
if current <= oldMax || atomic.CompareAndSwapInt32(&maxRunning, oldMax, current) {
break
}
}
time.Sleep(50 * time.Millisecond)
atomic.AddInt32(&runningCount, -1)
})
if err != nil {
t.Errorf("Failed to submit task %d: %v", i, err)
}
}
// Wait for all tasks to complete
pool.Wait()
// Verify all tasks executed
if atomic.LoadInt32(&taskCount) != 20 {
t.Errorf("Expected 20 tasks to execute, got %d", taskCount)
}
// Verify concurrency was limited
if atomic.LoadInt32(&maxRunning) > 5 {
t.Errorf("Max concurrent tasks exceeded pool size: %d > 5", maxRunning)
}
})
t.Run("PoolShutdown", func(t *testing.T) {
rm := GetResourceManager()
pool := rm.GetGoroutinePool("shutdown-pool", 3)
// Submit tasks
var completed int32
for i := 0; i < 10; i++ {
pool.Submit(func() {
time.Sleep(10 * time.Millisecond)
atomic.AddInt32(&completed, 1)
})
}
// Shutdown pool
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err := pool.Shutdown(ctx)
if err != nil {
t.Errorf("Pool shutdown failed: %v", err)
}
// Try to submit after shutdown - should fail
err = pool.Submit(func() {
t.Error("Task should not execute after shutdown")
})
if err == nil {
t.Error("Expected error when submitting to shutdown pool")
}
})
t.Run("ResourceReuse", func(t *testing.T) {
rm := GetResourceManager()
// Get same pool multiple times
pool1 := rm.GetGoroutinePool("reuse-pool", 3)
pool2 := rm.GetGoroutinePool("reuse-pool", 3)
if pool1 != pool2 {
t.Error("Expected same pool instance for same key")
}
// Get HTTP client multiple times
client1 := rm.GetHTTPClient("reuse-client")
client2 := rm.GetHTTPClient("reuse-client")
if client1 != client2 {
t.Error("Expected same HTTP client instance for same key")
}
})
}
// TestBackwardCompatibility verifies the changes maintain backward compatibility
func TestBackwardCompatibility(t *testing.T) {
t.Run("LegacyNewFunction", func(t *testing.T) {
// Test that the original New function still works
config := &Config{
ProviderURL: "https://example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
}
handler, err := New(context.Background(), nil, config, "test")
if err != nil {
t.Fatalf("Legacy New function failed: %v", err)
}
if handler == nil {
t.Fatal("Handler should not be nil")
}
// Cleanup - cast to TraefikOidc if needed
if plugin, ok := handler.(*TraefikOidc); ok {
plugin.Close()
}
})
t.Run("ExistingAPICompatibility", func(t *testing.T) {
config := &Config{
ProviderURL: "https://example.com",
ClientID: "test-client",
ClientSecret: "test-secret",
}
handler, _ := New(context.Background(), nil, config, "test")
// Test that the handler works
if handler == nil {
t.Error("Handler should not be nil")
}
// Cleanup - cast to TraefikOidc if needed
if plugin, ok := handler.(*TraefikOidc); ok {
plugin.Close()
}
})
}
// TestGoroutinePoolConditionVariable tests the condition variable-based Wait implementation
func TestGoroutinePoolConditionVariable(t *testing.T) {
t.Run("WaitDoesNotBusyPoll", func(t *testing.T) {
// This test verifies that Wait() uses condition variable instead of busy-polling
pool := NewGoroutinePool(2, nil)
defer pool.Shutdown(context.Background())
// Submit a slow task
var taskStarted, taskFinished int32
pool.Submit(func() {
atomic.StoreInt32(&taskStarted, 1)
time.Sleep(100 * time.Millisecond)
atomic.StoreInt32(&taskFinished, 1)
})
// Give task time to start
time.Sleep(10 * time.Millisecond)
// Measure CPU-time before Wait
startCPU := time.Now()
// Wait should block efficiently without consuming CPU
pool.Wait()
elapsed := time.Since(startCPU)
// Verify task completed
if atomic.LoadInt32(&taskFinished) != 1 {
t.Error("Task should have finished")
}
// Wait should have taken ~90ms (task was already running for ~10ms)
// If it was busy-polling, we would see much higher CPU usage
// This is a sanity check - the real proof is in profiling
if elapsed < 50*time.Millisecond {
t.Errorf("Wait returned too quickly: %v", elapsed)
}
})
t.Run("WaitReturnsImmediatelyWhenEmpty", func(t *testing.T) {
pool := NewGoroutinePool(2, nil)
defer pool.Shutdown(context.Background())
// Wait on empty pool should return immediately
start := time.Now()
pool.Wait()
elapsed := time.Since(start)
// Should return almost immediately
if elapsed > 10*time.Millisecond {
t.Errorf("Wait on empty pool took too long: %v", elapsed)
}
})
t.Run("ConcurrentSubmitAndWait", func(t *testing.T) {
pool := NewGoroutinePool(4, nil)
defer pool.Shutdown(context.Background())
var completed int32
const numTasks = 100
// Submit tasks concurrently
var wg sync.WaitGroup
for i := 0; i < numTasks; i++ {
wg.Add(1)
go func() {
defer wg.Done()
pool.Submit(func() {
time.Sleep(1 * time.Millisecond)
atomic.AddInt32(&completed, 1)
})
}()
}
wg.Wait() // Wait for all submissions
// Wait for all tasks to complete
pool.Wait()
if atomic.LoadInt32(&completed) != numTasks {
t.Errorf("Expected %d tasks completed, got %d", numTasks, completed)
}
})
t.Run("WaitWithTimeoutSuccess", func(t *testing.T) {
pool := NewGoroutinePool(2, nil)
defer pool.Shutdown(context.Background())
pool.Submit(func() {
time.Sleep(50 * time.Millisecond)
})
// Should complete within timeout
success := pool.WaitWithTimeout(1 * time.Second)
if !success {
t.Error("WaitWithTimeout should have succeeded")
}
})
t.Run("WaitWithTimeoutExpired", func(t *testing.T) {
pool := NewGoroutinePool(1, nil)
defer pool.Shutdown(context.Background())
pool.Submit(func() {
time.Sleep(500 * time.Millisecond)
})
// Should timeout
success := pool.WaitWithTimeout(50 * time.Millisecond)
if success {
t.Error("WaitWithTimeout should have timed out")
}
// Wait for actual completion to avoid goroutine leak in test
pool.Wait()
})
t.Run("PendingTasksCounter", func(t *testing.T) {
// Use pool with larger buffer (maxWorkers=2, buffer=4)
pool := NewGoroutinePool(2, nil)
defer pool.Shutdown(context.Background())
// Initially no pending tasks
if pool.PendingTasks() != 0 {
t.Errorf("Expected 0 pending tasks, got %d", pool.PendingTasks())
}
// Block both workers with signals that tasks have started
blocker1 := make(chan struct{})
blocker2 := make(chan struct{})
started1 := make(chan struct{})
started2 := make(chan struct{})
pool.Submit(func() {
close(started1)
<-blocker1
})
pool.Submit(func() {
close(started2)
<-blocker2
})
// Wait for both blocking tasks to actually start
<-started1
<-started2
// Submit 2 more tasks that will queue up (buffer can hold 4)
for i := 0; i < 2; i++ {
pool.Submit(func() {
time.Sleep(1 * time.Millisecond)
})
}
// Should have pending tasks (2 running + 2 queued = 4)
pending := pool.PendingTasks()
if pending != 4 {
t.Errorf("Expected 4 pending tasks, got %d", pending)
}
// Release blockers
close(blocker1)
close(blocker2)
// Wait for completion
pool.Wait()
// Should have no pending tasks
if pool.PendingTasks() != 0 {
t.Errorf("Expected 0 pending tasks after Wait, got %d", pool.PendingTasks())
}
})
t.Run("MultipleWaiters", func(t *testing.T) {
pool := NewGoroutinePool(2, nil)
defer pool.Shutdown(context.Background())
// Submit a slow task
pool.Submit(func() {
time.Sleep(100 * time.Millisecond)
})
// Multiple goroutines waiting
var waiters sync.WaitGroup
var waitCount int32
for i := 0; i < 5; i++ {
waiters.Add(1)
go func() {
defer waiters.Done()
pool.Wait()
atomic.AddInt32(&waitCount, 1)
}()
}
// All waiters should complete
waiters.Wait()
if atomic.LoadInt32(&waitCount) != 5 {
t.Errorf("Expected all 5 waiters to complete, got %d", waitCount)
}
})
t.Run("SubmitFailureDoesNotIncrementPending", func(t *testing.T) {
pool := NewGoroutinePool(1, nil)
// Shutdown the pool
pool.Shutdown(context.Background())
// Submit should fail
err := pool.Submit(func() {})
if err == nil {
t.Error("Submit should fail on shutdown pool")
}
// Pending tasks should still be 0
if pool.PendingTasks() != 0 {
t.Errorf("Pending tasks should be 0 after failed submit, got %d", pool.PendingTasks())
}
})
t.Run("PanicRecoveryDecrementsPending", func(t *testing.T) {
pool := NewGoroutinePool(2, nil)
defer pool.Shutdown(context.Background())
// Submit a task that panics
pool.Submit(func() {
panic("test panic")
})
// Submit a normal task
var normalCompleted int32
pool.Submit(func() {
atomic.StoreInt32(&normalCompleted, 1)
})
// Wait should still work (panic is recovered)
pool.Wait()
// Normal task should have completed
if atomic.LoadInt32(&normalCompleted) != 1 {
t.Error("Normal task should have completed despite panic in other task")
}
// Pending should be 0
if pool.PendingTasks() != 0 {
t.Errorf("Pending tasks should be 0 after Wait, got %d", pool.PendingTasks())
}
})
}
// BenchmarkGoroutinePoolWait benchmarks the Wait implementation
func BenchmarkGoroutinePoolWait(b *testing.B) {
pool := NewGoroutinePool(4, nil)
defer pool.Shutdown(context.Background())
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Submit a quick task
pool.Submit(func() {})
pool.Wait()
}
}
// BenchmarkGoroutinePoolHighThroughput benchmarks high throughput scenario
func BenchmarkGoroutinePoolHighThroughput(b *testing.B) {
pool := NewGoroutinePool(8, nil)
defer pool.Shutdown(context.Background())
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
pool.Submit(func() {
// Minimal work
_ = 1 + 1
})
}
pool.Wait()
}
}
// Helper function to reset singleton for testing
func resetResourceManagerForTesting() {
resourceManagerMutex.Lock()
defer resourceManagerMutex.Unlock()
if globalResourceManager != nil {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
globalResourceManager.Shutdown(ctx)
}
resourceManagerOnce = sync.Once{}
globalResourceManager = nil
}