mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 413e4a1b7d | |||
| 69e0d98c67 | |||
| 6d893df12b | |||
| 6efb78b7a8 | |||
| d0b920c4f0 |
@@ -11,7 +11,9 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
|
||||
@@ -47,3 +47,14 @@ release:
|
||||
name_template: "v{{ .Version }}"
|
||||
draft: false
|
||||
prerelease: auto
|
||||
|
||||
signs:
|
||||
- cmd: cosign
|
||||
signature: "${artifact}.sigstore.json"
|
||||
args:
|
||||
- sign-blob
|
||||
- "--bundle=${signature}"
|
||||
- "${artifact}"
|
||||
- "--yes"
|
||||
artifacts: checksum
|
||||
output: true
|
||||
|
||||
@@ -82,6 +82,19 @@ experimental:
|
||||
|
||||
2. Configure the middleware in your dynamic configuration (see examples below).
|
||||
|
||||
### Verifying Release Signatures
|
||||
|
||||
All release checksums are signed with [cosign](https://github.com/sigstore/cosign) using keyless signing. To verify:
|
||||
|
||||
```bash
|
||||
# Download the checksum file and its sigstore bundle from the release
|
||||
cosign verify-blob \
|
||||
--certificate-identity-regexp "https://github.com/lukaszraczylo/traefikoidc/.*" \
|
||||
--certificate-oidc-issuer "https://token.actions.githubusercontent.com" \
|
||||
--bundle "traefikoidc_v<version>_checksums.txt.sigstore.json" \
|
||||
traefikoidc_v<version>_checksums.txt
|
||||
```
|
||||
|
||||
### Local Development with Docker Compose
|
||||
|
||||
For local development or testing, you can use the provided Docker Compose setup:
|
||||
|
||||
+4
-4
@@ -84,8 +84,8 @@ func TestAudienceValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
expectError bool
|
||||
errorContains string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid custom audience URL",
|
||||
@@ -163,8 +163,8 @@ func TestConfigAudienceValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Empty audience is valid for backward compatibility",
|
||||
@@ -732,11 +732,11 @@ func TestJWTAudienceVerification(t *testing.T) {
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tests := []struct {
|
||||
tokenAudience interface{}
|
||||
name string
|
||||
configAudience string
|
||||
tokenAudience interface{}
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantErr bool
|
||||
skipReplayCheck bool
|
||||
}{
|
||||
{
|
||||
|
||||
@@ -253,8 +253,8 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication_WithPKCE()
|
||||
// TestIsAjaxRequest tests AJAX request detection
|
||||
func (s *AuthFlowBehaviourSuite) TestIsAjaxRequest() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
name string
|
||||
expectAjax bool
|
||||
}{
|
||||
{
|
||||
|
||||
+17
-15
@@ -222,17 +222,16 @@ func (bt *BackgroundTask) run() {
|
||||
// TaskCircuitBreaker implements circuit breaker pattern for background task creation
|
||||
// It limits concurrent task execution and tracks failures to prevent system overload
|
||||
type TaskCircuitBreaker struct {
|
||||
state int32 // CircuitBreakerState
|
||||
failureCount int32
|
||||
lastFailureTime int64 // Unix timestamp
|
||||
failureThreshold int32
|
||||
timeout time.Duration
|
||||
logger *Logger
|
||||
// Concurrency limiting
|
||||
concurrentTasks int32 // Current number of running tasks
|
||||
maxConcurrent int32 // Maximum concurrent tasks allowed
|
||||
activeTasks map[string]struct{} // Track active task names
|
||||
tasksMu sync.RWMutex // Separate mutex for task tracking
|
||||
activeTasks map[string]struct{}
|
||||
lastFailureTime int64
|
||||
timeout time.Duration
|
||||
tasksMu sync.RWMutex
|
||||
state int32
|
||||
failureCount int32
|
||||
failureThreshold int32
|
||||
concurrentTasks int32
|
||||
maxConcurrent int32
|
||||
}
|
||||
|
||||
// NewTaskCircuitBreaker creates a new circuit breaker for background tasks
|
||||
@@ -266,18 +265,21 @@ func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error {
|
||||
max := atomic.LoadInt32(&cb.maxConcurrent)
|
||||
|
||||
// For cleanup tasks, be more restrictive (singleton-like behavior)
|
||||
// However, allow distinct realm-specific tasks (e.g., singleton-metadata-refresh-abc123 vs singleton-metadata-refresh-def456)
|
||||
if strings.Contains(taskName, "cleanup") || strings.Contains(taskName, "singleton") {
|
||||
cb.tasksMu.RLock()
|
||||
hasCleanupTask := false
|
||||
hasSameTask := false
|
||||
for activeTask := range cb.activeTasks {
|
||||
if strings.Contains(activeTask, "cleanup") || strings.Contains(activeTask, "singleton") {
|
||||
hasCleanupTask = true
|
||||
// Only block if the EXACT same task is already running
|
||||
// This allows realm-specific tasks like singleton-metadata-refresh-{hash} to run concurrently
|
||||
if activeTask == taskName {
|
||||
hasSameTask = true
|
||||
break
|
||||
}
|
||||
}
|
||||
cb.tasksMu.RUnlock()
|
||||
|
||||
if hasCleanupTask {
|
||||
if hasSameTask {
|
||||
return fmt.Errorf("cleanup/singleton task already running: %s", taskName)
|
||||
}
|
||||
}
|
||||
@@ -377,9 +379,9 @@ func (cb *TaskCircuitBreaker) OnTaskFailure(taskName string, err error) {
|
||||
// TaskRegistry maintains a registry of all active background tasks to prevent duplicates
|
||||
type TaskRegistry struct {
|
||||
tasks map[string]*BackgroundTask
|
||||
mu sync.RWMutex
|
||||
cb *TaskCircuitBreaker
|
||||
logger *Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// GlobalTaskRegistry is the singleton instance for managing all background tasks
|
||||
|
||||
+6
-6
@@ -330,12 +330,12 @@ func TestValidateGoogleTokens(t *testing.T) {
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func() *SessionData
|
||||
name string
|
||||
description string
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidGoogleTokens",
|
||||
@@ -476,13 +476,13 @@ func TestIsUserAuthenticated(t *testing.T) {
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
setupSession func() *SessionData
|
||||
name string
|
||||
providerType string
|
||||
setupSession func() *SessionData
|
||||
description string
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "AzureProvider",
|
||||
@@ -660,12 +660,12 @@ func TestValidateAzureTokensEdgeCases(t *testing.T) {
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func() *SessionData
|
||||
name string
|
||||
description string
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "UnauthenticatedWithRefreshToken",
|
||||
|
||||
@@ -97,15 +97,15 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
|
||||
|
||||
t.Run("String method returns pressure name", func(t *testing.T) {
|
||||
pressures := []struct {
|
||||
level MemoryPressureLevel
|
||||
name string
|
||||
level MemoryPressureLevel
|
||||
}{
|
||||
{MemoryPressureNone, "None"},
|
||||
{MemoryPressureLow, "Low"},
|
||||
{MemoryPressureModerate, "Moderate"},
|
||||
{MemoryPressureHigh, "High"},
|
||||
{MemoryPressureCritical, "Critical"},
|
||||
{MemoryPressureLevel(999), "Unknown"},
|
||||
{level: MemoryPressureNone, name: "None"},
|
||||
{level: MemoryPressureLow, name: "Low"},
|
||||
{level: MemoryPressureModerate, name: "Moderate"},
|
||||
{level: MemoryPressureHigh, name: "High"},
|
||||
{level: MemoryPressureCritical, name: "Critical"},
|
||||
{level: MemoryPressureLevel(999), name: "Unknown"},
|
||||
}
|
||||
|
||||
for _, p := range pressures {
|
||||
|
||||
+3
-3
@@ -155,9 +155,9 @@ type CacheStrategy interface {
|
||||
|
||||
// CacheEntry for backward compatibility
|
||||
type CacheEntry struct {
|
||||
Key string
|
||||
Value interface{}
|
||||
ExpiresAt time.Time
|
||||
Value interface{}
|
||||
Key string
|
||||
}
|
||||
|
||||
// Cache is an alias for backward compatibility
|
||||
@@ -175,10 +175,10 @@ func NewOptimizedCacheWithConfig(config OptimizedCacheConfig) *CacheInterfaceWra
|
||||
|
||||
// ListNode for backward compatibility
|
||||
type ListNode struct {
|
||||
Key string
|
||||
Value interface{}
|
||||
Next *ListNode
|
||||
Prev *ListNode
|
||||
Key string
|
||||
}
|
||||
|
||||
// NewFixedMetadataCache creates a metadata cache with fixed configuration
|
||||
|
||||
+13
-6
@@ -61,7 +61,7 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM
|
||||
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache()}
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedTokenCache returns the shared token cache
|
||||
@@ -93,7 +93,7 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface {
|
||||
func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache()}
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedTokenTypeCache returns the shared token type cache
|
||||
@@ -101,7 +101,7 @@ func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
|
||||
func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache()}
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all cache components
|
||||
@@ -121,7 +121,8 @@ func CleanupGlobalCacheManager() error {
|
||||
|
||||
// CacheInterfaceWrapper wraps UniversalCache to implement CacheInterface
|
||||
type CacheInterfaceWrapper struct {
|
||||
cache *UniversalCache
|
||||
cache *UniversalCache
|
||||
managed bool // If true, cache is managed globally and Close() is a no-op
|
||||
}
|
||||
|
||||
// Set stores a value
|
||||
@@ -149,9 +150,15 @@ func (c *CacheInterfaceWrapper) Cleanup() {
|
||||
c.cache.Cleanup()
|
||||
}
|
||||
|
||||
// Close shuts down the cache
|
||||
// Close shuts down the cache if it's not managed globally.
|
||||
// For managed caches (from UniversalCacheManager), this is a no-op to prevent log flooding
|
||||
// when multiple plugin instances are closed during Traefik configuration reloads.
|
||||
func (c *CacheInterfaceWrapper) Close() {
|
||||
// Close the underlying cache to stop goroutines
|
||||
if c.managed {
|
||||
// Cache is managed globally by UniversalCacheManager, so we don't close it here.
|
||||
return
|
||||
}
|
||||
// Standalone cache - close it properly to stop cleanup goroutines
|
||||
if c.cache != nil {
|
||||
_ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown
|
||||
}
|
||||
|
||||
+164
-11
@@ -19,16 +19,16 @@ import (
|
||||
|
||||
// CacheTestCase represents a comprehensive test case for cache operations
|
||||
type CacheTestCase struct {
|
||||
setup func(*TestFramework)
|
||||
execute func(*TestFramework) error
|
||||
validate func(*testing.T, error, *TestFramework)
|
||||
cleanup func(*TestFramework)
|
||||
name string
|
||||
cacheType string // "universal", "metadata", "bounded"
|
||||
operation string // "get", "set", "evict", "cleanup"
|
||||
setup func(*TestFramework) // Pre-test setup
|
||||
execute func(*TestFramework) error // Test execution
|
||||
validate func(*testing.T, error, *TestFramework) // Validation logic
|
||||
cleanup func(*TestFramework) // Post-test cleanup
|
||||
timeout time.Duration // Test timeout
|
||||
parallel bool // Can run in parallel
|
||||
skipReason string // Optional reason to skip
|
||||
cacheType string
|
||||
operation string
|
||||
skipReason string
|
||||
timeout time.Duration
|
||||
parallel bool
|
||||
}
|
||||
|
||||
// createTestCacheConfig creates a standard test configuration
|
||||
@@ -219,6 +219,159 @@ func TestCacheInterfaceWrapper_Close(t *testing.T) {
|
||||
nilWrapper.Close()
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_ManagedClose_Regression tests that managed cache wrappers
|
||||
// don't close the underlying cache when Close() is called. This is a regression test
|
||||
// for issue #105 where multiple plugin instances closing shared caches caused log flooding.
|
||||
func TestCacheInterfaceWrapper_ManagedClose_Regression(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
// Get a managed cache wrapper
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
// Verify it's marked as managed
|
||||
if !wrapper.managed {
|
||||
t.Error("Expected shared cache wrapper to be marked as managed")
|
||||
}
|
||||
|
||||
// Set some data before Close
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
|
||||
// Close the wrapper (should be a no-op for managed caches)
|
||||
wrapper.Close()
|
||||
|
||||
// Verify the cache is still operational after Close
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected cache to still work after Close() on managed wrapper")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
|
||||
// Can still set new values
|
||||
cache.Set("new-key", "new-value", time.Hour)
|
||||
newValue, found := cache.Get("new-key")
|
||||
if !found || newValue != "new-value" {
|
||||
t.Error("Expected to be able to set new values after Close() on managed wrapper")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_StandaloneClose tests that standalone cache wrappers
|
||||
// properly close the underlying cache when Close() is called.
|
||||
func TestCacheInterfaceWrapper_StandaloneClose(t *testing.T) {
|
||||
// Create a standalone cache (not from the global cache manager)
|
||||
standaloneCache := NewCache()
|
||||
|
||||
wrapper, ok := standaloneCache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
// Verify it's NOT marked as managed
|
||||
if wrapper.managed {
|
||||
t.Error("Expected standalone cache wrapper to NOT be marked as managed")
|
||||
}
|
||||
|
||||
// Set some data
|
||||
standaloneCache.Set("test-key", "test-value", time.Hour)
|
||||
|
||||
// Get baseline goroutine count
|
||||
baselineGoroutines := runtime.NumGoroutine()
|
||||
|
||||
// Close the wrapper (should actually close the underlying cache)
|
||||
wrapper.Close()
|
||||
|
||||
// Give cleanup goroutine time to stop
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Goroutine count should decrease (cleanup routine stopped)
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > baselineGoroutines {
|
||||
// This is acceptable - other tests might have started goroutines
|
||||
t.Logf("Goroutine count: baseline=%d, final=%d", baselineGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_MultipleInstancesClose_Regression tests that multiple
|
||||
// plugin instances can close their cache wrappers without affecting shared caches.
|
||||
// This is a regression test for issue #105.
|
||||
func TestCacheInterfaceWrapper_MultipleInstancesClose_Regression(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
// Simulate multiple plugin instances getting cache references
|
||||
instances := make([]*CacheInterfaceWrapper, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
instances[i] = wrapper
|
||||
|
||||
// Each instance might set some data
|
||||
cache.Set(fmt.Sprintf("instance-%d-key", i), fmt.Sprintf("value-%d", i), time.Hour)
|
||||
}
|
||||
|
||||
// Close all instances (simulating plugin shutdown/reload)
|
||||
for _, wrapper := range instances {
|
||||
wrapper.Close()
|
||||
}
|
||||
|
||||
// The shared cache should still work after all instances closed their wrappers
|
||||
newCache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Data set by earlier instances should still be accessible
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("instance-%d-key", i)
|
||||
value, found := newCache.Get(key)
|
||||
if !found {
|
||||
t.Errorf("Expected data from instance %d to still be accessible", i)
|
||||
}
|
||||
expectedValue := fmt.Sprintf("value-%d", i)
|
||||
if value != expectedValue {
|
||||
t.Errorf("Expected '%s', got '%v'", expectedValue, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Should be able to add new data
|
||||
newCache.Set("after-close-key", "after-close-value", time.Hour)
|
||||
value, found := newCache.Get("after-close-key")
|
||||
if !found || value != "after-close-value" {
|
||||
t.Error("Expected to be able to use cache after all wrapper Close() calls")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllSharedCachesMarkedAsManaged verifies all shared cache getters
|
||||
// return managed wrappers to prevent the log flooding issue.
|
||||
func TestAllSharedCachesMarkedAsManaged(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cache CacheInterface
|
||||
}{
|
||||
{"TokenBlacklist", cm.GetSharedTokenBlacklist()},
|
||||
{"IntrospectionCache", cm.GetSharedIntrospectionCache()},
|
||||
{"TokenTypeCache", cm.GetSharedTokenTypeCache()},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
wrapper, ok := tt.cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatalf("Expected CacheInterfaceWrapper for %s", tt.name)
|
||||
}
|
||||
if !wrapper.managed {
|
||||
t.Errorf("%s cache wrapper should be marked as managed", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheInterfaceWrapper_GetStats(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
@@ -698,10 +851,10 @@ func TestUnifiedCache_SetMaxSize(t *testing.T) {
|
||||
|
||||
func TestNewCacheAdapter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache interface{}
|
||||
expectNil bool
|
||||
name string
|
||||
description string
|
||||
expectNil bool
|
||||
}{
|
||||
{
|
||||
name: "UniversalCache",
|
||||
|
||||
@@ -353,6 +353,8 @@ allowPrivateIPAddresses: true # Required for private IPs
|
||||
- Roles: User Client Role mapper with "Add to ID token" enabled
|
||||
- Groups: Group Membership mapper with "Add to ID token" enabled
|
||||
|
||||
See [KEYCLOAK_SETUP_GUIDE.md](KEYCLOAK_SETUP_GUIDE.md) for detailed step-by-step setup instructions, mapper configuration, troubleshooting, and performance optimization.
|
||||
|
||||
---
|
||||
|
||||
## AWS Cognito
|
||||
|
||||
@@ -16,35 +16,26 @@ import (
|
||||
|
||||
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
|
||||
type ClientRegistrationResponse struct {
|
||||
// Required fields
|
||||
ClientID string `json:"client_id"`
|
||||
|
||||
// Conditional - only for confidential clients
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
|
||||
// Optional - for managing registration
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
|
||||
// Expiration
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
|
||||
// Echo back of registered metadata
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
}
|
||||
|
||||
// ClientRegistrationError represents an error response from client registration (RFC 7591)
|
||||
@@ -55,14 +46,12 @@ type ClientRegistrationError struct {
|
||||
|
||||
// DynamicClientRegistrar handles OIDC Dynamic Client Registration (RFC 7591)
|
||||
type DynamicClientRegistrar struct {
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
config *DynamicClientRegistrationConfig
|
||||
providerURL string
|
||||
|
||||
// Cached registration response
|
||||
mu sync.RWMutex
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
config *DynamicClientRegistrationConfig
|
||||
registrationResponse *ClientRegistrationResponse
|
||||
providerURL string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewDynamicClientRegistrar creates a new dynamic client registrar
|
||||
|
||||
@@ -223,10 +223,10 @@ func TestRegisterClientWithInitialAccessToken(t *testing.T) {
|
||||
// TestRegisterClientError tests error handling during registration
|
||||
func TestRegisterClientError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverResponse func(w http.ResponseWriter, r *http.Request)
|
||||
expectError bool
|
||||
name string
|
||||
errorContains string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "invalid_redirect_uri error",
|
||||
@@ -321,8 +321,8 @@ func TestRegisterClientError(t *testing.T) {
|
||||
// TestRegisterClientDisabled tests that registration fails when not enabled
|
||||
func TestRegisterClientDisabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dcrConfig *DynamicClientRegistrationConfig
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
@@ -521,8 +521,8 @@ func TestCredentialsValidation(t *testing.T) {
|
||||
registrar := NewDynamicClientRegistrar(&http.Client{}, NewLogger("DEBUG"), dcrConfig, "https://example.com")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response *ClientRegistrationResponse
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
@@ -584,9 +584,9 @@ func TestCredentialsValidation(t *testing.T) {
|
||||
// TestBuildRegistrationRequest tests the request body construction
|
||||
func TestBuildRegistrationRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata *ClientRegistrationMetadata
|
||||
expectedFields map[string]interface{}
|
||||
name string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
|
||||
+29
-47
@@ -12,23 +12,19 @@ import (
|
||||
|
||||
// EnhancedMockJWKCache is an improved state-based mock with call tracking
|
||||
type EnhancedMockJWKCache struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// State (what to return)
|
||||
JWKS *JWKSet
|
||||
Err error
|
||||
|
||||
// Call tracking
|
||||
Err error
|
||||
JWKS *JWKSet
|
||||
GetJWKSCalls []JWKSCall
|
||||
mu sync.RWMutex
|
||||
getJWKSCallsMu sync.Mutex
|
||||
CleanupCalls int32
|
||||
CloseCalls int32
|
||||
getJWKSCallsMu sync.Mutex
|
||||
}
|
||||
|
||||
// JWKSCall records parameters from a GetJWKS call
|
||||
type JWKSCall struct {
|
||||
URL string
|
||||
Timestamp time.Time
|
||||
URL string
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
@@ -108,22 +104,18 @@ func (m *EnhancedMockJWKCache) Reset() {
|
||||
|
||||
// EnhancedMockTokenVerifier is an improved state-based mock with call tracking
|
||||
type EnhancedMockTokenVerifier struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// State (what to return) - can be a fixed error or a function
|
||||
Err error
|
||||
VerifyFunc func(token string) error
|
||||
|
||||
// Call tracking
|
||||
Err error
|
||||
VerifyFunc func(token string) error
|
||||
VerifyCalls []TokenVerifyCall
|
||||
mu sync.RWMutex
|
||||
verifyCallsMu sync.Mutex
|
||||
}
|
||||
|
||||
// TokenVerifyCall records parameters from a VerifyToken call
|
||||
type TokenVerifyCall struct {
|
||||
Token string
|
||||
Timestamp time.Time
|
||||
Result error
|
||||
Token string
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenVerifier) VerifyToken(token string) error {
|
||||
@@ -207,49 +199,43 @@ func (m *EnhancedMockTokenVerifier) Reset() {
|
||||
|
||||
// EnhancedMockTokenExchanger is an improved state-based mock with call tracking
|
||||
type EnhancedMockTokenExchanger struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// State (what to return)
|
||||
ExchangeResponse *TokenResponse
|
||||
ExchangeErr error
|
||||
RefreshResponse *TokenResponse
|
||||
RefreshErr error
|
||||
RevokeErr error
|
||||
|
||||
// Optional functions for dynamic behavior
|
||||
ExchangeErr error
|
||||
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
|
||||
RefreshResponse *TokenResponse
|
||||
ExchangeResponse *TokenResponse
|
||||
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
|
||||
RevokeTokenFunc func(token, tokenType string) error
|
||||
|
||||
// Call tracking
|
||||
ExchangeCalls []ExchangeCall
|
||||
RefreshCalls []RefreshCall
|
||||
RevokeCalls []RevokeCall
|
||||
exchangeCallsMu sync.Mutex
|
||||
refreshCallsMu sync.Mutex
|
||||
revokeCallsMu sync.Mutex
|
||||
ExchangeCalls []ExchangeCall
|
||||
RefreshCalls []RefreshCall
|
||||
RevokeCalls []RevokeCall
|
||||
mu sync.RWMutex
|
||||
exchangeCallsMu sync.Mutex
|
||||
refreshCallsMu sync.Mutex
|
||||
revokeCallsMu sync.Mutex
|
||||
}
|
||||
|
||||
// ExchangeCall records parameters from an ExchangeCodeForToken call
|
||||
type ExchangeCall struct {
|
||||
Timestamp time.Time
|
||||
GrantType string
|
||||
CodeOrToken string
|
||||
RedirectURL string
|
||||
CodeVerifier string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// RefreshCall records parameters from a GetNewTokenWithRefreshToken call
|
||||
type RefreshCall struct {
|
||||
RefreshToken string
|
||||
Timestamp time.Time
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// RevokeCall records parameters from a RevokeTokenWithProvider call
|
||||
type RevokeCall struct {
|
||||
Timestamp time.Time
|
||||
Token string
|
||||
TokenType string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
@@ -401,16 +387,12 @@ func (m *EnhancedMockTokenExchanger) Reset() {
|
||||
|
||||
// EnhancedMockCacheInterface is an improved state-based mock for CacheInterface
|
||||
type EnhancedMockCacheInterface struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Internal storage
|
||||
data map[string]cacheEntry
|
||||
maxSize int
|
||||
|
||||
// Call tracking
|
||||
data map[string]cacheEntry
|
||||
GetCalls []CacheGetCall
|
||||
SetCalls []CacheSetCall
|
||||
DeleteCalls []string
|
||||
maxSize int
|
||||
mu sync.RWMutex
|
||||
getCalls sync.Mutex
|
||||
setCalls sync.Mutex
|
||||
deleteCalls sync.Mutex
|
||||
@@ -423,17 +405,17 @@ type cacheEntry struct {
|
||||
|
||||
// CacheGetCall records parameters from a Get call
|
||||
type CacheGetCall struct {
|
||||
Timestamp time.Time
|
||||
Key string
|
||||
Found bool
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// CacheSetCall records parameters from a Set call
|
||||
type CacheSetCall struct {
|
||||
Key string
|
||||
Value any
|
||||
TTL time.Duration
|
||||
Timestamp time.Time
|
||||
Value any
|
||||
Key string
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// NewEnhancedMockCache creates a new enhanced cache mock
|
||||
|
||||
+15
-36
@@ -642,14 +642,10 @@ func (e *HTTPError) Error() string {
|
||||
// OIDCError represents OIDC-specific errors with context information.
|
||||
// It provides structured error reporting for authentication and authorization failures.
|
||||
type OIDCError struct {
|
||||
// Code identifies the specific error type
|
||||
Code string
|
||||
// Message provides a human-readable description
|
||||
Message string
|
||||
// Context contains additional error context (e.g., provider, session details)
|
||||
Cause error
|
||||
Context map[string]interface{}
|
||||
// Cause is the underlying error that caused this error
|
||||
Cause error
|
||||
Code string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error returns the string representation of the OIDC error.
|
||||
@@ -669,14 +665,10 @@ func (e *OIDCError) Unwrap() error {
|
||||
// SessionError represents session-related errors with context.
|
||||
// Used for session management, validation, and storage errors.
|
||||
type SessionError struct {
|
||||
// Operation describes what session operation failed
|
||||
Cause error
|
||||
Operation string
|
||||
// Message provides a human-readable description
|
||||
Message string
|
||||
// SessionID identifies the session (if available)
|
||||
Message string
|
||||
SessionID string
|
||||
// Cause is the underlying error that caused this error
|
||||
Cause error
|
||||
}
|
||||
|
||||
// Error returns the string representation of the session error.
|
||||
@@ -696,14 +688,10 @@ func (e *SessionError) Unwrap() error {
|
||||
// TokenError represents token-related errors with validation context.
|
||||
// Used for JWT validation, token refresh, and token format errors.
|
||||
type TokenError struct {
|
||||
// TokenType identifies the type of token (id_token, access_token, refresh_token)
|
||||
Cause error
|
||||
TokenType string
|
||||
// Reason describes why the token is invalid
|
||||
Reason string
|
||||
// Message provides a human-readable description
|
||||
Message string
|
||||
// Cause is the underlying error that caused this error
|
||||
Cause error
|
||||
Reason string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error returns the string representation of the token error.
|
||||
@@ -765,24 +753,15 @@ func NewTokenError(tokenType, reason, message string, cause error) *TokenError {
|
||||
// It provides fallback mechanisms when primary services are unavailable and monitors
|
||||
// service health to automatically recover when services become available again.
|
||||
type GracefulDegradation struct {
|
||||
// BaseRecoveryMechanism provides common functionality
|
||||
*BaseRecoveryMechanism
|
||||
// fallbacks stores service-specific fallback implementations
|
||||
fallbacks map[string]func() (interface{}, error)
|
||||
// healthChecks stores service health check functions
|
||||
healthChecks map[string]func() bool
|
||||
// degradedServices tracks which services are currently degraded
|
||||
fallbacks map[string]func() (interface{}, error)
|
||||
healthChecks map[string]func() bool
|
||||
degradedServices map[string]time.Time
|
||||
// config contains graceful degradation configuration
|
||||
config GracefulDegradationConfig
|
||||
// mutex protects shared state
|
||||
mutex sync.RWMutex
|
||||
// healthCheckTask manages background health checking
|
||||
healthCheckTask *BackgroundTask
|
||||
// stopChan signals shutdown
|
||||
stopChan chan struct{}
|
||||
// shutdownOnce ensures shutdown happens only once
|
||||
shutdownOnce sync.Once
|
||||
healthCheckTask *BackgroundTask
|
||||
stopChan chan struct{}
|
||||
config GracefulDegradationConfig
|
||||
mutex sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
}
|
||||
|
||||
// GracefulDegradationConfig holds configuration for graceful degradation behavior.
|
||||
|
||||
@@ -20,10 +20,10 @@ import (
|
||||
func TestCircuitBreakerStateTransitions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
failures int
|
||||
maxFailures int
|
||||
expectedStateBefore string
|
||||
expectedStateAfter string
|
||||
failures int
|
||||
maxFailures int
|
||||
}{
|
||||
{
|
||||
name: "stays closed below threshold",
|
||||
@@ -543,8 +543,8 @@ func TestRetryExecutorNetworkErrors(t *testing.T) {
|
||||
}, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
name string
|
||||
shouldRetry bool
|
||||
}{
|
||||
{
|
||||
@@ -1647,8 +1647,8 @@ func TestGracefulDegradationFullScenario(t *testing.T) {
|
||||
|
||||
func TestIsTraefikDefaultCertError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
@@ -1680,8 +1680,8 @@ func TestIsTraefikDefaultCertError(t *testing.T) {
|
||||
|
||||
func TestIsEOFError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
@@ -1723,8 +1723,8 @@ func TestIsEOFError(t *testing.T) {
|
||||
|
||||
func TestIsCertificateError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
@@ -1811,8 +1811,8 @@ func TestRetryExecutorStartupErrors(t *testing.T) {
|
||||
_ = NewRetryExecutor(MetadataFetchRetryConfig(), nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
name string
|
||||
shouldRetry bool
|
||||
}{
|
||||
{
|
||||
@@ -1890,8 +1890,8 @@ func TestRetryExecutorIsRetryableErrorIntegration(t *testing.T) {
|
||||
re := NewRetryExecutor(DefaultRetryConfig(), nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
name string
|
||||
shouldRetry bool
|
||||
}{
|
||||
{
|
||||
@@ -1977,9 +1977,9 @@ func circuitBreakerStateToString(state CircuitBreakerState) string {
|
||||
}
|
||||
|
||||
type mockNetError struct {
|
||||
msg string
|
||||
timeout bool
|
||||
temporary bool
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *mockNetError) Error() string { return e.msg }
|
||||
|
||||
@@ -10,16 +10,16 @@ import (
|
||||
type GoroutineManager struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
goroutines map[string]*managedGoroutine
|
||||
logger *Logger
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type managedGoroutine struct {
|
||||
name string
|
||||
cancel context.CancelFunc
|
||||
startTime time.Time
|
||||
cancel context.CancelFunc
|
||||
name string
|
||||
running bool
|
||||
}
|
||||
|
||||
@@ -149,10 +149,10 @@ func (m *GoroutineManager) GetStatus() map[string]GoroutineStatus {
|
||||
|
||||
// GoroutineStatus represents the status of a managed goroutine
|
||||
type GoroutineStatus struct {
|
||||
Name string
|
||||
Running bool
|
||||
StartTime time.Time
|
||||
Name string
|
||||
Runtime time.Duration
|
||||
Running bool
|
||||
}
|
||||
|
||||
// ErrShutdownTimeout is returned when shutdown times out
|
||||
|
||||
+12
-19
@@ -12,30 +12,23 @@ import (
|
||||
|
||||
// HTTPClientConfig provides configuration for creating HTTP clients
|
||||
type HTTPClientConfig struct {
|
||||
// Timeout for the entire request
|
||||
Timeout time.Duration
|
||||
// MaxRedirects allowed (0 means follow Go's default of 10)
|
||||
MaxRedirects int
|
||||
// UseCookieJar enables cookie jar for the client
|
||||
UseCookieJar bool
|
||||
// Connection settings
|
||||
IdleConnTimeout time.Duration
|
||||
MaxIdleConns int
|
||||
ReadBufferSize int
|
||||
DialTimeout time.Duration
|
||||
KeepAlive time.Duration
|
||||
TLSHandshakeTimeout time.Duration
|
||||
ResponseHeaderTimeout time.Duration
|
||||
ExpectContinueTimeout time.Duration
|
||||
IdleConnTimeout time.Duration
|
||||
// Connection pool settings
|
||||
MaxIdleConns int
|
||||
MaxIdleConnsPerHost int
|
||||
MaxConnsPerHost int
|
||||
// Buffer settings
|
||||
WriteBufferSize int
|
||||
ReadBufferSize int
|
||||
// Feature flags
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
MaxRedirects int
|
||||
MaxIdleConnsPerHost int
|
||||
Timeout time.Duration
|
||||
MaxConnsPerHost int
|
||||
WriteBufferSize int
|
||||
UseCookieJar bool
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
}
|
||||
|
||||
// DefaultHTTPClientConfig returns the default configuration for general use
|
||||
|
||||
@@ -110,9 +110,9 @@ func TestHTTPClientFactoryValidateHTTPClientConfig(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errorMsg string
|
||||
config HTTPClientConfig
|
||||
wantError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
|
||||
+6
-6
@@ -12,19 +12,19 @@ import (
|
||||
|
||||
// SharedTransportPool manages a pool of shared HTTP transports to prevent connection exhaustion
|
||||
type SharedTransportPool struct {
|
||||
mu sync.RWMutex
|
||||
transports map[string]*sharedTransport
|
||||
maxConns int
|
||||
ctx context.Context
|
||||
transports map[string]*sharedTransport
|
||||
cancel context.CancelFunc
|
||||
clientCount int32 // SECURITY FIX: Track total HTTP clients
|
||||
maxClients int32 // SECURITY FIX: Limit total clients to 5
|
||||
maxConns int
|
||||
mu sync.RWMutex
|
||||
clientCount int32
|
||||
maxClients int32
|
||||
}
|
||||
|
||||
type sharedTransport struct {
|
||||
lastUsed time.Time
|
||||
transport *http.Transport
|
||||
refCount int
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -14,7 +14,7 @@ func TestInputValidator(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("Valid token validation", func(t *testing.T) {
|
||||
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc"
|
||||
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc" // trufflehog:ignore
|
||||
|
||||
result := validator.ValidateToken(validToken)
|
||||
if !result.IsValid {
|
||||
@@ -428,12 +428,12 @@ func TestInputValidatorValidateToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidJWTToken",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature", // trufflehog:ignore
|
||||
expectValid: true,
|
||||
description: "Valid JWT token should pass validation",
|
||||
},
|
||||
@@ -475,7 +475,7 @@ func TestInputValidatorValidateToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "MaliciousJWTWithExtraData",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra", // trufflehog:ignore
|
||||
expectValid: false,
|
||||
description: "JWT with extra malicious data should fail validation",
|
||||
},
|
||||
@@ -500,8 +500,8 @@ func TestInputValidatorValidateEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidEmail",
|
||||
@@ -578,8 +578,8 @@ func TestInputValidatorValidateURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidHTTPSURL",
|
||||
@@ -669,8 +669,8 @@ func TestInputValidatorValidateClaim(t *testing.T) {
|
||||
name string
|
||||
claimName string
|
||||
claimValue string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidStringClaim",
|
||||
@@ -750,8 +750,8 @@ func TestInputValidatorValidateHeader(t *testing.T) {
|
||||
name string
|
||||
headerName string
|
||||
headerValue string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidHeader",
|
||||
@@ -830,8 +830,8 @@ func TestInputValidatorValidateUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidUsername",
|
||||
|
||||
@@ -726,20 +726,20 @@ type MockConfig struct {
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
id string
|
||||
userID string
|
||||
created time.Time
|
||||
lastUsed time.Time
|
||||
data map[string]interface{}
|
||||
id string
|
||||
userID string
|
||||
}
|
||||
|
||||
type TestResult struct {
|
||||
UserID int
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
Error error
|
||||
UserID int
|
||||
Duration time.Duration
|
||||
Success bool
|
||||
Error error
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
Vendored
+14
-25
@@ -18,33 +18,22 @@ const (
|
||||
|
||||
// Config provides common configuration for cache backends
|
||||
type Config struct {
|
||||
// Type specifies the backend type
|
||||
Type BackendType
|
||||
|
||||
// Memory backend settings
|
||||
MaxSize int
|
||||
MaxMemoryBytes int64
|
||||
CleanupInterval time.Duration
|
||||
|
||||
// Redis backend settings
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
RedisPrefix string
|
||||
PoolSize int
|
||||
|
||||
// Hybrid backend settings
|
||||
L1Config *Config // Memory cache (L1)
|
||||
L2Config *Config // Redis cache (L2)
|
||||
AsyncWrites bool // Write to L2 asynchronously
|
||||
|
||||
// Resilience settings
|
||||
L2Config *Config
|
||||
L1Config *Config
|
||||
RedisPrefix string
|
||||
Type BackendType
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
PoolSize int
|
||||
RedisDB int
|
||||
CleanupInterval time.Duration
|
||||
MaxMemoryBytes int64
|
||||
MaxSize int
|
||||
HealthCheckInterval time.Duration
|
||||
AsyncWrites bool
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
HealthCheckInterval time.Duration
|
||||
|
||||
// Metrics
|
||||
EnableMetrics bool
|
||||
EnableMetrics bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
|
||||
Vendored
+18
-28
@@ -13,40 +13,30 @@ import (
|
||||
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
|
||||
// It provides automatic failover, async writes for non-critical data, and optimized read paths
|
||||
type HybridBackend struct {
|
||||
primary CacheBackend // L1: Memory cache for fast access
|
||||
secondary CacheBackend // L2: Redis cache for distributed access
|
||||
|
||||
// Configuration
|
||||
syncWriteCacheTypes map[string]bool // Which cache types require synchronous writes
|
||||
lastL2Error atomic.Value
|
||||
secondary CacheBackend
|
||||
primary CacheBackend
|
||||
logger Logger
|
||||
ctx context.Context
|
||||
syncWriteCacheTypes map[string]bool
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
|
||||
// Metrics
|
||||
l1Hits atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
l1Writes atomic.Int64
|
||||
l2Writes atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Fallback tracking
|
||||
fallbackMode atomic.Bool // True when operating in degraded mode (L1 only)
|
||||
lastL2Error atomic.Value // Stores last L2 error timestamp
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Logging
|
||||
logger Logger
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
l1Hits atomic.Int64
|
||||
errors atomic.Int64
|
||||
l2Writes atomic.Int64
|
||||
l1Writes atomic.Int64
|
||||
misses atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
fallbackMode atomic.Bool
|
||||
}
|
||||
|
||||
// asyncWriteItem represents an async write operation
|
||||
type asyncWriteItem struct {
|
||||
ctx context.Context
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
@@ -82,9 +72,9 @@ func (l *defaultLogger) Errorf(format string, args ...interface{}) {
|
||||
type HybridConfig struct {
|
||||
Primary CacheBackend
|
||||
Secondary CacheBackend
|
||||
SyncWriteCacheTypes map[string]bool // Cache types requiring synchronous L2 writes
|
||||
AsyncBufferSize int
|
||||
Logger Logger
|
||||
SyncWriteCacheTypes map[string]bool
|
||||
AsyncBufferSize int
|
||||
}
|
||||
|
||||
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
|
||||
|
||||
+6
-6
@@ -17,23 +17,23 @@ import (
|
||||
|
||||
// mockBackend is a simple mock implementation of CacheBackend for testing
|
||||
type mockBackend struct {
|
||||
pingError error
|
||||
data map[string]mockEntry
|
||||
stats map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
getCalls atomic.Int32
|
||||
setCalls atomic.Int32
|
||||
deleteCalls atomic.Int32
|
||||
failSet bool
|
||||
failGet bool
|
||||
failDelete bool
|
||||
failClear bool
|
||||
failPing bool
|
||||
pingError error
|
||||
stats map[string]interface{}
|
||||
getCalls atomic.Int32
|
||||
setCalls atomic.Int32
|
||||
deleteCalls atomic.Int32
|
||||
}
|
||||
|
||||
type mockEntry struct {
|
||||
value []byte
|
||||
expiresAt time.Time
|
||||
value []byte
|
||||
}
|
||||
|
||||
// mockBatchBackend extends mockBackend with batch operations
|
||||
|
||||
Vendored
+14
-45
@@ -41,53 +41,22 @@ type CacheBackend interface {
|
||||
|
||||
// BackendStats represents statistics for a cache backend
|
||||
type BackendStats struct {
|
||||
// Type is the backend type
|
||||
Type BackendType
|
||||
|
||||
// Hits is the number of cache hits
|
||||
Hits int64
|
||||
|
||||
// Misses is the number of cache misses
|
||||
Misses int64
|
||||
|
||||
// Sets is the number of set operations
|
||||
Sets int64
|
||||
|
||||
// Deletes is the number of delete operations
|
||||
Deletes int64
|
||||
|
||||
// Errors is the number of errors
|
||||
Errors int64
|
||||
|
||||
// Evictions is the number of evicted items
|
||||
Evictions int64
|
||||
|
||||
// CurrentSize is the current number of items in cache
|
||||
CurrentSize int64
|
||||
|
||||
// MaxSize is the maximum number of items (0 means unlimited)
|
||||
MaxSize int64
|
||||
|
||||
// MemoryUsage is the approximate memory usage in bytes
|
||||
MemoryUsage int64
|
||||
|
||||
// AverageGetLatency is the average latency for get operations
|
||||
StartTime time.Time
|
||||
LastErrorTime time.Time
|
||||
Type BackendType
|
||||
LastError string
|
||||
Deletes int64
|
||||
Errors int64
|
||||
Evictions int64
|
||||
CurrentSize int64
|
||||
MaxSize int64
|
||||
MemoryUsage int64
|
||||
AverageGetLatency time.Duration
|
||||
|
||||
// AverageSetLatency is the average latency for set operations
|
||||
AverageSetLatency time.Duration
|
||||
|
||||
// LastError is the last error encountered
|
||||
LastError string
|
||||
|
||||
// LastErrorTime is when the last error occurred
|
||||
LastErrorTime time.Time
|
||||
|
||||
// Uptime is how long the backend has been running
|
||||
Uptime time.Duration
|
||||
|
||||
// StartTime is when the backend was started
|
||||
StartTime time.Time
|
||||
Sets int64
|
||||
Misses int64
|
||||
Uptime time.Duration
|
||||
Hits int64
|
||||
}
|
||||
|
||||
// BackendCapabilities describes the capabilities of a cache backend
|
||||
|
||||
Vendored
+219
-200
@@ -2,23 +2,30 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Default configuration values
|
||||
const (
|
||||
defaultShardCount = 256
|
||||
defaultMaxSize = int64(10000)
|
||||
defaultMaxMemory = int64(100 * 1024 * 1024) // 100MB
|
||||
defaultCleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// memoryCacheItem represents an item in the memory cache
|
||||
type memoryCacheItem struct {
|
||||
key string
|
||||
value interface{}
|
||||
expiresAt time.Time
|
||||
createdAt time.Time
|
||||
accessedAt time.Time
|
||||
value interface{}
|
||||
element interface{} // *list.Element, using interface{} to avoid import cycle
|
||||
key string
|
||||
accessCount int64
|
||||
size int64
|
||||
element *list.Element // for LRU tracking
|
||||
}
|
||||
|
||||
// isExpired checks if the item is expired
|
||||
@@ -29,17 +36,23 @@ func (item *memoryCacheItem) isExpired() bool {
|
||||
return time.Now().After(item.expiresAt)
|
||||
}
|
||||
|
||||
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
|
||||
// MemoryCacheBackend implements the CacheBackend interface using sharded in-memory storage
|
||||
// The sharded design reduces lock contention by partitioning keys across multiple shards,
|
||||
// each with its own lock.
|
||||
type MemoryCacheBackend struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
currentSize int64
|
||||
currentMemory int64
|
||||
shards []*cacheShard
|
||||
startTime time.Time
|
||||
lastErrorTime time.Time
|
||||
cleanupDone chan struct{}
|
||||
cleanupTicker *time.Ticker
|
||||
lastError string
|
||||
shardCount uint32
|
||||
shardMask uint32
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
cleanupInterval time.Duration
|
||||
|
||||
// Statistics
|
||||
// Global stats (aggregated from shards)
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
sets atomic.Int64
|
||||
@@ -53,40 +66,59 @@ type MemoryCacheBackend struct {
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
|
||||
// Status
|
||||
startTime time.Time
|
||||
lastError string
|
||||
lastErrorTime time.Time
|
||||
cleanupTicker *time.Ticker
|
||||
cleanupDone chan bool
|
||||
closed atomic.Bool
|
||||
|
||||
// Configuration
|
||||
cleanupInterval time.Duration
|
||||
evictionPolicy string // "lru", "lfu", "fifo"
|
||||
// State
|
||||
closed atomic.Bool
|
||||
mu sync.RWMutex // For global operations like stats and error tracking
|
||||
}
|
||||
|
||||
// NewMemoryCacheBackend creates a new memory cache backend
|
||||
// NewMemoryCacheBackend creates a new sharded memory cache backend
|
||||
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
|
||||
if maxSize <= 0 {
|
||||
maxSize = 10000 // Default to 10k items
|
||||
maxSize = defaultMaxSize
|
||||
}
|
||||
if maxMemory <= 0 {
|
||||
maxMemory = 100 * 1024 * 1024 // Default to 100MB
|
||||
maxMemory = defaultMaxMemory
|
||||
}
|
||||
if cleanupInterval <= 0 {
|
||||
cleanupInterval = 5 * time.Minute
|
||||
cleanupInterval = defaultCleanupInterval
|
||||
}
|
||||
|
||||
shardCount := uint32(defaultShardCount)
|
||||
|
||||
// For very small caches, reduce shard count to maintain sensible per-shard limits
|
||||
// Ensure each shard can hold at least 2 items for proper LRU behavior
|
||||
for shardCount > 1 && maxSize/int64(shardCount) < 2 {
|
||||
shardCount /= 2
|
||||
}
|
||||
if shardCount < 1 {
|
||||
shardCount = 1
|
||||
}
|
||||
|
||||
// Per-shard limits are soft hints; global limits are enforced
|
||||
// Give shards 2x the average to allow for uneven distribution
|
||||
shardMaxSize := (maxSize * 2) / int64(shardCount)
|
||||
if shardMaxSize < 4 {
|
||||
shardMaxSize = 4
|
||||
}
|
||||
shardMaxMemory := (maxMemory * 2) / int64(shardCount)
|
||||
if shardMaxMemory < 4096 {
|
||||
shardMaxMemory = 4096 // Minimum 4KB per shard
|
||||
}
|
||||
|
||||
m := &MemoryCacheBackend{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
shards: make([]*cacheShard, shardCount),
|
||||
shardCount: shardCount,
|
||||
shardMask: shardCount - 1, // For fast modulo with power-of-2
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
startTime: time.Now(),
|
||||
cleanupInterval: cleanupInterval,
|
||||
evictionPolicy: "lru",
|
||||
cleanupDone: make(chan bool),
|
||||
cleanupDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize shards
|
||||
for i := uint32(0); i < shardCount; i++ {
|
||||
m.shards[i] = newCacheShard(shardMaxSize, shardMaxMemory)
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
@@ -96,6 +128,12 @@ func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.
|
||||
return m
|
||||
}
|
||||
|
||||
// getShard returns the shard for a given key
|
||||
func (m *MemoryCacheBackend) getShard(key string) *cacheShard {
|
||||
hash := fnv32(key)
|
||||
return m.shards[hash&m.shardMask]
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired items
|
||||
func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
for {
|
||||
@@ -108,20 +146,19 @@ func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpired removes all expired items from the cache
|
||||
// cleanupExpired removes all expired items from all shards
|
||||
func (m *MemoryCacheBackend) cleanupExpired() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var keysToDelete []string
|
||||
for key, item := range m.items {
|
||||
if item.isExpired() {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
if m.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range keysToDelete {
|
||||
m.deleteItemLocked(key)
|
||||
totalRemoved := 0
|
||||
for _, shard := range m.shards {
|
||||
totalRemoved += shard.cleanup()
|
||||
}
|
||||
|
||||
if totalRemoved > 0 {
|
||||
m.evictions.Add(int64(totalRemoved))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,35 +175,23 @@ func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{},
|
||||
m.getCount.Add(1)
|
||||
}()
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
shard := m.getShard(key)
|
||||
value, exists, expired := shard.get(key)
|
||||
|
||||
if expired {
|
||||
// Clean up expired item
|
||||
shard.delete(key)
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if !exists {
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
m.mu.Lock()
|
||||
m.deleteItemLocked(key)
|
||||
m.mu.Unlock()
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
// Update access time and count
|
||||
m.mu.Lock()
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
// Move to front of LRU list
|
||||
if m.evictionPolicy == "lru" && item.element != nil {
|
||||
m.lruList.MoveToFront(item.element)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.hits.Add(1)
|
||||
return item.value, nil
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with optional TTL
|
||||
@@ -182,113 +207,105 @@ func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interfac
|
||||
m.setCount.Add(1)
|
||||
}()
|
||||
|
||||
// Calculate item size (simplified estimation)
|
||||
// Calculate item size
|
||||
itemSize := int64(len(key)) + estimateValueSize(value)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Enforce global limits before adding new item
|
||||
m.enforceGlobalLimits(itemSize)
|
||||
|
||||
// Check if we need to evict items
|
||||
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
|
||||
m.evictLocked()
|
||||
}
|
||||
|
||||
// Check if key exists
|
||||
if oldItem, exists := m.items[key]; exists {
|
||||
m.currentMemory -= oldItem.size
|
||||
if oldItem.element != nil {
|
||||
m.lruList.Remove(oldItem.element)
|
||||
}
|
||||
} else {
|
||||
m.currentSize++
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
var expiresAt time.Time
|
||||
if ttl > 0 {
|
||||
expiresAt = now.Add(ttl)
|
||||
expiresAt = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: itemSize,
|
||||
}
|
||||
shard := m.getShard(key)
|
||||
shard.set(key, value, expiresAt, itemSize)
|
||||
|
||||
// Add to LRU list
|
||||
if m.evictionPolicy == "lru" {
|
||||
item.element = m.lruList.PushFront(item)
|
||||
}
|
||||
|
||||
m.items[key] = item
|
||||
m.currentMemory += itemSize
|
||||
m.sets.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// enforceGlobalLimits ensures global size and memory limits are respected
|
||||
// by evicting from shards when necessary
|
||||
func (m *MemoryCacheBackend) enforceGlobalLimits(newItemSize int64) {
|
||||
// Check and enforce size limit
|
||||
for {
|
||||
totalSize, totalMemory := m.getGlobalStats()
|
||||
|
||||
needsSizeEviction := m.maxSize > 0 && totalSize >= m.maxSize
|
||||
needsMemoryEviction := m.maxMemory > 0 && totalMemory+newItemSize > m.maxMemory
|
||||
|
||||
if !needsSizeEviction && !needsMemoryEviction {
|
||||
break
|
||||
}
|
||||
|
||||
// Find the shard with the most items and evict from it
|
||||
evicted := m.evictFromLargestShard()
|
||||
if !evicted {
|
||||
break // No more items to evict
|
||||
}
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// getGlobalStats returns the total size and memory usage across all shards
|
||||
func (m *MemoryCacheBackend) getGlobalStats() (totalSize, totalMemory int64) {
|
||||
for _, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
totalSize += size
|
||||
totalMemory += memory
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// evictFromLargestShard evicts the globally oldest item across all shards
|
||||
// This provides true LRU behavior even with sharding
|
||||
func (m *MemoryCacheBackend) evictFromLargestShard() bool {
|
||||
var oldestShard *cacheShard
|
||||
var oldestTime time.Time
|
||||
|
||||
for _, shard := range m.shards {
|
||||
accessTime := shard.getOldestAccessTime()
|
||||
// Skip empty shards
|
||||
if accessTime.IsZero() {
|
||||
continue
|
||||
}
|
||||
// Find the shard with the oldest (earliest) access time
|
||||
if oldestShard == nil || accessTime.Before(oldestTime) {
|
||||
oldestTime = accessTime
|
||||
oldestShard = shard
|
||||
}
|
||||
}
|
||||
|
||||
if oldestShard == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return oldestShard.evictOne()
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.items[key]; !exists {
|
||||
return nil
|
||||
shard := m.getShard(key)
|
||||
if shard.delete(key) {
|
||||
m.deletes.Add(1)
|
||||
}
|
||||
|
||||
m.deleteItemLocked(key)
|
||||
m.deletes.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
|
||||
if item, exists := m.items[key]; exists {
|
||||
m.currentMemory -= item.size
|
||||
m.currentSize--
|
||||
if item.element != nil {
|
||||
m.lruList.Remove(item.element)
|
||||
}
|
||||
delete(m.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// evictLocked evicts items based on the eviction policy (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) evictLocked() {
|
||||
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
|
||||
// Evict least recently used item
|
||||
element := m.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
m.deleteItemLocked(item.key)
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if m.closed.Load() {
|
||||
return false, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return !item.isExpired(), nil
|
||||
shard := m.getShard(key)
|
||||
return shard.exists(key), nil
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache
|
||||
@@ -297,13 +314,9 @@ func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.items = make(map[string]*memoryCacheItem)
|
||||
m.lruList = list.New()
|
||||
m.currentSize = 0
|
||||
m.currentMemory = 0
|
||||
for _, shard := range m.shards {
|
||||
shard.clear()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -314,29 +327,28 @@ func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range m.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
var allKeys []string
|
||||
for _, shard := range m.shards {
|
||||
keys := shard.keys(pattern)
|
||||
allKeys = append(allKeys, keys...)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
return allKeys, nil
|
||||
}
|
||||
|
||||
// Size returns the number of items in the cache
|
||||
// Size returns the total number of items in the cache
|
||||
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
var total int64
|
||||
for _, shard := range m.shards {
|
||||
size, _ := shard.stats()
|
||||
total += size
|
||||
}
|
||||
|
||||
return m.currentSize, nil
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// TTL returns the remaining time-to-live for a key
|
||||
@@ -345,24 +357,13 @@ func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
shard := m.getShard(key)
|
||||
ttl, exists := shard.ttl(key)
|
||||
if !exists {
|
||||
return 0, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, nil // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return remaining, nil
|
||||
return ttl, nil
|
||||
}
|
||||
|
||||
// Expire updates the TTL for an existing key
|
||||
@@ -371,20 +372,11 @@ func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Du
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
item, exists := m.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
shard := m.getShard(key)
|
||||
if !shard.expire(key, ttl) {
|
||||
return ErrCacheMiss
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -394,6 +386,14 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
// Aggregate stats from all shards
|
||||
var totalSize, totalMemory int64
|
||||
for _, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
totalSize += size
|
||||
totalMemory += memory
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
lastError := m.lastError
|
||||
lastErrorTime := m.lastErrorTime
|
||||
@@ -417,9 +417,9 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
|
||||
Deletes: m.deletes.Load(),
|
||||
Errors: m.errors.Load(),
|
||||
Evictions: m.evictions.Load(),
|
||||
CurrentSize: m.currentSize,
|
||||
CurrentSize: totalSize,
|
||||
MaxSize: m.maxSize,
|
||||
MemoryUsage: m.currentMemory,
|
||||
MemoryUsage: totalMemory,
|
||||
AverageGetLatency: avgGetLatency,
|
||||
AverageSetLatency: avgSetLatency,
|
||||
LastError: lastError,
|
||||
@@ -446,10 +446,10 @@ func (m *MemoryCacheBackend) Close() error {
|
||||
m.cleanupTicker.Stop()
|
||||
close(m.cleanupDone)
|
||||
|
||||
m.mu.Lock()
|
||||
m.items = nil
|
||||
m.lruList = nil
|
||||
m.mu.Unlock()
|
||||
// Clear all shards
|
||||
for _, shard := range m.shards {
|
||||
shard.clear()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -482,12 +482,28 @@ func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
|
||||
}
|
||||
}
|
||||
|
||||
// GetShardCount returns the number of shards (for testing/monitoring)
|
||||
func (m *MemoryCacheBackend) GetShardCount() uint32 {
|
||||
return m.shardCount
|
||||
}
|
||||
|
||||
// GetShardStats returns per-shard statistics (for monitoring)
|
||||
func (m *MemoryCacheBackend) GetShardStats() []map[string]int64 {
|
||||
stats := make([]map[string]int64, m.shardCount)
|
||||
for i, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
stats[i] = map[string]int64{
|
||||
"size": size,
|
||||
"memory": memory,
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// estimateValueSize estimates the size of a value in bytes
|
||||
func estimateValueSize(value interface{}) int64 {
|
||||
// This is a simplified estimation
|
||||
// In production, you might want to use a more accurate method
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return int64(len(v))
|
||||
@@ -510,7 +526,10 @@ func matchPattern(pattern, key string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
// Simplified pattern matching - in production, use a proper glob library
|
||||
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
|
||||
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
|
||||
// Simplified pattern matching
|
||||
if len(pattern) > 0 && pattern[0] == '*' {
|
||||
suffix := pattern[1:]
|
||||
return len(key) >= len(suffix) && key[len(key)-len(suffix):] == suffix
|
||||
}
|
||||
return key == pattern
|
||||
}
|
||||
|
||||
+290
@@ -0,0 +1,290 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// cacheShard represents a single shard of the sharded cache
|
||||
// Each shard has its own lock for reduced contention
|
||||
type cacheShard struct {
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
mu sync.RWMutex
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
size int64
|
||||
memoryUsed int64
|
||||
}
|
||||
|
||||
// newCacheShard creates a new cache shard
|
||||
func newCacheShard(maxSize, maxMemory int64) *cacheShard {
|
||||
return &cacheShard{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
}
|
||||
}
|
||||
|
||||
// get retrieves a value from this shard
|
||||
// Returns: value, exists, expired
|
||||
func (s *cacheShard) get(key string) (interface{}, bool, bool) {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
return nil, true, true // exists but expired
|
||||
}
|
||||
|
||||
// Update access time and LRU position under write lock
|
||||
s.mu.Lock()
|
||||
// Re-check item exists (could have been deleted)
|
||||
item, exists = s.items[key]
|
||||
if exists && !item.isExpired() {
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
if elem, ok := item.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.MoveToFront(elem)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
return item.value, true, false
|
||||
}
|
||||
|
||||
// set stores a value in this shard
|
||||
func (s *cacheShard) set(key string, value interface{}, expiresAt time.Time, size int64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if we need to evict items
|
||||
if s.maxSize > 0 && s.size >= s.maxSize {
|
||||
s.evictLRULocked()
|
||||
}
|
||||
if s.maxMemory > 0 && s.memoryUsed+size > s.maxMemory {
|
||||
s.evictLRULocked()
|
||||
}
|
||||
|
||||
// Remove old item if exists
|
||||
if oldItem, exists := s.items[key]; exists {
|
||||
s.memoryUsed -= oldItem.size
|
||||
if elem, ok := oldItem.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.Remove(elem)
|
||||
}
|
||||
s.size--
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: size,
|
||||
}
|
||||
|
||||
item.element = s.lruList.PushFront(item)
|
||||
s.items[key] = item
|
||||
s.size++
|
||||
s.memoryUsed += size
|
||||
}
|
||||
|
||||
// delete removes a key from this shard
|
||||
// Returns true if the key was deleted
|
||||
func (s *cacheShard) delete(key string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
item, exists := s.items[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
}
|
||||
|
||||
// exists checks if a key exists (and is not expired)
|
||||
func (s *cacheShard) exists(key string) bool {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
return !item.isExpired()
|
||||
}
|
||||
|
||||
// ttl returns the remaining TTL for a key
|
||||
func (s *cacheShard) ttl(key string) (time.Duration, bool) {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, true // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return remaining, true
|
||||
}
|
||||
|
||||
// expire updates the TTL for an existing key
|
||||
func (s *cacheShard) expire(key string, ttl time.Duration) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
item, exists := s.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
return false
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// keys returns all non-expired keys matching the pattern
|
||||
func (s *cacheShard) keys(pattern string) []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range s.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// clear removes all items from this shard
|
||||
func (s *cacheShard) clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.items = make(map[string]*memoryCacheItem)
|
||||
s.lruList.Init()
|
||||
s.size = 0
|
||||
s.memoryUsed = 0
|
||||
}
|
||||
|
||||
// cleanup removes expired items
|
||||
// Returns the number of items removed
|
||||
func (s *cacheShard) cleanup() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var toRemove []*memoryCacheItem
|
||||
for _, item := range s.items {
|
||||
if item.isExpired() {
|
||||
toRemove = append(toRemove, item)
|
||||
}
|
||||
}
|
||||
|
||||
for _, item := range toRemove {
|
||||
s.deleteItemLocked(item)
|
||||
}
|
||||
|
||||
return len(toRemove)
|
||||
}
|
||||
|
||||
// stats returns statistics for this shard
|
||||
func (s *cacheShard) stats() (size, memory int64) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.size, s.memoryUsed
|
||||
}
|
||||
|
||||
// deleteItemLocked removes an item (must be called with lock held)
|
||||
func (s *cacheShard) deleteItemLocked(item *memoryCacheItem) {
|
||||
if elem, ok := item.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.Remove(elem)
|
||||
}
|
||||
delete(s.items, item.key)
|
||||
s.size--
|
||||
s.memoryUsed -= item.size
|
||||
}
|
||||
|
||||
// evictLRULocked evicts the least recently used item (must be called with lock held)
|
||||
func (s *cacheShard) evictLRULocked() bool {
|
||||
if s.lruList.Len() == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// evictOne evicts one item from this shard (for global limit enforcement)
|
||||
func (s *cacheShard) evictOne() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.evictLRULocked()
|
||||
}
|
||||
|
||||
// getOldestAccessTime returns the access time of the LRU item (oldest) in this shard
|
||||
// Returns zero time if shard is empty
|
||||
func (s *cacheShard) getOldestAccessTime() time.Time {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.lruList.Len() == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
return item.accessedAt
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// fnv32 computes FNV-1a hash of a string
|
||||
// This is a fast, well-distributed hash function
|
||||
func fnv32(key string) uint32 {
|
||||
const (
|
||||
offset32 = uint32(2166136261)
|
||||
prime32 = uint32(16777619)
|
||||
)
|
||||
|
||||
hash := offset32
|
||||
for i := 0; i < len(key); i++ {
|
||||
hash ^= uint32(key[i])
|
||||
hash *= prime32
|
||||
}
|
||||
return hash
|
||||
}
|
||||
+283
@@ -0,0 +1,283 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestShardedCache_ShardDistribution tests that keys are distributed across shards
|
||||
func TestShardedCache_ShardDistribution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a cache with large enough size to have multiple shards
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024 // 100MB
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add many items to see distribution
|
||||
numItems := 1000
|
||||
for i := 0; i < numItems; i++ {
|
||||
key := fmt.Sprintf("dist-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("dist-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Check that items are distributed across multiple shards
|
||||
shardStats := backend.MemoryCacheBackend.GetShardStats()
|
||||
nonEmptyShards := 0
|
||||
for _, stat := range shardStats {
|
||||
if stat["size"] > 0 {
|
||||
nonEmptyShards++
|
||||
}
|
||||
}
|
||||
|
||||
// With good hash distribution, we should have items in multiple shards
|
||||
assert.Greater(t, nonEmptyShards, 1, "Items should be distributed across multiple shards")
|
||||
}
|
||||
|
||||
// TestShardedCache_ShardCount tests that shard count adapts to cache size
|
||||
func TestShardedCache_ShardCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
maxSize int
|
||||
expectLowShards bool
|
||||
}{
|
||||
{5, true}, // Very small cache should have fewer shards
|
||||
{100, true}, // Small cache should have fewer shards
|
||||
{10000, false}, // Large cache should have default shards
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("MaxSize_%d", tt.maxSize), func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = tt.maxSize
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
shardCount := backend.MemoryCacheBackend.GetShardCount()
|
||||
|
||||
if tt.expectLowShards {
|
||||
assert.Less(t, shardCount, uint32(256), "Small cache should have fewer shards")
|
||||
} else {
|
||||
assert.Equal(t, uint32(256), shardCount, "Large cache should have default shard count")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShardedCache_ConcurrentSameKey tests concurrent access to the same key
|
||||
func TestShardedCache_ConcurrentSameKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "concurrent-same-key"
|
||||
initialValue := []byte("initial-value")
|
||||
|
||||
err = backend.Set(ctx, key, initialValue, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 50
|
||||
iterations := 100
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Mix of reads and writes
|
||||
if j%3 == 0 {
|
||||
newValue := []byte(fmt.Sprintf("value-%d-%d", id, j))
|
||||
err := backend.Set(ctx, key, newValue, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
_, _, _, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Key should still exist
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
// TestShardedCache_GlobalLRUEviction tests that global LRU is maintained
|
||||
func TestShardedCache_GlobalLRUEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a small cache to force eviction
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
// Small delay to ensure different access times
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Access some items to make them recently used
|
||||
for i := 5; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
_, _, _, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Add more items to trigger eviction
|
||||
for i := 10; i < 15; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Recently accessed items (5-9) should still exist
|
||||
for i := 5; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Recently accessed item %d should exist", i)
|
||||
}
|
||||
|
||||
// Check eviction stats
|
||||
stats := backend.GetStats()
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have evictions")
|
||||
}
|
||||
|
||||
// TestShardedCache_StatsAggregation tests that stats are aggregated correctly
|
||||
func TestShardedCache_StatsAggregation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10000
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items to multiple shards
|
||||
numItems := 100
|
||||
for i := 0; i < numItems; i++ {
|
||||
key := fmt.Sprintf("stats-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("stats-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Read some items
|
||||
for i := 0; i < numItems/2; i++ {
|
||||
key := fmt.Sprintf("stats-key-%d", i)
|
||||
backend.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Read non-existent items
|
||||
for i := 0; i < 10; i++ {
|
||||
backend.Get(ctx, fmt.Sprintf("nonexistent-%d", i))
|
||||
}
|
||||
|
||||
stats := backend.GetStats()
|
||||
|
||||
// Verify stats
|
||||
assert.Equal(t, int64(numItems), stats["sets"].(int64), "Sets should match")
|
||||
assert.Equal(t, int64(numItems/2), stats["hits"].(int64), "Hits should match")
|
||||
assert.Equal(t, int64(10), stats["misses"].(int64), "Misses should match")
|
||||
assert.Equal(t, int64(numItems), stats["size"].(int64), "Size should match")
|
||||
|
||||
// Verify hit rate
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
expectedHitRate := float64(numItems/2) / float64(numItems/2+10)
|
||||
assert.InDelta(t, expectedHitRate, hitRate, 0.01, "Hit rate should match")
|
||||
}
|
||||
|
||||
// BenchmarkShardedCache_Parallel benchmarks parallel access
|
||||
func BenchmarkShardedCache_Parallel(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024
|
||||
|
||||
backend, _ := NewMemoryBackend(config)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 10000; i++ {
|
||||
key := fmt.Sprintf("bench-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("bench-key-%d", i%10000)
|
||||
backend.Get(ctx, key)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkShardedCache_MixedOps benchmarks mixed operations
|
||||
func BenchmarkShardedCache_MixedOps(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024
|
||||
|
||||
backend, _ := NewMemoryBackend(config)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("mixed-key-%d", i%1000)
|
||||
if i%3 == 0 {
|
||||
value := []byte(fmt.Sprintf("mixed-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
} else {
|
||||
backend.Get(ctx, key)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
+20
-30
@@ -45,21 +45,11 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
return nil, 0, false, err
|
||||
}
|
||||
|
||||
// Get the item directly to check TTL
|
||||
m.MemoryCacheBackend.mu.RLock()
|
||||
item, exists := m.MemoryCacheBackend.items[key]
|
||||
m.MemoryCacheBackend.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
var ttl time.Duration
|
||||
if !item.expiresAt.IsZero() {
|
||||
ttl = time.Until(item.expiresAt)
|
||||
if ttl < 0 {
|
||||
ttl = 0
|
||||
}
|
||||
// Get TTL using the TTL method
|
||||
ttl, ttlErr := m.MemoryCacheBackend.TTL(ctx, key)
|
||||
if ttlErr != nil {
|
||||
// If we can't get TTL, still return the value with 0 TTL
|
||||
ttl = 0
|
||||
}
|
||||
|
||||
// Convert interface{} to []byte
|
||||
@@ -68,8 +58,7 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
if bytes, ok := val.([]byte); ok {
|
||||
valueBytes = bytes
|
||||
} else {
|
||||
// If it's not already []byte, we might need to handle other types
|
||||
// For now, we'll just return an error
|
||||
// If it's not already []byte, return an error
|
||||
return nil, 0, false, ErrInvalidValue
|
||||
}
|
||||
}
|
||||
@@ -123,19 +112,20 @@ func (m *MemoryBackend) GetStats() map[string]interface{} {
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
"shard_count": m.MemoryCacheBackend.GetShardCount(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Vendored
+107
-11
@@ -431,39 +431,135 @@ func isRetryableError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SetMany stores multiple values in Redis (batch operation)
|
||||
// SetMany stores multiple values in Redis using pipelining for efficiency
|
||||
// This reduces N round-trips to a single round-trip
|
||||
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
// For simplicity, execute sequentially (can be optimized with pipelining later)
|
||||
for key, value := range items {
|
||||
if err := r.Set(ctx, key, value, ttl); err != nil {
|
||||
return err
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For single items, use regular Set
|
||||
if len(items) == 1 {
|
||||
for key, value := range items {
|
||||
return r.Set(ctx, key, value, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
// Queue all SET commands
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
ttlMillis := ttl.Milliseconds()
|
||||
|
||||
for key, value := range items {
|
||||
prefixedKey := r.prefixKey(key)
|
||||
|
||||
if ttl > 0 {
|
||||
if ttlMillis < 1000 {
|
||||
// Use PSETEX for sub-second TTLs
|
||||
pipeline.Queue("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
|
||||
} else {
|
||||
// Use SETEX for larger TTLs
|
||||
pipeline.Queue("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
|
||||
}
|
||||
} else {
|
||||
pipeline.Queue("SET", prefixedKey, string(value))
|
||||
}
|
||||
}
|
||||
|
||||
// Execute pipeline
|
||||
responses, err := pipeline.Execute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pipeline SetMany failed: %w", err)
|
||||
}
|
||||
|
||||
// Check responses for errors (each should be "OK")
|
||||
for i, resp := range responses {
|
||||
if resp == nil {
|
||||
continue
|
||||
}
|
||||
if str, ok := resp.(string); ok && str == "OK" {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("SetMany: unexpected response at index %d: %v", i, resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values from Redis
|
||||
// GetMany retrieves multiple values from Redis using pipelining for efficiency
|
||||
// This reduces N round-trips to a single round-trip
|
||||
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
result := make(map[string][]byte)
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
// For simplicity, execute sequentially
|
||||
for _, key := range keys {
|
||||
value, _, exists, err := r.Get(ctx, key)
|
||||
// For single key, use regular Get
|
||||
if len(keys) == 1 {
|
||||
result := make(map[string][]byte)
|
||||
value, _, exists, err := r.Get(ctx, keys[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
result[key] = value
|
||||
result[keys[0]] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
// Queue all GET commands
|
||||
prefixedKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
prefixedKeys[i] = r.prefixKey(key)
|
||||
pipeline.Queue("GET", prefixedKeys[i])
|
||||
}
|
||||
|
||||
// Execute pipeline
|
||||
responses, err := pipeline.Execute()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pipeline GetMany failed: %w", err)
|
||||
}
|
||||
|
||||
// Process responses
|
||||
result := make(map[string][]byte)
|
||||
for i, resp := range responses {
|
||||
if resp == nil {
|
||||
// Key doesn't exist
|
||||
r.misses.Add(1)
|
||||
continue
|
||||
}
|
||||
|
||||
value, err := RESPString(resp)
|
||||
if err != nil {
|
||||
// Invalid response, skip this key
|
||||
r.misses.Add(1)
|
||||
continue
|
||||
}
|
||||
|
||||
r.hits.Add(1)
|
||||
result[keys[i]] = []byte(value)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
+10
-16
@@ -9,30 +9,24 @@ import (
|
||||
|
||||
// HealthMonitor continuously monitors Redis connection health and triggers reconnections
|
||||
type HealthMonitor struct {
|
||||
pool *ConnectionPool
|
||||
config *HealthMonitorConfig
|
||||
|
||||
// State
|
||||
healthy atomic.Bool
|
||||
running atomic.Bool
|
||||
lastCheckTime atomic.Int64 // Unix timestamp
|
||||
|
||||
// Metrics
|
||||
pool *ConnectionPool
|
||||
config *HealthMonitorConfig
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
lastCheckTime atomic.Int64
|
||||
consecutiveFailures atomic.Int64
|
||||
totalChecks atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
healthy atomic.Bool
|
||||
running atomic.Bool
|
||||
}
|
||||
|
||||
// HealthMonitorConfig configures the health monitor
|
||||
type HealthMonitorConfig struct {
|
||||
CheckInterval time.Duration // How often to check health
|
||||
Timeout time.Duration // Timeout for health check
|
||||
UnhealthyThreshold int // Consecutive failures before marking unhealthy
|
||||
OnHealthChange func(healthy bool)
|
||||
CheckInterval time.Duration
|
||||
Timeout time.Duration
|
||||
UnhealthyThreshold int
|
||||
}
|
||||
|
||||
// DefaultHealthMonitorConfig returns default health monitor configuration
|
||||
|
||||
+461
@@ -0,0 +1,461 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// setupTestRedis creates a miniredis instance for testing
|
||||
func setupTestRedis(t *testing.T) (*miniredis.Miniredis, *RedisBackend) {
|
||||
t.Helper()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
mr.Close()
|
||||
})
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "test:",
|
||||
PoolSize: 5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
backend.Close()
|
||||
})
|
||||
|
||||
return mr, backend
|
||||
}
|
||||
|
||||
// TestPipeline_Basic tests basic pipeline functionality
|
||||
func TestPipeline_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.Addr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("SingleCommand", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("SET", "single-key", "single-value")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 1)
|
||||
assert.Equal(t, "OK", responses[0])
|
||||
})
|
||||
|
||||
t.Run("MultipleCommands", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("SET", "key1", "value1")
|
||||
pipeline.Queue("SET", "key2", "value2")
|
||||
pipeline.Queue("SET", "key3", "value3")
|
||||
pipeline.Queue("GET", "key1")
|
||||
pipeline.Queue("GET", "key2")
|
||||
pipeline.Queue("GET", "key3")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 6)
|
||||
|
||||
// First 3 are SET responses
|
||||
assert.Equal(t, "OK", responses[0])
|
||||
assert.Equal(t, "OK", responses[1])
|
||||
assert.Equal(t, "OK", responses[2])
|
||||
|
||||
// Last 3 are GET responses
|
||||
assert.Equal(t, "value1", responses[3])
|
||||
assert.Equal(t, "value2", responses[4])
|
||||
assert.Equal(t, "value3", responses[5])
|
||||
})
|
||||
|
||||
t.Run("EmptyPipeline", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, responses)
|
||||
})
|
||||
|
||||
t.Run("NilResponses", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("GET", "nonexistent-key")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 1)
|
||||
assert.Nil(t, responses[0])
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_SetMany tests pipelined SetMany
|
||||
func TestPipeline_SetMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetManyItems", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 10; i++ {
|
||||
items[fmt.Sprintf("setmany-key-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all items were set
|
||||
for key, expectedValue := range items {
|
||||
value, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key %s should exist", key)
|
||||
assert.Equal(t, expectedValue, value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SetManyEmpty", func(t *testing.T) {
|
||||
err := backend.SetMany(ctx, map[string][]byte{}, time.Minute)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetManySingleItem", func(t *testing.T) {
|
||||
items := map[string][]byte{
|
||||
"single-setmany": []byte("single-value"),
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
value, _, exists, err := backend.Get(ctx, "single-setmany")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("single-value"), value)
|
||||
})
|
||||
|
||||
t.Run("SetManyNoTTL", func(t *testing.T) {
|
||||
items := map[string][]byte{
|
||||
"nottl-key1": []byte("value1"),
|
||||
"nottl-key2": []byte("value2"),
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Keys should exist
|
||||
for key := range items {
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_GetMany tests pipelined GetMany
|
||||
func TestPipeline_GetMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("getmany-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("GetManyExisting", func(t *testing.T) {
|
||||
keys := make([]string, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
keys[i] = fmt.Sprintf("getmany-key-%d", i)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 10)
|
||||
|
||||
for i, key := range keys {
|
||||
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), results[key])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetManyMixed", func(t *testing.T) {
|
||||
keys := []string{
|
||||
"getmany-key-0", // exists
|
||||
"nonexistent-key-1", // doesn't exist
|
||||
"getmany-key-2", // exists
|
||||
"nonexistent-key-2", // doesn't exist
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only existing keys
|
||||
|
||||
assert.Equal(t, []byte("value-0"), results["getmany-key-0"])
|
||||
assert.Equal(t, []byte("value-2"), results["getmany-key-2"])
|
||||
assert.NotContains(t, results, "nonexistent-key-1")
|
||||
assert.NotContains(t, results, "nonexistent-key-2")
|
||||
})
|
||||
|
||||
t.Run("GetManyEmpty", func(t *testing.T) {
|
||||
results, err := backend.GetMany(ctx, []string{})
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, results)
|
||||
assert.Len(t, results, 0)
|
||||
})
|
||||
|
||||
t.Run("GetManySingleKey", func(t *testing.T) {
|
||||
results, err := backend.GetMany(ctx, []string{"getmany-key-5"})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, []byte("value-5"), results["getmany-key-5"])
|
||||
})
|
||||
|
||||
t.Run("GetManyAllNonexistent", func(t *testing.T) {
|
||||
keys := []string{
|
||||
"nonexistent-1",
|
||||
"nonexistent-2",
|
||||
"nonexistent-3",
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 0)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_LargeBatch tests pipelining with large batches
|
||||
func TestPipeline_LargeBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMany100Items", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 100; i++ {
|
||||
items[fmt.Sprintf("large-batch-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify random samples
|
||||
for _, i := range []int{0, 25, 50, 75, 99} {
|
||||
key := fmt.Sprintf("large-batch-%d", i)
|
||||
value, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetMany100Items", func(t *testing.T) {
|
||||
keys := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
keys[i] = fmt.Sprintf("large-batch-%d", i)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 100)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_Stats tests that stats are tracked correctly with pipelining
|
||||
func TestPipeline_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Set some items
|
||||
items := map[string][]byte{
|
||||
"stats-key-1": []byte("value1"),
|
||||
"stats-key-2": []byte("value2"),
|
||||
}
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get items (some exist, some don't)
|
||||
keys := []string{
|
||||
"stats-key-1",
|
||||
"stats-key-2",
|
||||
"stats-key-nonexistent",
|
||||
}
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
// Check stats
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
|
||||
assert.Equal(t, int64(2), hits, "Should have 2 hits")
|
||||
assert.Equal(t, int64(1), misses, "Should have 1 miss")
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_SetMany benchmarks SetMany with pipelining
|
||||
func BenchmarkPipeline_SetMany(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare items
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 100; i++ {
|
||||
items[fmt.Sprintf("bench-key-%d", i)] = []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = backend.SetMany(ctx, items, time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_GetMany benchmarks GetMany with pipelining
|
||||
func BenchmarkPipeline_GetMany(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("bench-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
}
|
||||
|
||||
// Prepare keys
|
||||
keys := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
keys[i] = fmt.Sprintf("bench-key-%d", i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = backend.GetMany(ctx, keys)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_VsSequential benchmarks pipeline vs sequential operations
|
||||
func BenchmarkPipeline_VsSequential(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare items
|
||||
items := make(map[string][]byte)
|
||||
keys := make([]string, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
key := fmt.Sprintf("compare-key-%d", i)
|
||||
keys[i] = key
|
||||
items[key] = []byte(fmt.Sprintf("compare-value-%d", i))
|
||||
}
|
||||
|
||||
b.Run("Pipelined-Set", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = backend.SetMany(ctx, items, time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Sequential-Set", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for key, value := range items {
|
||||
_ = backend.Set(ctx, key, value, time.Minute)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Pre-populate for get benchmarks
|
||||
_ = backend.SetMany(ctx, items, time.Hour)
|
||||
|
||||
b.Run("Pipelined-Get", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = backend.GetMany(ctx, keys)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Sequential-Get", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, key := range keys {
|
||||
_, _, _, _ = backend.Get(ctx, key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
+117
@@ -336,3 +336,120 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
_, err := conn.Do("PING")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Pipeline represents a Redis pipeline for batch operations
|
||||
// It queues multiple commands and executes them in a single round-trip
|
||||
type Pipeline struct {
|
||||
conn *RedisConn
|
||||
commands []pipelineCommand
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// pipelineCommand represents a single command in the pipeline
|
||||
type pipelineCommand struct {
|
||||
command string
|
||||
args []string
|
||||
}
|
||||
|
||||
// NewPipeline creates a new pipeline for the connection
|
||||
func (c *RedisConn) NewPipeline() *Pipeline {
|
||||
return &Pipeline{
|
||||
conn: c,
|
||||
commands: make([]pipelineCommand, 0, 16), // Pre-allocate for typical batch size
|
||||
}
|
||||
}
|
||||
|
||||
// Queue adds a command to the pipeline
|
||||
func (p *Pipeline) Queue(command string, args ...string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.commands = append(p.commands, pipelineCommand{
|
||||
command: command,
|
||||
args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Execute sends all queued commands and returns all responses
|
||||
// Returns a slice of responses in the same order as commands were queued
|
||||
func (p *Pipeline) Execute() ([]interface{}, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if len(p.commands) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if p.conn.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
p.conn.mu.Lock()
|
||||
defer p.conn.mu.Unlock()
|
||||
|
||||
// Set write timeout for all commands
|
||||
if p.conn.writeTimeout > 0 {
|
||||
// Use longer timeout for batch operations
|
||||
timeout := p.conn.writeTimeout * time.Duration(len(p.commands))
|
||||
if timeout > 30*time.Second {
|
||||
timeout = 30 * time.Second // Cap at 30 seconds
|
||||
}
|
||||
_ = p.conn.conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
// Write all commands (pipelining - send all before reading any responses)
|
||||
writer := NewRESPWriter(p.conn.conn)
|
||||
for _, cmd := range p.commands {
|
||||
cmdArgs := append([]string{cmd.command}, cmd.args...)
|
||||
if err := writer.WriteCommand(cmdArgs...); err != nil {
|
||||
writer.Release()
|
||||
p.conn.closed.Store(true)
|
||||
return nil, fmt.Errorf("pipeline write error: %w", err)
|
||||
}
|
||||
}
|
||||
writer.Release()
|
||||
|
||||
// Set read timeout for all responses
|
||||
if p.conn.readTimeout > 0 {
|
||||
timeout := p.conn.readTimeout * time.Duration(len(p.commands))
|
||||
if timeout > 30*time.Second {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
_ = p.conn.conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
// Read all responses
|
||||
responses := make([]interface{}, len(p.commands))
|
||||
reader := NewRESPReader(p.conn.conn)
|
||||
defer reader.Release()
|
||||
|
||||
for i := range p.commands {
|
||||
resp, err := reader.ReadResponse()
|
||||
if err != nil {
|
||||
// For nil responses, store nil instead of erroring
|
||||
if errors.Is(err, ErrNilResponse) {
|
||||
responses[i] = nil
|
||||
continue
|
||||
}
|
||||
p.conn.closed.Store(true)
|
||||
return responses[:i], fmt.Errorf("pipeline read error at command %d: %w", i, err)
|
||||
}
|
||||
responses[i] = resp
|
||||
}
|
||||
|
||||
return responses, nil
|
||||
}
|
||||
|
||||
// Clear resets the pipeline for reuse
|
||||
func (p *Pipeline) Clear() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.commands = p.commands[:0]
|
||||
}
|
||||
|
||||
// Len returns the number of queued commands
|
||||
func (p *Pipeline) Len() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return len(p.commands)
|
||||
}
|
||||
|
||||
Vendored
+5
-5
@@ -15,8 +15,8 @@ import (
|
||||
func TestRESPWriter_WriteCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected string
|
||||
args []string
|
||||
}{
|
||||
{
|
||||
name: "Simple command",
|
||||
@@ -205,9 +205,9 @@ func TestRESPReader_ReadInteger(t *testing.T) {
|
||||
// TestRESPReader_ReadBulkString tests reading bulk strings
|
||||
func TestRESPReader_ReadBulkString(t *testing.T) {
|
||||
tests := []struct {
|
||||
expected interface{}
|
||||
name string
|
||||
input string
|
||||
expected interface{}
|
||||
wantErr bool
|
||||
isNil bool
|
||||
}{
|
||||
@@ -440,10 +440,10 @@ func TestRESPHelpers(t *testing.T) {
|
||||
// TestRESPRoundTrip tests full round-trip encoding/decoding
|
||||
func TestRESPRoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command []string
|
||||
response string
|
||||
expected interface{}
|
||||
name string
|
||||
response string
|
||||
command []string
|
||||
}{
|
||||
{
|
||||
name: "PING command",
|
||||
|
||||
+183
@@ -0,0 +1,183 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SingleflightCache wraps a CacheBackend with singleflight deduplication
|
||||
// to prevent thundering herd problems when multiple concurrent requests
|
||||
// try to fetch the same uncached key.
|
||||
type SingleflightCache struct {
|
||||
backend CacheBackend
|
||||
mu sync.Mutex
|
||||
calls map[string]*singleflightCall
|
||||
|
||||
// Metrics
|
||||
deduplicatedCalls atomic.Int64
|
||||
totalCalls atomic.Int64
|
||||
}
|
||||
|
||||
// singleflightCall represents an in-flight or completed fetch call
|
||||
type singleflightCall struct {
|
||||
wg sync.WaitGroup
|
||||
val []byte
|
||||
ttl time.Duration
|
||||
err error
|
||||
done bool
|
||||
}
|
||||
|
||||
// NewSingleflightCache creates a new singleflight-wrapped cache backend
|
||||
func NewSingleflightCache(backend CacheBackend) *SingleflightCache {
|
||||
return &SingleflightCache{
|
||||
backend: backend,
|
||||
calls: make(map[string]*singleflightCall),
|
||||
}
|
||||
}
|
||||
|
||||
// Fetcher is a function type that fetches data when cache misses
|
||||
type Fetcher func(ctx context.Context) (value []byte, ttl time.Duration, err error)
|
||||
|
||||
// GetOrFetch retrieves a value from cache or calls the fetcher exactly once
|
||||
// per key when there's a cache miss. Concurrent calls for the same key will
|
||||
// wait for the first call to complete and share its result.
|
||||
func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher Fetcher) ([]byte, error) {
|
||||
s.totalCalls.Add(1)
|
||||
|
||||
// Try cache first
|
||||
value, _, exists, err := s.backend.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Cache miss - use singleflight
|
||||
s.mu.Lock()
|
||||
|
||||
// Check if there's already an in-flight call for this key
|
||||
if call, ok := s.calls[key]; ok {
|
||||
s.mu.Unlock()
|
||||
s.deduplicatedCalls.Add(1)
|
||||
|
||||
// Wait for the in-flight call to complete
|
||||
call.wg.Wait()
|
||||
|
||||
// Check context cancellation
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
return call.val, call.err
|
||||
}
|
||||
|
||||
// Create new call
|
||||
call := &singleflightCall{}
|
||||
call.wg.Add(1)
|
||||
s.calls[key] = call
|
||||
s.mu.Unlock()
|
||||
|
||||
// Execute the fetcher
|
||||
call.val, call.ttl, call.err = fetcher(ctx)
|
||||
call.done = true
|
||||
|
||||
// If successful, store in cache
|
||||
if call.err == nil && call.val != nil {
|
||||
// Use a background context for cache storage to ensure it completes
|
||||
// even if the original context is cancelled
|
||||
storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = s.backend.Set(storeCtx, key, call.val, call.ttl)
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Signal waiting goroutines
|
||||
call.wg.Done()
|
||||
|
||||
// Clean up the call from the map after a short delay
|
||||
// This allows late arrivals to still benefit from the result
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
s.mu.Lock()
|
||||
if c, ok := s.calls[key]; ok && c == call {
|
||||
delete(s.calls, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
return call.val, call.err
|
||||
}
|
||||
|
||||
// Get retrieves a value from the underlying cache backend
|
||||
func (s *SingleflightCache) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
return s.backend.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Set stores a value in the underlying cache backend
|
||||
func (s *SingleflightCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
return s.backend.Set(ctx, key, value, ttl)
|
||||
}
|
||||
|
||||
// Delete removes a key from the underlying cache backend
|
||||
func (s *SingleflightCache) Delete(ctx context.Context, key string) (bool, error) {
|
||||
return s.backend.Delete(ctx, key)
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the underlying cache backend
|
||||
func (s *SingleflightCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return s.backend.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Clear removes all keys from the underlying cache backend
|
||||
func (s *SingleflightCache) Clear(ctx context.Context) error {
|
||||
return s.backend.Clear(ctx)
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics including singleflight metrics
|
||||
func (s *SingleflightCache) GetStats() map[string]interface{} {
|
||||
stats := s.backend.GetStats()
|
||||
|
||||
// Add singleflight-specific stats
|
||||
totalCalls := s.totalCalls.Load()
|
||||
deduped := s.deduplicatedCalls.Load()
|
||||
|
||||
stats["singleflight_total_calls"] = totalCalls
|
||||
stats["singleflight_deduplicated"] = deduped
|
||||
if totalCalls > 0 {
|
||||
stats["singleflight_dedup_rate"] = float64(deduped) / float64(totalCalls)
|
||||
} else {
|
||||
stats["singleflight_dedup_rate"] = float64(0)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
stats["singleflight_inflight"] = len(s.calls)
|
||||
s.mu.Unlock()
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Close shuts down the cache backend
|
||||
func (s *SingleflightCache) Close() error {
|
||||
return s.backend.Close()
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy
|
||||
func (s *SingleflightCache) Ping(ctx context.Context) error {
|
||||
return s.backend.Ping(ctx)
|
||||
}
|
||||
|
||||
// GetBackend returns the underlying cache backend
|
||||
func (s *SingleflightCache) GetBackend() CacheBackend {
|
||||
return s.backend
|
||||
}
|
||||
|
||||
// ResetStats resets the singleflight statistics
|
||||
func (s *SingleflightCache) ResetStats() {
|
||||
s.totalCalls.Store(0)
|
||||
s.deduplicatedCalls.Store(0)
|
||||
}
|
||||
|
||||
// Ensure SingleflightCache implements CacheBackend
|
||||
var _ CacheBackend = (*SingleflightCache)(nil)
|
||||
+510
@@ -0,0 +1,510 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSingleflightCache_BasicGetOrFetch tests basic GetOrFetch functionality
|
||||
func TestSingleflightCache_BasicGetOrFetch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CacheHit", func(t *testing.T) {
|
||||
key := "existing-key"
|
||||
value := []byte("existing-value")
|
||||
|
||||
// Pre-populate cache
|
||||
err := cache.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
var fetchCalled bool
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCalled = true
|
||||
return []byte("fetched-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, value, result)
|
||||
assert.False(t, fetchCalled, "Fetcher should not be called on cache hit")
|
||||
})
|
||||
|
||||
t.Run("CacheMiss", func(t *testing.T) {
|
||||
key := "missing-key"
|
||||
expectedValue := []byte("fetched-value")
|
||||
|
||||
var fetchCalled bool
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCalled = true
|
||||
return expectedValue, time.Minute, nil
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedValue, result)
|
||||
assert.True(t, fetchCalled, "Fetcher should be called on cache miss")
|
||||
|
||||
// Verify value was stored in cache
|
||||
cached, _, exists, err := cache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, cached)
|
||||
})
|
||||
|
||||
t.Run("FetcherError", func(t *testing.T) {
|
||||
key := "error-key"
|
||||
expectedErr := errors.New("fetch failed")
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return nil, 0, expectedErr
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, expectedErr, err)
|
||||
assert.Nil(t, result)
|
||||
|
||||
// Verify nothing was stored in cache
|
||||
_, _, exists, err := cache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSingleflightCache_Deduplication tests that concurrent calls are deduplicated
|
||||
func TestSingleflightCache_Deduplication(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
key := "dedup-key"
|
||||
expectedValue := []byte("dedup-value")
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
// Simulate slow fetch
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return expectedValue, time.Minute, nil
|
||||
}
|
||||
|
||||
// Launch multiple concurrent requests
|
||||
concurrency := 10
|
||||
var wg sync.WaitGroup
|
||||
results := make([][]byte, concurrency)
|
||||
errs := make([]error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx], errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests got the same result
|
||||
for i := 0; i < concurrency; i++ {
|
||||
assert.NoError(t, errs[i])
|
||||
assert.Equal(t, expectedValue, results[i])
|
||||
}
|
||||
|
||||
// Verify fetcher was only called once
|
||||
assert.Equal(t, int32(1), fetchCount.Load(), "Fetcher should only be called once")
|
||||
|
||||
// Verify deduplication stats
|
||||
stats := cache.GetStats()
|
||||
deduped := stats["singleflight_deduplicated"].(int64)
|
||||
assert.Equal(t, int64(concurrency-1), deduped, "Should have deduplicated N-1 calls")
|
||||
}
|
||||
|
||||
// TestSingleflightCache_DifferentKeys tests that different keys can fetch in parallel
|
||||
func TestSingleflightCache_DifferentKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetchStarted := make(chan struct{}, 3)
|
||||
fetchComplete := make(chan struct{})
|
||||
|
||||
fetcher := func(key string) Fetcher {
|
||||
return func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
fetchStarted <- struct{}{}
|
||||
<-fetchComplete // Wait for signal
|
||||
return []byte("value-" + key), time.Minute, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Launch concurrent requests for different keys
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 3; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
key := fmt.Sprintf("key-%d", idx)
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher(key))
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all fetches to start
|
||||
for i := 0; i < 3; i++ {
|
||||
<-fetchStarted
|
||||
}
|
||||
|
||||
// All 3 fetches should be running in parallel
|
||||
assert.Equal(t, int32(3), fetchCount.Load(), "All three fetches should run in parallel")
|
||||
|
||||
// Release all fetches
|
||||
close(fetchComplete)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ContextCancellation tests context cancellation
|
||||
func TestSingleflightCache_ContextCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
key := "cancel-key"
|
||||
fetchStarted := make(chan struct{})
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
close(fetchStarted)
|
||||
// Simulate slow fetch
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
// Start first request with long timeout
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx := context.Background()
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}()
|
||||
|
||||
// Wait for fetch to start
|
||||
<-fetchStarted
|
||||
|
||||
// Start second request with short timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = cache.GetOrFetch(ctx, key, fetcher)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ErrorPropagation tests that errors are properly propagated
|
||||
func TestSingleflightCache_ErrorPropagation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
key := "error-prop-key"
|
||||
expectedErr := errors.New("intentional error")
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return nil, 0, expectedErr
|
||||
}
|
||||
|
||||
// Launch multiple concurrent requests
|
||||
concurrency := 5
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests got the same error
|
||||
for i := 0; i < concurrency; i++ {
|
||||
assert.Error(t, errs[i])
|
||||
assert.Equal(t, expectedErr, errs[i])
|
||||
}
|
||||
|
||||
// Verify fetcher was only called once
|
||||
assert.Equal(t, int32(1), fetchCount.Load())
|
||||
}
|
||||
|
||||
// TestSingleflightCache_PassthroughMethods tests that passthrough methods work
|
||||
func TestSingleflightCache_PassthroughMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Set", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "set-key", []byte("set-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, _, exists, err := cache.Get(ctx, "set-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("set-value"), val)
|
||||
})
|
||||
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "get-key", []byte("get-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, ttl, exists, err := cache.Get(ctx, "get-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("get-value"), val)
|
||||
assert.Greater(t, ttl, time.Duration(0))
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "delete-key", []byte("delete-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := cache.Delete(ctx, "delete-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := cache.Exists(ctx, "delete-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
exists, err := cache.Exists(ctx, "nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = cache.Set(ctx, "exists-key", []byte("value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = cache.Exists(ctx, "exists-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "clear-key", []byte("value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err := cache.Exists(ctx, "clear-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
err := cache.Ping(ctx)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSingleflightCache_Stats tests statistics tracking
|
||||
func TestSingleflightCache_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Make some calls
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = cache.GetOrFetch(ctx, "stats-key", fetcher)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
stats := cache.GetStats()
|
||||
|
||||
// Check singleflight stats exist
|
||||
assert.Contains(t, stats, "singleflight_total_calls")
|
||||
assert.Contains(t, stats, "singleflight_deduplicated")
|
||||
assert.Contains(t, stats, "singleflight_dedup_rate")
|
||||
assert.Contains(t, stats, "singleflight_inflight")
|
||||
|
||||
// Verify values
|
||||
assert.Equal(t, int64(5), stats["singleflight_total_calls"])
|
||||
assert.Equal(t, int64(4), stats["singleflight_deduplicated"])
|
||||
|
||||
// Also check underlying backend stats are included
|
||||
assert.Contains(t, stats, "hits")
|
||||
assert.Contains(t, stats, "misses")
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ResetStats tests stats reset
|
||||
func TestSingleflightCache_ResetStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
// Make some calls
|
||||
_, _ = cache.GetOrFetch(ctx, "key1", fetcher)
|
||||
_, _ = cache.GetOrFetch(ctx, "key2", fetcher)
|
||||
|
||||
stats := cache.GetStats()
|
||||
assert.Greater(t, stats["singleflight_total_calls"].(int64), int64(0))
|
||||
|
||||
// Reset stats
|
||||
cache.ResetStats()
|
||||
|
||||
stats = cache.GetStats()
|
||||
assert.Equal(t, int64(0), stats["singleflight_total_calls"])
|
||||
assert.Equal(t, int64(0), stats["singleflight_deduplicated"])
|
||||
}
|
||||
|
||||
// TestSingleflightCache_GetBackend tests GetBackend method
|
||||
func TestSingleflightCache_GetBackend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
assert.Equal(t, backend, cache.GetBackend())
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_Sequential benchmarks sequential access
|
||||
func BenchmarkSingleflightCache_Sequential(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := fmt.Sprintf("key-%d", i%100)
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_Concurrent benchmarks concurrent access
|
||||
func BenchmarkSingleflightCache_Concurrent(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(time.Millisecond) // Simulate slow fetch
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("key-%d", i%10) // Only 10 unique keys to force deduplication
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_HighContention benchmarks high contention scenario
|
||||
func BenchmarkSingleflightCache_HighContention(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(10 * time.Millisecond) // Slow fetch to force queuing
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
// All goroutines hit the same key
|
||||
_, _ = cache.GetOrFetch(ctx, "hot-key", fetcher)
|
||||
}
|
||||
})
|
||||
}
|
||||
Vendored
+28
-40
@@ -33,21 +33,19 @@ type Logger interface {
|
||||
|
||||
// Config provides configuration for the cache
|
||||
type Config struct {
|
||||
Logger Logger
|
||||
JWKConfig *JWKConfig
|
||||
MetadataConfig *MetadataConfig
|
||||
TokenConfig *TokenConfig
|
||||
Type Type
|
||||
MaxSize int
|
||||
MaxMemoryBytes int64
|
||||
DefaultTTL time.Duration
|
||||
CleanupInterval time.Duration
|
||||
EnableCompression bool
|
||||
MaxMemoryBytes int64
|
||||
MaxSize int
|
||||
EnableMetrics bool
|
||||
EnableAutoCleanup bool
|
||||
EnableMemoryLimit bool
|
||||
Logger Logger
|
||||
|
||||
// Type-specific configurations
|
||||
TokenConfig *TokenConfig
|
||||
MetadataConfig *MetadataConfig
|
||||
JWKConfig *JWKConfig
|
||||
EnableCompression bool
|
||||
}
|
||||
|
||||
// TokenConfig provides token-specific cache configuration
|
||||
@@ -59,11 +57,11 @@ type TokenConfig struct {
|
||||
|
||||
// MetadataConfig provides metadata-specific cache configuration
|
||||
type MetadataConfig struct {
|
||||
SecurityCriticalFields []string
|
||||
GracePeriod time.Duration
|
||||
ExtendedGracePeriod time.Duration
|
||||
MaxGracePeriod time.Duration
|
||||
SecurityCriticalMaxGracePeriod time.Duration
|
||||
SecurityCriticalFields []string
|
||||
}
|
||||
|
||||
// JWKConfig provides JWK-specific cache configuration
|
||||
@@ -75,45 +73,35 @@ type JWKConfig struct {
|
||||
|
||||
// Item represents a single cache entry
|
||||
type Item struct {
|
||||
Key string
|
||||
Value interface{}
|
||||
Size int64
|
||||
ExpiresAt time.Time
|
||||
LastAccessed time.Time
|
||||
AccessCount int64
|
||||
Value interface{}
|
||||
Metadata map[string]interface{}
|
||||
element *list.Element
|
||||
Key string
|
||||
CacheType Type
|
||||
|
||||
// Type-specific metadata
|
||||
Metadata map[string]interface{}
|
||||
|
||||
// LRU list element reference
|
||||
element *list.Element
|
||||
Size int64
|
||||
AccessCount int64
|
||||
}
|
||||
|
||||
// Cache provides a single, unified cache implementation
|
||||
type Cache struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*Item
|
||||
lruList *list.List
|
||||
config Config
|
||||
logger Logger
|
||||
|
||||
// Memory management
|
||||
config Config
|
||||
ctx context.Context
|
||||
logger Logger
|
||||
cancel context.CancelFunc
|
||||
lruList *list.List
|
||||
items map[string]*Item
|
||||
stopCleanup chan bool
|
||||
wg sync.WaitGroup
|
||||
currentSize int64
|
||||
currentMemory int64
|
||||
|
||||
// Metrics
|
||||
hits int64
|
||||
misses int64
|
||||
evictions int64
|
||||
sets int64
|
||||
|
||||
// Lifecycle management
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
stopCleanup chan bool
|
||||
closed int32
|
||||
hits int64
|
||||
misses int64
|
||||
evictions int64
|
||||
sets int64
|
||||
mu sync.RWMutex
|
||||
closed int32
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default cache configuration
|
||||
|
||||
Vendored
+11
-11
@@ -1750,19 +1750,19 @@ func TestAdvancedEdgeCases(t *testing.T) {
|
||||
|
||||
// Test with various data types
|
||||
testCases := []struct {
|
||||
key string
|
||||
value interface{}
|
||||
key string
|
||||
}{
|
||||
{"string", "test string"},
|
||||
{"int", 42},
|
||||
{"float", 3.14159},
|
||||
{"bool", true},
|
||||
{"slice", []string{"a", "b", "c"}},
|
||||
{"map", map[string]int{"one": 1, "two": 2}},
|
||||
{"nil", nil},
|
||||
{"empty-string", ""},
|
||||
{"empty-slice", []string{}},
|
||||
{"empty-map", map[string]interface{}{}},
|
||||
{key: "string", value: "test string"},
|
||||
{key: "int", value: 42},
|
||||
{key: "float", value: 3.14159},
|
||||
{key: "bool", value: true},
|
||||
{key: "slice", value: []string{"a", "b", "c"}},
|
||||
{key: "map", value: map[string]int{"one": 1, "two": 2}},
|
||||
{key: "nil", value: nil},
|
||||
{key: "empty-string", value: ""},
|
||||
{key: "empty-slice", value: []string{}},
|
||||
{key: "empty-map", value: map[string]interface{}{}},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
Vendored
+2
-7
@@ -7,22 +7,17 @@ import (
|
||||
|
||||
// Manager manages multiple cache instances with singleton pattern
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Core caches
|
||||
logger Logger
|
||||
tokenCache *Cache
|
||||
metadataCache *Cache
|
||||
jwkCache *Cache
|
||||
sessionCache *Cache
|
||||
generalCache *Cache
|
||||
|
||||
// Typed wrappers
|
||||
typedToken *TokenCache
|
||||
typedMetadata *MetadataCache
|
||||
typedJWK *JWKCache
|
||||
typedSession *SessionCache
|
||||
|
||||
logger Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
+23
-42
@@ -48,23 +48,12 @@ func (s State) String() string {
|
||||
|
||||
// CircuitBreakerConfig holds configuration for the circuit breaker
|
||||
type CircuitBreakerConfig struct {
|
||||
// MaxFailures is the number of consecutive failures before opening the circuit
|
||||
MaxFailures int
|
||||
|
||||
// FailureThreshold is the failure rate threshold (0.0 to 1.0)
|
||||
FailureThreshold float64
|
||||
|
||||
// Timeout is how long the circuit stays open before trying half-open
|
||||
Timeout time.Duration
|
||||
|
||||
// HalfOpenMaxRequests is the number of requests allowed in half-open state
|
||||
OnStateChange func(from, to State)
|
||||
MaxFailures int
|
||||
FailureThreshold float64
|
||||
Timeout time.Duration
|
||||
HalfOpenMaxRequests int
|
||||
|
||||
// ResetTimeout is how long to wait before resetting counters in closed state
|
||||
ResetTimeout time.Duration
|
||||
|
||||
// OnStateChange is called when the circuit breaker changes state
|
||||
OnStateChange func(from, to State)
|
||||
ResetTimeout time.Duration
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns default configuration
|
||||
@@ -80,28 +69,20 @@ func DefaultCircuitBreakerConfig() *CircuitBreakerConfig {
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern
|
||||
type CircuitBreaker struct {
|
||||
config *CircuitBreakerConfig
|
||||
|
||||
// State management
|
||||
state atomic.Int32
|
||||
lastStateChange time.Time
|
||||
stateMu sync.RWMutex
|
||||
|
||||
// Failure tracking
|
||||
consecutiveFailures atomic.Int32
|
||||
totalRequests atomic.Int64
|
||||
nextRetryTime time.Time
|
||||
lastStateChange time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureTime time.Time
|
||||
config *CircuitBreakerConfig
|
||||
totalFailures atomic.Int64
|
||||
totalRequests atomic.Int64
|
||||
stateTransitions atomic.Int64
|
||||
rejectedRequests atomic.Int64
|
||||
stateMu sync.RWMutex
|
||||
timeMu sync.RWMutex
|
||||
halfOpenRequests atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
nextRetryTime time.Time
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
stateTransitions atomic.Int64
|
||||
rejectedRequests atomic.Int64
|
||||
consecutiveFailures atomic.Int32
|
||||
state atomic.Int32
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker
|
||||
@@ -313,17 +294,17 @@ func (cb *CircuitBreaker) Stats() CircuitBreakerStats {
|
||||
|
||||
// CircuitBreakerStats holds statistics for the circuit breaker
|
||||
type CircuitBreakerStats struct {
|
||||
State State
|
||||
ConsecutiveFailures int32
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastStateChange time.Time
|
||||
NextRetryTime time.Time
|
||||
TotalRequests int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
RejectedRequests int64
|
||||
StateTransitions int64
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastStateChange time.Time
|
||||
NextRetryTime time.Time
|
||||
State State
|
||||
ConsecutiveFailures int32
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the circuit breaker is in a healthy state
|
||||
|
||||
@@ -28,8 +28,8 @@ type mockBackend struct {
|
||||
}
|
||||
|
||||
type mockEntry struct {
|
||||
value []byte
|
||||
expiresAt time.Time
|
||||
value []byte
|
||||
}
|
||||
|
||||
func newMockBackend() *mockBackend {
|
||||
|
||||
+28
-49
@@ -41,26 +41,13 @@ func (h HealthStatus) String() string {
|
||||
|
||||
// HealthCheckConfig holds configuration for the health checker
|
||||
type HealthCheckConfig struct {
|
||||
// CheckInterval is how often to check health
|
||||
CheckInterval time.Duration
|
||||
|
||||
// Timeout is the timeout for each health check
|
||||
Timeout time.Duration
|
||||
|
||||
// HealthyThreshold is the number of consecutive successes to become healthy
|
||||
HealthyThreshold int
|
||||
|
||||
// UnhealthyThreshold is the number of consecutive failures to become unhealthy
|
||||
OnStatusChange func(from, to HealthStatus)
|
||||
CheckFunc func(ctx context.Context) error
|
||||
CheckInterval time.Duration
|
||||
Timeout time.Duration
|
||||
HealthyThreshold int
|
||||
UnhealthyThreshold int
|
||||
|
||||
// DegradedThreshold is the latency threshold in ms to mark as degraded
|
||||
DegradedThreshold time.Duration
|
||||
|
||||
// OnStatusChange is called when health status changes
|
||||
OnStatusChange func(from, to HealthStatus)
|
||||
|
||||
// CheckFunc is the function to check health
|
||||
CheckFunc func(ctx context.Context) error
|
||||
DegradedThreshold time.Duration
|
||||
}
|
||||
|
||||
// DefaultHealthCheckConfig returns default configuration
|
||||
@@ -76,31 +63,23 @@ func DefaultHealthCheckConfig() *HealthCheckConfig {
|
||||
|
||||
// HealthChecker monitors the health of a backend
|
||||
type HealthChecker struct {
|
||||
config *HealthCheckConfig
|
||||
|
||||
// Status tracking
|
||||
status atomic.Int32
|
||||
consecutiveSuccesses atomic.Int32
|
||||
lastCheckTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureTime time.Time
|
||||
config *HealthCheckConfig
|
||||
stopChan chan struct{}
|
||||
ticker *time.Ticker
|
||||
wg sync.WaitGroup
|
||||
statusChanges atomic.Int64
|
||||
totalChecks atomic.Int64
|
||||
totalSuccesses atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
averageLatency atomic.Int64
|
||||
timeMu sync.RWMutex
|
||||
consecutiveFailures atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastCheckTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureTime time.Time
|
||||
averageLatency atomic.Int64
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalChecks atomic.Int64
|
||||
totalSuccesses atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
statusChanges atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
ticker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
stopped atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
consecutiveSuccesses atomic.Int32
|
||||
stopped atomic.Bool
|
||||
status atomic.Int32
|
||||
}
|
||||
|
||||
// NewHealthChecker creates a new health checker
|
||||
@@ -342,19 +321,19 @@ func (hc *HealthChecker) Stats() HealthCheckerStats {
|
||||
|
||||
// HealthCheckerStats holds statistics for the health checker
|
||||
type HealthCheckerStats struct {
|
||||
Status HealthStatus
|
||||
ConsecutiveSuccesses int32
|
||||
ConsecutiveFailures int32
|
||||
LastCheckTime time.Time
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
TotalChecks int64
|
||||
TotalSuccesses int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
AverageLatency time.Duration
|
||||
StatusChanges int64
|
||||
LastCheckTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastFailureTime time.Time
|
||||
HealthScore float64
|
||||
Status HealthStatus
|
||||
ConsecutiveFailures int32
|
||||
ConsecutiveSuccesses int32
|
||||
}
|
||||
|
||||
// Reset resets the health checker statistics
|
||||
|
||||
+7
-11
@@ -12,20 +12,16 @@ import (
|
||||
|
||||
// HealthCheckBackend wraps a cache backend with health checking
|
||||
type HealthCheckBackend struct {
|
||||
backend backends.CacheBackend
|
||||
config *HealthCheckConfig
|
||||
|
||||
// Health tracking
|
||||
lastCheck time.Time
|
||||
backend backends.CacheBackend
|
||||
ctx context.Context
|
||||
config *HealthCheckConfig
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
checkMutex sync.RWMutex
|
||||
status atomic.Int32
|
||||
consecutiveFails atomic.Int32
|
||||
consecutiveOK atomic.Int32
|
||||
lastCheck time.Time
|
||||
checkMutex sync.RWMutex
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewHealthCheckBackend creates a new health check wrapped backend
|
||||
|
||||
Vendored
+2
-2
@@ -292,12 +292,12 @@ type SessionCache struct {
|
||||
|
||||
// SessionData represents session information
|
||||
type SessionData struct {
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Claims map[string]interface{} `json:"claims"`
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Claims map[string]interface{} `json:"claims"`
|
||||
}
|
||||
|
||||
// NewSessionCache creates a new session cache
|
||||
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
|
||||
// Mock logger for testing
|
||||
type mockLogger struct {
|
||||
mu sync.Mutex
|
||||
logs []string
|
||||
errLogs []string
|
||||
debugLog []string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *mockLogger) Logf(format string, args ...interface{}) {
|
||||
|
||||
+11
-11
@@ -19,20 +19,20 @@ type Logger interface {
|
||||
|
||||
// BackgroundTask represents a recurring background task
|
||||
type BackgroundTask struct {
|
||||
name string
|
||||
interval time.Duration
|
||||
taskFunc func()
|
||||
lastRun time.Time
|
||||
logger Logger
|
||||
ctx context.Context
|
||||
ticker *time.Ticker
|
||||
stopChan chan bool
|
||||
isRunning int32
|
||||
logger Logger
|
||||
waitGroup *sync.WaitGroup
|
||||
lastRun time.Time
|
||||
taskFunc func()
|
||||
cancelFunc context.CancelFunc
|
||||
name string
|
||||
runCount int64
|
||||
errorCount int64
|
||||
interval time.Duration
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
isRunning int32
|
||||
}
|
||||
|
||||
// NewBackgroundTask creates a new background task
|
||||
@@ -183,11 +183,11 @@ func (bt *BackgroundTask) IsRunning() bool {
|
||||
|
||||
// TaskRegistry manages all background tasks
|
||||
type TaskRegistry struct {
|
||||
tasks map[string]*BackgroundTask
|
||||
mu sync.RWMutex
|
||||
logger Logger
|
||||
maxTasks int
|
||||
tasks map[string]*BackgroundTask
|
||||
circuitBreaker *TaskCircuitBreaker
|
||||
maxTasks int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// globalTaskRegistry is the singleton task registry
|
||||
|
||||
+13
-13
@@ -11,14 +11,14 @@ import (
|
||||
|
||||
// TaskCircuitBreaker prevents task creation failures from cascading
|
||||
type TaskCircuitBreaker struct {
|
||||
lastFailureTime time.Time
|
||||
logger Logger
|
||||
taskFailures map[string]int32
|
||||
timeout time.Duration
|
||||
mu sync.RWMutex
|
||||
failureThreshold int32
|
||||
failureCount int32
|
||||
lastFailureTime time.Time
|
||||
timeout time.Duration
|
||||
state int32 // 0: closed, 1: open
|
||||
logger Logger
|
||||
mu sync.RWMutex
|
||||
taskFailures map[string]int32
|
||||
state int32
|
||||
}
|
||||
|
||||
// CircuitBreakerState represents the state of the circuit breaker
|
||||
@@ -140,14 +140,14 @@ func (cb *TaskCircuitBreaker) GetState() CircuitBreakerState {
|
||||
|
||||
// TaskMemoryMonitor monitors memory usage and can trigger cleanup
|
||||
type TaskMemoryMonitor struct {
|
||||
lastCheck time.Time
|
||||
logger Logger
|
||||
registry *TaskRegistry
|
||||
stopChan chan bool
|
||||
memoryThreshold uint64
|
||||
checkInterval time.Duration
|
||||
isMonitoring int32
|
||||
stopChan chan bool
|
||||
lastCheck time.Time
|
||||
mu sync.RWMutex
|
||||
isMonitoring int32
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -310,13 +310,13 @@ func (tmm *TaskMemoryMonitor) GetStats() map[string]interface{} {
|
||||
|
||||
// WorkerPool manages a pool of worker goroutines for task execution
|
||||
type WorkerPool struct {
|
||||
workers int
|
||||
taskQueue chan func()
|
||||
workerWg sync.WaitGroup
|
||||
isRunning int32
|
||||
logger Logger
|
||||
taskQueue chan func()
|
||||
stopChan chan bool
|
||||
metrics WorkerPoolMetrics
|
||||
workerWg sync.WaitGroup
|
||||
workers int
|
||||
isRunning int32
|
||||
}
|
||||
|
||||
// WorkerPoolMetrics tracks worker pool performance
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
type FeatureFlag struct {
|
||||
name string
|
||||
description string
|
||||
enabled atomic.Bool
|
||||
mu sync.RWMutex
|
||||
callbacks []func(bool)
|
||||
mu sync.RWMutex
|
||||
enabled atomic.Bool
|
||||
}
|
||||
|
||||
// FeatureManager manages all feature flags in the application
|
||||
|
||||
+19
-28
@@ -14,50 +14,41 @@ import (
|
||||
// and resource leaks. It provides centralized management of HTTP client transports with
|
||||
// proper lifecycle management and security controls.
|
||||
type TransportPool struct {
|
||||
mu sync.RWMutex
|
||||
transports map[string]*sharedTransport
|
||||
maxConns int
|
||||
ctx context.Context
|
||||
transports map[string]*sharedTransport
|
||||
cancel context.CancelFunc
|
||||
clientCount int32 // Track total HTTP clients
|
||||
maxClients int32 // Limit total clients
|
||||
maxConns int
|
||||
mu sync.RWMutex
|
||||
clientCount int32
|
||||
maxClients int32
|
||||
}
|
||||
|
||||
// sharedTransport wraps an HTTP transport with reference counting
|
||||
type sharedTransport struct {
|
||||
transport *http.Transport
|
||||
refCount int32
|
||||
lastUsed time.Time
|
||||
transport *http.Transport
|
||||
config TransportConfig
|
||||
refCount int32
|
||||
}
|
||||
|
||||
// TransportConfig defines configuration for HTTP transports
|
||||
type TransportConfig struct {
|
||||
// Timeouts
|
||||
DialTimeout time.Duration
|
||||
TLSHandshakeTimeout time.Duration
|
||||
MaxConnsPerHost int
|
||||
WriteBufferSize int
|
||||
ResponseHeaderTimeout time.Duration
|
||||
ExpectContinueTimeout time.Duration
|
||||
IdleConnTimeout time.Duration
|
||||
KeepAlive time.Duration
|
||||
|
||||
// Connection limits
|
||||
MaxIdleConns int
|
||||
MaxIdleConnsPerHost int
|
||||
MaxConnsPerHost int
|
||||
|
||||
// Features
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
|
||||
// Buffer sizes
|
||||
WriteBufferSize int
|
||||
ReadBufferSize int
|
||||
|
||||
// TLS
|
||||
InsecureSkipVerify bool
|
||||
MinTLSVersion uint16
|
||||
TLSHandshakeTimeout time.Duration
|
||||
MaxIdleConns int
|
||||
DialTimeout time.Duration
|
||||
MaxIdleConnsPerHost int
|
||||
ReadBufferSize int
|
||||
MinTLSVersion uint16
|
||||
ForceHTTP2 bool
|
||||
DisableCompression bool
|
||||
InsecureSkipVerify bool
|
||||
DisableKeepAlives bool
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -154,10 +154,10 @@ func TestAzureProvider_ValidateTokens(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
session *mockSession
|
||||
verifierError error
|
||||
session *mockSession
|
||||
cacheData map[string]interface{}
|
||||
name string
|
||||
expectedResult ValidationResult
|
||||
}{
|
||||
{
|
||||
@@ -369,9 +369,9 @@ func TestAzureProvider_OfflineAccessHandling(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
inputScopes []string
|
||||
expectedCount int // Expected number of offline_access scopes (should be 1)
|
||||
description string
|
||||
inputScopes []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "No offline_access - should add one",
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
|
||||
// Mock implementations for testing
|
||||
type mockSession struct {
|
||||
authenticated bool
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
authenticated bool
|
||||
}
|
||||
|
||||
func (s *mockSession) GetIDToken() string { return s.idToken }
|
||||
@@ -338,10 +338,10 @@ func TestBaseProvider_ValidateTokenExpiry(t *testing.T) {
|
||||
gracePeriod := 5 * time.Minute
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claims map[string]interface{}
|
||||
cacheFound bool
|
||||
name string
|
||||
expectedResult ValidationResult
|
||||
cacheFound bool
|
||||
}{
|
||||
{
|
||||
name: "Token not found in cache, has refresh token",
|
||||
@@ -438,10 +438,10 @@ func TestBaseProvider_ValidateTokenExpiry_NoRefreshToken(t *testing.T) {
|
||||
gracePeriod := 5 * time.Minute
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claims map[string]interface{}
|
||||
cacheFound bool
|
||||
name string
|
||||
expectedResult ValidationResult
|
||||
cacheFound bool
|
||||
}{
|
||||
{
|
||||
name: "Token not found in cache, no refresh token",
|
||||
|
||||
@@ -25,9 +25,9 @@ func TestProviderFactory_CreateProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
errMsg string
|
||||
expectedType ProviderType
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Google provider",
|
||||
@@ -158,10 +158,10 @@ func TestProviderFactory_CreateProviderByType(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errMsg string
|
||||
providerType ProviderType
|
||||
expectedType ProviderType
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Generic provider",
|
||||
|
||||
@@ -136,9 +136,9 @@ func TestGenericProvider_ValidateTokens(t *testing.T) {
|
||||
provider := NewGenericProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
session *mockSession
|
||||
verifierError error
|
||||
session *mockSession
|
||||
name string
|
||||
expectedResult ValidationResult
|
||||
}{
|
||||
{
|
||||
|
||||
@@ -172,8 +172,8 @@ func TestGoogleProvider_OfflineAccessFiltering(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
inputScopes []string
|
||||
description string
|
||||
inputScopes []string
|
||||
}{
|
||||
{
|
||||
name: "Multiple offline_access occurrences",
|
||||
|
||||
@@ -82,9 +82,9 @@ func TestProviderRegistry_GetProviderByType(t *testing.T) {
|
||||
registry.RegisterProvider(googleProvider)
|
||||
|
||||
tests := []struct {
|
||||
expected OIDCProvider
|
||||
name string
|
||||
providerType ProviderType
|
||||
expected OIDCProvider
|
||||
}{
|
||||
{
|
||||
name: "Get Generic provider",
|
||||
@@ -180,9 +180,9 @@ func TestProviderRegistry_DetectProvider(t *testing.T) {
|
||||
registry.RegisterProvider(gitlabProvider)
|
||||
|
||||
tests := []struct {
|
||||
expected OIDCProvider
|
||||
name string
|
||||
issuerURL string
|
||||
expected OIDCProvider
|
||||
}{
|
||||
{
|
||||
name: "Google provider detection",
|
||||
@@ -640,9 +640,9 @@ func TestProviderRegistry_GitLabDetection_RealWorldURLs(t *testing.T) {
|
||||
registry.RegisterProvider(githubProvider)
|
||||
|
||||
realWorldTests := []struct {
|
||||
expected OIDCProvider
|
||||
name string
|
||||
issuerURL string
|
||||
expected OIDCProvider
|
||||
}{
|
||||
// Actual self-hosted GitLab examples from issue #61
|
||||
{
|
||||
|
||||
@@ -20,8 +20,8 @@ func TestValidateIssuerURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid https URL",
|
||||
@@ -106,8 +106,8 @@ func TestValidateClientID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid client ID",
|
||||
@@ -173,9 +173,9 @@ func TestValidateClientID(t *testing.T) {
|
||||
func TestValidateScopes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errMsg string
|
||||
scopes []string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid scopes with openid",
|
||||
@@ -248,8 +248,8 @@ func TestValidateRedirectURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURL string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid https redirect URL",
|
||||
@@ -315,11 +315,11 @@ func TestValidateRedirectURL(t *testing.T) {
|
||||
// TestValidateProviderSpecificConfig tests provider-specific configuration validation
|
||||
func TestValidateProviderSpecificConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
provider OIDCProvider
|
||||
config map[string]interface{}
|
||||
wantErr bool
|
||||
name string
|
||||
errMsg string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid Google config",
|
||||
@@ -458,8 +458,8 @@ func TestValidateGoogleConfig_EdgeCases(t *testing.T) {
|
||||
googleProvider := NewGoogleProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config map[string]interface{}
|
||||
name string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
@@ -502,10 +502,10 @@ func TestValidateAzureConfig_EdgeCases(t *testing.T) {
|
||||
azureProvider := NewAzureProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config map[string]interface{}
|
||||
wantErr bool
|
||||
name string
|
||||
errMsg string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid tenant ID format",
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
|
||||
// ProviderWarning represents a warning about provider limitations or requirements.
|
||||
type ProviderWarning struct {
|
||||
ProviderType ProviderType
|
||||
Level string // "info", "warning", "error"
|
||||
Level string
|
||||
Message string
|
||||
ProviderType ProviderType
|
||||
}
|
||||
|
||||
// GetProviderWarnings returns warnings about provider-specific limitations.
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
func TestGetProviderWarnings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
checkContent string
|
||||
providerType ProviderType
|
||||
expectCount int
|
||||
checkContent string
|
||||
}{
|
||||
{
|
||||
name: "GitHub has OAuth 2.0 warning",
|
||||
|
||||
@@ -34,21 +34,15 @@ type Logger interface {
|
||||
// for all recovery mechanism implementations. It handles request counting,
|
||||
// success/failure tracking, and timestamp management in a thread-safe manner.
|
||||
type BaseRecoveryMechanism struct {
|
||||
// name identifies the recovery mechanism instance
|
||||
name string
|
||||
// logger provides structured logging capabilities
|
||||
logger Logger
|
||||
|
||||
// Metrics tracked with atomic operations for thread safety
|
||||
logger Logger
|
||||
name string
|
||||
lastSuccessStr string
|
||||
lastFailureStr string
|
||||
totalRequests int64
|
||||
successCount int64
|
||||
failureCount int64
|
||||
lastSuccessStr string
|
||||
lastFailureStr string
|
||||
|
||||
// mutexes for thread-safe timestamp updates
|
||||
successMutex sync.RWMutex
|
||||
failureMutex sync.RWMutex
|
||||
successMutex sync.RWMutex
|
||||
failureMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewBaseRecoveryMechanism creates a new base recovery mechanism with the given name and logger.
|
||||
@@ -182,10 +176,10 @@ const (
|
||||
|
||||
// HTTPError represents an HTTP error with status code and message
|
||||
type HTTPError struct {
|
||||
StatusCode int
|
||||
Headers map[string]string
|
||||
Message string
|
||||
Body []byte
|
||||
Headers map[string]string
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
|
||||
@@ -60,20 +60,14 @@ func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
|
||||
// CircuitBreaker implements the circuit breaker pattern for fault tolerance.
|
||||
// It prevents cascading failures by temporarily blocking requests to a failing service.
|
||||
type CircuitBreaker struct {
|
||||
*BaseRecoveryMechanism
|
||||
config CircuitBreakerConfig
|
||||
|
||||
// State management
|
||||
state int32 // atomic: CircuitBreakerState
|
||||
lastStateChange time.Time
|
||||
stateMutex sync.RWMutex
|
||||
|
||||
// Failure tracking
|
||||
consecutiveFailures int32 // atomic
|
||||
consecutiveSuccesses int32 // atomic
|
||||
|
||||
// Half-open state management
|
||||
halfOpenRequests int32 // atomic
|
||||
*BaseRecoveryMechanism
|
||||
config CircuitBreakerConfig
|
||||
stateMutex sync.RWMutex
|
||||
state int32
|
||||
consecutiveFailures int32
|
||||
consecutiveSuccesses int32
|
||||
halfOpenRequests int32
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker with the given configuration
|
||||
|
||||
@@ -15,20 +15,13 @@ import (
|
||||
|
||||
// RetryConfig defines configuration for the retry executor
|
||||
type RetryConfig struct {
|
||||
// MaxAttempts is the maximum number of retry attempts
|
||||
MaxAttempts int
|
||||
// InitialDelay is the initial delay between retries
|
||||
InitialDelay time.Duration
|
||||
// MaxDelay is the maximum delay between retries
|
||||
MaxDelay time.Duration
|
||||
// Multiplier is the backoff multiplier
|
||||
Multiplier float64
|
||||
// RandomizationFactor adds jitter to delays (0.0 to 1.0)
|
||||
RandomizationFactor float64
|
||||
// RetryableErrors defines which errors should trigger a retry
|
||||
RetryableErrors []string
|
||||
// RetryableStatusCodes defines which HTTP status codes should trigger a retry
|
||||
RetryableErrors []string
|
||||
RetryableStatusCodes []int
|
||||
MaxAttempts int
|
||||
InitialDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
Multiplier float64
|
||||
RandomizationFactor float64
|
||||
}
|
||||
|
||||
// DefaultRetryConfig returns sensible default retry configuration
|
||||
@@ -46,13 +39,11 @@ func DefaultRetryConfig() RetryConfig {
|
||||
|
||||
// RetryExecutor implements retry logic with exponential backoff
|
||||
type RetryExecutor struct {
|
||||
lastRetryTime time.Time
|
||||
*BaseRecoveryMechanism
|
||||
config RetryConfig
|
||||
|
||||
// Metrics
|
||||
config RetryConfig
|
||||
totalRetries int64
|
||||
maxRetriesHit int64
|
||||
lastRetryTime time.Time
|
||||
retryTimeMutex sync.RWMutex
|
||||
}
|
||||
|
||||
|
||||
@@ -273,17 +273,17 @@ func TestRetryExecutor_isRetryableError(t *testing.T) {
|
||||
executor := NewRetryExecutor(config, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{"nil error", nil, false},
|
||||
{"connection refused", errors.New("connection refused"), true},
|
||||
{"timeout", errors.New("TIMEOUT"), true}, // case insensitive
|
||||
{"EOF", errors.New("EOF"), false},
|
||||
{"random error", errors.New("something else"), false},
|
||||
{"context cancelled", context.Canceled, false},
|
||||
{"context deadline exceeded", context.DeadlineExceeded, false},
|
||||
{name: "nil error", err: nil, expected: false},
|
||||
{name: "connection refused", err: errors.New("connection refused"), expected: true},
|
||||
{name: "timeout", err: errors.New("TIMEOUT"), expected: true}, // case insensitive
|
||||
{name: "EOF", err: errors.New("EOF"), expected: false},
|
||||
{name: "random error", err: errors.New("something else"), expected: false},
|
||||
{name: "context cancelled", err: context.Canceled, expected: false},
|
||||
{name: "context deadline exceeded", err: context.DeadlineExceeded, expected: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -13,10 +13,10 @@ import (
|
||||
|
||||
// Mock logger for testing
|
||||
type mockLogger struct {
|
||||
mu sync.Mutex
|
||||
logs []string
|
||||
errLogs []string
|
||||
debugLog []string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *mockLogger) Logf(format string, args ...interface{}) {
|
||||
@@ -202,13 +202,13 @@ func TestBaseRecoveryMechanism_ConcurrentAccess(t *testing.T) {
|
||||
// CircuitBreakerState tests
|
||||
func TestCircuitBreakerState_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
state CircuitBreakerState
|
||||
expected string
|
||||
state CircuitBreakerState
|
||||
}{
|
||||
{CircuitBreakerClosed, "closed"},
|
||||
{CircuitBreakerOpen, "open"},
|
||||
{CircuitBreakerHalfOpen, "half-open"},
|
||||
{CircuitBreakerState(99), "unknown"},
|
||||
{state: CircuitBreakerClosed, expected: "closed"},
|
||||
{state: CircuitBreakerOpen, expected: "open"},
|
||||
{state: CircuitBreakerHalfOpen, expected: "half-open"},
|
||||
{state: CircuitBreakerState(99), expected: "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -29,26 +29,26 @@ type JWK struct {
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// IntrospectionResponse represents a token introspection response
|
||||
type IntrospectionResponse struct {
|
||||
Active bool `json:"active"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
Nbf int64 `json:"nbf,omitempty"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Aud string `json:"aud,omitempty"`
|
||||
Iss string `json:"iss,omitempty"`
|
||||
Jti string `json:"jti,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
Nbf int64 `json:"nbf,omitempty"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
// JWKCache is a testify mock for JWK caching operations
|
||||
|
||||
@@ -8,16 +8,16 @@ import (
|
||||
|
||||
// SessionData represents session data for testing
|
||||
type SessionData struct {
|
||||
Claims map[string]interface{}
|
||||
Email string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
IDToken string
|
||||
Expiry int64
|
||||
Nonce string
|
||||
State string
|
||||
CodeVerifier string
|
||||
RedirectURL string
|
||||
Claims map[string]interface{}
|
||||
Expiry int64
|
||||
}
|
||||
|
||||
// SessionManager is a testify mock for session management
|
||||
|
||||
@@ -16,45 +16,30 @@ import (
|
||||
|
||||
// OIDCServerConfig configures the mock OIDC server behavior
|
||||
type OIDCServerConfig struct {
|
||||
// Identity
|
||||
Issuer string
|
||||
|
||||
// Discovery
|
||||
ScopesSupported []string
|
||||
ResponseTypesSupported []string
|
||||
JWKSResponse map[string]interface{}
|
||||
TokenFixture *fixtures.TokenFixture
|
||||
UserinfoError *OIDCError
|
||||
UserinfoResponse map[string]interface{}
|
||||
IntrospectionResponse map[string]interface{}
|
||||
JWKSError *OIDCError
|
||||
RefreshError *OIDCError
|
||||
TokenResponse map[string]interface{}
|
||||
TokenError *OIDCError
|
||||
IntrospectionError *OIDCError
|
||||
RefreshResponse map[string]interface{}
|
||||
Issuer string
|
||||
GrantTypesSupported []string
|
||||
ClaimsSupported []string
|
||||
TokenEndpointAuthMethods []string
|
||||
|
||||
// Token fixture for signing
|
||||
TokenFixture *fixtures.TokenFixture
|
||||
|
||||
// Token endpoint behavior
|
||||
TokenResponse map[string]interface{}
|
||||
TokenError *OIDCError
|
||||
TokenDelay time.Duration
|
||||
RefreshResponse map[string]interface{}
|
||||
RefreshError *OIDCError
|
||||
|
||||
// JWKS behavior
|
||||
JWKSResponse map[string]interface{}
|
||||
JWKSError *OIDCError
|
||||
JWKSDelay time.Duration
|
||||
|
||||
// Introspection behavior
|
||||
IntrospectionResponse map[string]interface{}
|
||||
IntrospectionError *OIDCError
|
||||
|
||||
// Userinfo behavior
|
||||
UserinfoResponse map[string]interface{}
|
||||
UserinfoError *OIDCError
|
||||
|
||||
// Simulation flags
|
||||
SimulateTimeout bool
|
||||
TimeoutDuration time.Duration
|
||||
RateLimitAfter int
|
||||
FailAfterN int
|
||||
FailWithStatus int
|
||||
ScopesSupported []string
|
||||
ClaimsSupported []string
|
||||
ResponseTypesSupported []string
|
||||
FailAfterN int
|
||||
JWKSDelay time.Duration
|
||||
TimeoutDuration time.Duration
|
||||
RateLimitAfter int
|
||||
TokenDelay time.Duration
|
||||
FailWithStatus int
|
||||
SimulateTimeout bool
|
||||
}
|
||||
|
||||
// OIDCError represents an OAuth error response
|
||||
@@ -67,9 +52,9 @@ type OIDCError struct {
|
||||
type OIDCServer struct {
|
||||
*httptest.Server
|
||||
Config *OIDCServerConfig
|
||||
RequestCount int32
|
||||
mu sync.Mutex
|
||||
requests []*http.Request
|
||||
mu sync.Mutex
|
||||
RequestCount int32
|
||||
}
|
||||
|
||||
// NewOIDCServer creates a new mock OIDC server
|
||||
|
||||
@@ -135,9 +135,9 @@ func TestIsTestMode(t *testing.T) {
|
||||
// We'll test what we can control via environment variables.
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func()
|
||||
cleanup func()
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
@@ -206,8 +206,8 @@ func TestIsTestMode(t *testing.T) {
|
||||
func TestIsTestModeEdgeCases(t *testing.T) {
|
||||
// Test with various environment variable combinations
|
||||
tests := []struct {
|
||||
name string
|
||||
env map[string]string
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "all env vars empty",
|
||||
@@ -560,11 +560,11 @@ func TestIsTestModeYaegiCompiler(t *testing.T) {
|
||||
|
||||
// mockLogger is a simple mock implementation for testing
|
||||
type mockLogger struct {
|
||||
lastFormat string
|
||||
lastArgs []interface{}
|
||||
infoCalls int
|
||||
debugCalls int
|
||||
errorCalls int
|
||||
lastFormat string
|
||||
lastArgs []interface{}
|
||||
}
|
||||
|
||||
func (m *mockLogger) Infof(format string, args ...interface{}) {
|
||||
|
||||
@@ -21,25 +21,16 @@ import (
|
||||
// JWK represents a JSON Web Key as defined in RFC 7517.
|
||||
// It can represent different key types including RSA, EC, and symmetric keys.
|
||||
type JWK struct {
|
||||
// Key type (e.g., "RSA", "EC", "oct")
|
||||
Kty string `json:"kty"`
|
||||
// Key use (e.g., "sig" for signature, "enc" for encryption)
|
||||
Use string `json:"use,omitempty"`
|
||||
// Key operations allowed
|
||||
Kty string `json:"kty"`
|
||||
Use string `json:"use,omitempty"`
|
||||
Alg string `json:"alg,omitempty"`
|
||||
Kid string `json:"kid,omitempty"`
|
||||
N string `json:"n,omitempty"`
|
||||
E string `json:"e,omitempty"`
|
||||
Crv string `json:"crv,omitempty"`
|
||||
X string `json:"x,omitempty"`
|
||||
Y string `json:"y,omitempty"`
|
||||
KeyOps []string `json:"key_ops,omitempty"`
|
||||
// Algorithm intended for use with this key
|
||||
Alg string `json:"alg,omitempty"`
|
||||
// Key ID
|
||||
Kid string `json:"kid,omitempty"`
|
||||
|
||||
// RSA specific fields
|
||||
N string `json:"n,omitempty"` // Modulus
|
||||
E string `json:"e,omitempty"` // Exponent
|
||||
|
||||
// EC specific fields
|
||||
Crv string `json:"crv,omitempty"` // Curve
|
||||
X string `json:"x,omitempty"` // X coordinate
|
||||
Y string `json:"y,omitempty"` // Y coordinate
|
||||
}
|
||||
|
||||
// JWKSet represents a set of JSON Web Keys.
|
||||
|
||||
+9
-9
@@ -309,18 +309,18 @@ func TestJWKCacheCleanupAndClose(t *testing.T) {
|
||||
func TestFetchJWKSEdgeCases(t *testing.T) {
|
||||
t.Run("handles various HTTP status codes", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
errContains string
|
||||
status int
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{200, false, ""},
|
||||
{400, true, "400"},
|
||||
{401, true, "401"},
|
||||
{403, true, "403"},
|
||||
{404, true, "404"},
|
||||
{500, true, "500"},
|
||||
{502, true, "502"},
|
||||
{503, true, "503"},
|
||||
{status: 200, wantErr: false, errContains: ""},
|
||||
{status: 400, wantErr: true, errContains: "400"},
|
||||
{status: 401, wantErr: true, errContains: "401"},
|
||||
{status: 403, wantErr: true, errContains: "403"},
|
||||
{status: 404, wantErr: true, errContains: "404"},
|
||||
{status: 500, wantErr: true, errContains: "500"},
|
||||
{status: 502, wantErr: true, errContains: "502"},
|
||||
{status: 503, wantErr: true, errContains: "503"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -5,6 +5,8 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -392,11 +394,15 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662)
|
||||
t.registrationURL = metadata.RegistrationURL // OIDC Dynamic Client Registration endpoint (RFC 7591)
|
||||
|
||||
// Copy values for logging after unlock to avoid race conditions
|
||||
introspectionURL := t.introspectionURL
|
||||
registrationURL := t.registrationURL
|
||||
|
||||
t.metadataMu.Unlock()
|
||||
|
||||
// Log introspection endpoint availability for opaque token support
|
||||
if t.introspectionURL != "" {
|
||||
t.logger.Debugf("Token introspection endpoint discovered: %s", t.introspectionURL)
|
||||
if introspectionURL != "" {
|
||||
t.logger.Debugf("Token introspection endpoint discovered: %s", introspectionURL)
|
||||
if t.allowOpaqueTokens {
|
||||
t.logger.Debugf("Opaque token support enabled with introspection endpoint")
|
||||
}
|
||||
@@ -405,8 +411,8 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
}
|
||||
|
||||
// Log registration endpoint availability
|
||||
if t.registrationURL != "" {
|
||||
t.logger.Debugf("Dynamic client registration endpoint discovered: %s", t.registrationURL)
|
||||
if registrationURL != "" {
|
||||
t.logger.Debugf("Dynamic client registration endpoint discovered: %s", registrationURL)
|
||||
}
|
||||
|
||||
// Perform Dynamic Client Registration if enabled and ClientID is not set
|
||||
@@ -474,7 +480,10 @@ func (t *TraefikOidc) performDynamicClientRegistration() {
|
||||
func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
// Use singleton resource manager for metadata refresh
|
||||
rm := GetResourceManager()
|
||||
taskName := "singleton-metadata-refresh"
|
||||
// Use last 6 chars of provider URL hash to create unique task name per realm
|
||||
// This fixes multi-realm support where different Keycloak realms need separate refresh tasks
|
||||
hash := sha256.Sum256([]byte(providerURL))
|
||||
taskName := "singleton-metadata-refresh-" + hex.EncodeToString(hash[:])[0:6]
|
||||
|
||||
// Create refresh function
|
||||
refreshFunc := func() {
|
||||
@@ -510,6 +519,27 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
}
|
||||
}
|
||||
|
||||
// attemptMetadataRecovery tries to fetch provider metadata when the system is in a failed state.
|
||||
// This is called periodically (every 30s) when requests come in and metadata is unavailable.
|
||||
// It allows automatic recovery when the OIDC provider becomes available again.
|
||||
func (t *TraefikOidc) attemptMetadataRecovery() {
|
||||
if t.metadataCache == nil || t.httpClient == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Try to fetch metadata (single attempt, no aggressive retry here since this runs every 30s)
|
||||
metadata, err := t.metadataCache.GetMetadata(t.providerURL, t.httpClient, t.logger)
|
||||
if err != nil {
|
||||
t.safeLogDebugf("Metadata recovery attempt failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if metadata != nil {
|
||||
t.updateMetadataEndpoints(metadata)
|
||||
t.safeLogInfo("Successfully recovered OIDC provider metadata - service restored")
|
||||
}
|
||||
}
|
||||
|
||||
// createCaseInsensitiveStringMap creates a map with lowercase keys for case-insensitive matching.
|
||||
// This is used for case-insensitive matching of email addresses.
|
||||
// Parameters:
|
||||
|
||||
@@ -18,15 +18,15 @@ import (
|
||||
// TestExchangeCodeForToken_Comprehensive tests the ExchangeCodeForToken function comprehensively
|
||||
func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
tests := []struct {
|
||||
setupMock func(*httptest.Server) *TraefikOidc
|
||||
validateFunc func(*testing.T, *TokenResponse, error)
|
||||
name string
|
||||
grantType string
|
||||
code string
|
||||
redirectURL string
|
||||
codeVerifier string
|
||||
setupMock func(*httptest.Server) *TraefikOidc
|
||||
validateFunc func(*testing.T, *TokenResponse, error)
|
||||
wantErr bool
|
||||
expectedError string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful authorization code exchange",
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
func TestGoroutineLeakPrevention_ContextCancellation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cancelAfter time.Duration
|
||||
expectedLeaks int // Maximum expected goroutines after cleanup
|
||||
description string
|
||||
cancelAfter time.Duration
|
||||
expectedLeaks int
|
||||
}{
|
||||
{
|
||||
name: "immediate_cancellation",
|
||||
|
||||
@@ -15,10 +15,10 @@ import (
|
||||
// TestInitializeMetadata tests the initializeMetadata function
|
||||
func TestInitializeMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providerURL string
|
||||
setupMock func() *httptest.Server
|
||||
validateFunc func(*testing.T, *TraefikOidc)
|
||||
name string
|
||||
providerURL string
|
||||
wantPanic bool
|
||||
}{
|
||||
{
|
||||
|
||||
@@ -16,12 +16,12 @@ import (
|
||||
// TestGetNewTokenWithRefreshToken tests the GetNewTokenWithRefreshToken function
|
||||
func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
setupMock func(*httptest.Server) *TraefikOidc
|
||||
validateFunc func(*testing.T, *TokenResponse, error)
|
||||
wantErr bool
|
||||
name string
|
||||
refreshToken string
|
||||
expectedError string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful token refresh",
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
// TestServeHTTP_ExcludedURLs tests the excluded URLs functionality
|
||||
func TestServeHTTP_ExcludedURLs(t *testing.T) {
|
||||
tests := []struct {
|
||||
excludedURLs map[string]struct{}
|
||||
name string
|
||||
path string
|
||||
excludedURLs map[string]struct{}
|
||||
shouldBypass bool
|
||||
}{
|
||||
{
|
||||
@@ -506,12 +506,12 @@ type MockSessionData struct {
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
redirectCount int
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
}
|
||||
|
||||
func (m *MockSessionData) GetEmail() string { return m.email }
|
||||
|
||||
+2
-2
@@ -81,11 +81,11 @@ func TestIsTestMode_DefaultBehavior(t *testing.T) {
|
||||
// TestVerifyAudience tests the verifyAudience function
|
||||
func TestVerifyAudience(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenAudience interface{}
|
||||
name string
|
||||
expectedAudience string
|
||||
expectError bool
|
||||
description string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Audience matches",
|
||||
|
||||
+274
-7
@@ -192,9 +192,9 @@ func (ts *TestSuite) Setup() {
|
||||
|
||||
// MockJWKCache implements JWKCacheInterface
|
||||
type MockJWKCache struct {
|
||||
mu sync.RWMutex
|
||||
JWKS *JWKSet
|
||||
Err error
|
||||
JWKS *JWKSet
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Close is a no-op for the mock.
|
||||
@@ -209,11 +209,8 @@ func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) Cleanup() {
|
||||
// Mock cleanup implementation
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.JWKS = nil
|
||||
m.Err = nil
|
||||
// Mock cleanup is a no-op - we don't want to destroy the mock JWKS data
|
||||
// Real cleanup is for expired entries, not resetting all data
|
||||
}
|
||||
|
||||
// MockTokenVerifier implements TokenVerifier for testing, allowing interception of VerifyToken calls.
|
||||
@@ -2427,6 +2424,276 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiRealmMetadataRefreshIsolation verifies that multiple middleware instances
|
||||
// with different provider URLs (e.g., different Keycloak realms) get separate
|
||||
// metadata refresh tasks. This addresses the issue reported in PR #88.
|
||||
func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping test in short mode")
|
||||
}
|
||||
|
||||
// Create two mock provider metadata servers simulating different Keycloak realms
|
||||
realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://keycloak.example.com/realms/realm1",
|
||||
AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth",
|
||||
TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token",
|
||||
JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs",
|
||||
EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
defer realm1Server.Close()
|
||||
|
||||
realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://keycloak.example.com/realms/realm2",
|
||||
AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth",
|
||||
TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token",
|
||||
JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs",
|
||||
EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
}))
|
||||
defer realm2Server.Close()
|
||||
|
||||
// Config for realm1
|
||||
config1 := &Config{
|
||||
ProviderURL: realm1Server.URL,
|
||||
ClientID: "realm1-client",
|
||||
ClientSecret: "realm1-secret",
|
||||
CallbackURL: "/realm1/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
CookiePrefix: "_oidc_realm1_",
|
||||
}
|
||||
|
||||
// Config for realm2
|
||||
config2 := &Config{
|
||||
ProviderURL: realm2Server.URL,
|
||||
ClientID: "realm2-client",
|
||||
ClientSecret: "realm2-secret",
|
||||
CallbackURL: "/realm2/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
CookiePrefix: "_oidc_realm2_",
|
||||
}
|
||||
|
||||
// Create middleware instances for both realms
|
||||
middleware1, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}), config1, "realm1-middleware")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create middleware for realm1: %v", err)
|
||||
}
|
||||
|
||||
middleware2, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}), config2, "realm2-middleware")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create middleware for realm2: %v", err)
|
||||
}
|
||||
|
||||
m1, ok1 := middleware1.(*TraefikOidc)
|
||||
m2, ok2 := middleware2.(*TraefikOidc)
|
||||
if !ok1 || !ok2 {
|
||||
t.Fatalf("Middleware is not of type *TraefikOidc")
|
||||
}
|
||||
|
||||
// Clean up middleware instances
|
||||
defer func() {
|
||||
if err := m1.Close(); err != nil {
|
||||
t.Errorf("Failed to close realm1 middleware: %v", err)
|
||||
}
|
||||
if err := m2.Close(); err != nil {
|
||||
t.Errorf("Failed to close realm2 middleware: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for both instances to initialize
|
||||
select {
|
||||
case <-m1.initComplete:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("Realm1 middleware failed to initialize")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-m2.initComplete:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("Realm2 middleware failed to initialize")
|
||||
}
|
||||
|
||||
// Verify each instance has the correct issuer URL from their respective realms
|
||||
if !strings.Contains(m1.issuerURL, "realm1") {
|
||||
t.Errorf("Realm1 middleware expected issuer with realm1, got %s", m1.issuerURL)
|
||||
}
|
||||
if !strings.Contains(m2.issuerURL, "realm2") {
|
||||
t.Errorf("Realm2 middleware expected issuer with realm2, got %s", m2.issuerURL)
|
||||
}
|
||||
|
||||
// Verify provider URLs are different
|
||||
if m1.providerURL == m2.providerURL {
|
||||
t.Errorf("Both middlewares should have different provider URLs, got same: %s", m1.providerURL)
|
||||
}
|
||||
|
||||
// Test that each middleware can handle requests independently
|
||||
req1 := httptest.NewRequest("GET", "/realm1/protected", nil)
|
||||
rr1 := httptest.NewRecorder()
|
||||
m1.ServeHTTP(rr1, req1)
|
||||
|
||||
req2 := httptest.NewRequest("GET", "/realm2/protected", nil)
|
||||
rr2 := httptest.NewRecorder()
|
||||
m2.ServeHTTP(rr2, req2)
|
||||
|
||||
// Both should redirect to their respective auth URLs
|
||||
if rr1.Code != http.StatusFound {
|
||||
t.Errorf("Realm1: Expected redirect status %d, got %d", http.StatusFound, rr1.Code)
|
||||
}
|
||||
if rr2.Code != http.StatusFound {
|
||||
t.Errorf("Realm2: Expected redirect status %d, got %d", http.StatusFound, rr2.Code)
|
||||
}
|
||||
|
||||
location1 := rr1.Header().Get("Location")
|
||||
location2 := rr2.Header().Get("Location")
|
||||
|
||||
if !strings.Contains(location1, "realm1") {
|
||||
t.Errorf("Realm1: Expected redirect to realm1 auth URL, got %s", location1)
|
||||
}
|
||||
if !strings.Contains(location2, "realm2") {
|
||||
t.Errorf("Realm2: Expected redirect to realm2 auth URL, got %s", location2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMetadataRecoveryOnProviderFailure verifies that the middleware automatically
|
||||
// recovers when the OIDC provider becomes available after initial failure.
|
||||
func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping test in short mode")
|
||||
}
|
||||
|
||||
// Track whether the provider is "available"
|
||||
providerAvailable := false
|
||||
var mu sync.Mutex
|
||||
|
||||
// Create mock provider that initially fails, then becomes available
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
available := providerAvailable
|
||||
mu.Unlock()
|
||||
|
||||
if !available {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/.well-known/openid-configuration" {
|
||||
metadata := ProviderMetadata{
|
||||
Issuer: "https://test-issuer.com",
|
||||
AuthURL: "https://test-issuer.com/auth",
|
||||
TokenURL: "https://test-issuer.com/token",
|
||||
JWKSURL: "https://test-issuer.com/jwks",
|
||||
EndSessionURL: "https://test-issuer.com/logout",
|
||||
}
|
||||
json.NewEncoder(w).Encode(metadata)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
config := &Config{
|
||||
ProviderURL: mockServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
||||
}
|
||||
|
||||
// Create middleware while provider is unavailable
|
||||
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}), config, "test-recovery")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create middleware: %v", err)
|
||||
}
|
||||
|
||||
m, ok := middleware.(*TraefikOidc)
|
||||
if !ok {
|
||||
t.Fatalf("Middleware is not of type *TraefikOidc")
|
||||
}
|
||||
defer m.Close()
|
||||
|
||||
// Wait for initial initialization to complete (it should fail)
|
||||
select {
|
||||
case <-m.initComplete:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("Initialization did not complete in time")
|
||||
}
|
||||
|
||||
// Verify initial state - should be in failed state (no issuerURL)
|
||||
m.metadataMu.RLock()
|
||||
initialIssuer := m.issuerURL
|
||||
m.metadataMu.RUnlock()
|
||||
|
||||
if initialIssuer != "" {
|
||||
t.Errorf("Expected empty issuerURL after failed init, got: %s", initialIssuer)
|
||||
}
|
||||
|
||||
// First request should get 503
|
||||
req1 := httptest.NewRequest("GET", "/protected", nil)
|
||||
rr1 := httptest.NewRecorder()
|
||||
m.ServeHTTP(rr1, req1)
|
||||
|
||||
if rr1.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected 503 when provider unavailable, got %d", rr1.Code)
|
||||
}
|
||||
|
||||
// Now make the provider available
|
||||
mu.Lock()
|
||||
providerAvailable = true
|
||||
mu.Unlock()
|
||||
|
||||
// Reset the retry timer to allow immediate retry
|
||||
m.metadataRetryMutex.Lock()
|
||||
m.lastMetadataRetryTime = time.Time{} // Reset to zero time
|
||||
m.metadataRetryMutex.Unlock()
|
||||
|
||||
// Second request should trigger recovery attempt
|
||||
req2 := httptest.NewRequest("GET", "/protected", nil)
|
||||
rr2 := httptest.NewRecorder()
|
||||
m.ServeHTTP(rr2, req2)
|
||||
|
||||
// Give the async recovery a moment to complete
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check if recovery happened
|
||||
m.metadataMu.RLock()
|
||||
recoveredIssuer := m.issuerURL
|
||||
m.metadataMu.RUnlock()
|
||||
|
||||
if recoveredIssuer == "" {
|
||||
t.Error("Expected issuerURL to be recovered after provider became available")
|
||||
}
|
||||
|
||||
// Third request should succeed (redirect to auth, not 503)
|
||||
req3 := httptest.NewRequest("GET", "/protected", nil)
|
||||
rr3 := httptest.NewRecorder()
|
||||
m.ServeHTTP(rr3, req3)
|
||||
|
||||
if rr3.Code == http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected redirect after recovery, still got 503")
|
||||
}
|
||||
|
||||
t.Logf("Recovery test: initial_issuer=%q, recovered_issuer=%q, final_status=%d",
|
||||
initialIssuer, recoveredIssuer, rr3.Code)
|
||||
}
|
||||
|
||||
func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
+11
-11
@@ -42,27 +42,27 @@ func NewMemoryLeakFixesTestSuite() *MemoryLeakFixesTestSuite {
|
||||
|
||||
// MemoryTestCase defines a memory leak test scenario
|
||||
type MemoryTestCase struct {
|
||||
name string
|
||||
component string // "cache", "session", "token", "plugin", "pool"
|
||||
scenario string // "concurrent", "longrunning", "stress", "lifecycle"
|
||||
iterations int
|
||||
concurrency int
|
||||
setup func(*MemoryTestFramework) error
|
||||
execute func(*MemoryTestFramework) error
|
||||
validateLeak func(*testing.T, runtime.MemStats, runtime.MemStats)
|
||||
cleanup func(*MemoryTestFramework) error
|
||||
name string
|
||||
component string
|
||||
scenario string
|
||||
iterations int
|
||||
concurrency int
|
||||
}
|
||||
|
||||
// MemoryTestFramework provides common test infrastructure for memory tests
|
||||
type MemoryTestFramework struct {
|
||||
t *testing.T
|
||||
cache CacheInterface
|
||||
ctx context.Context
|
||||
t *testing.T
|
||||
plugin *TraefikOidc
|
||||
logger *Logger
|
||||
cancel context.CancelFunc
|
||||
servers []*httptest.Server
|
||||
configs []*Config
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewMemoryTestFramework creates a new test framework instance
|
||||
@@ -97,12 +97,12 @@ func (tf *MemoryTestFramework) Cleanup() {
|
||||
// ConsolidatedMemorySnapshot captures memory statistics at a point in time
|
||||
type ConsolidatedMemorySnapshot struct {
|
||||
Timestamp time.Time
|
||||
Description string
|
||||
Alloc uint64
|
||||
TotalAlloc uint64
|
||||
Sys uint64
|
||||
NumGC uint32
|
||||
Goroutines int
|
||||
Description string
|
||||
NumGC uint32
|
||||
}
|
||||
|
||||
// VerifyNoGoroutineLeaks checks for goroutine leaks
|
||||
@@ -1601,8 +1601,8 @@ func TestMemoryLeakConsolidated(t *testing.T) {
|
||||
|
||||
func TestGoroutineLeaks(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
test func(t *testing.T)
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "cache_no_leak",
|
||||
|
||||
+28
-38
@@ -10,30 +10,24 @@ import (
|
||||
|
||||
// MemoryStats holds comprehensive memory statistics
|
||||
type MemoryStats struct {
|
||||
// Go runtime memory stats
|
||||
HeapAllocBytes uint64 // bytes allocated and still in use
|
||||
HeapSysBytes uint64 // bytes obtained from system
|
||||
HeapIdleBytes uint64 // bytes in idle (unused) spans
|
||||
HeapInuseBytes uint64 // bytes in in-use spans
|
||||
HeapReleasedBytes uint64 // bytes released to the OS
|
||||
HeapObjects uint64 // total number of allocated objects
|
||||
StackInuseBytes uint64 // bytes in stack spans
|
||||
StackSysBytes uint64 // bytes obtained from system for stack
|
||||
GCSysBytes uint64 // bytes used for garbage collection system metadata
|
||||
NumGoroutines int // number of goroutines that currently exist
|
||||
LastGCTime time.Time // time of last garbage collection
|
||||
|
||||
// Application-specific memory tracking
|
||||
SessionCount int // current number of sessions
|
||||
TaskCount int // current number of background tasks
|
||||
CacheSize int64 // estimated cache memory usage
|
||||
ConnectionPools int // number of HTTP connection pools
|
||||
|
||||
// Memory pressure indicators
|
||||
MemoryPressure MemoryPressureLevel // overall memory pressure level
|
||||
GCFrequency float64 // garbage collections per minute
|
||||
|
||||
Timestamp time.Time
|
||||
LastGCTime time.Time
|
||||
Timestamp time.Time
|
||||
GCSysBytes uint64
|
||||
NumGoroutines int
|
||||
HeapReleasedBytes uint64
|
||||
HeapObjects uint64
|
||||
StackInuseBytes uint64
|
||||
StackSysBytes uint64
|
||||
HeapAllocBytes uint64
|
||||
HeapInuseBytes uint64
|
||||
HeapIdleBytes uint64
|
||||
SessionCount int
|
||||
TaskCount int
|
||||
CacheSize int64
|
||||
ConnectionPools int
|
||||
MemoryPressure MemoryPressureLevel
|
||||
GCFrequency float64
|
||||
HeapSysBytes uint64
|
||||
}
|
||||
|
||||
// MemoryPressureLevel indicates the current memory pressure
|
||||
@@ -66,22 +60,18 @@ func (mpl MemoryPressureLevel) String() string {
|
||||
|
||||
// MemoryMonitor provides comprehensive memory monitoring and alerting
|
||||
type MemoryMonitor struct {
|
||||
logger *Logger
|
||||
mu sync.RWMutex
|
||||
lastStats *MemoryStats
|
||||
lastGCCount uint32
|
||||
lastGCTime time.Time
|
||||
startTime time.Time
|
||||
alertThresholds MemoryAlertThresholds
|
||||
|
||||
// Memory leak detection
|
||||
baselineHeap uint64
|
||||
heapGrowthRate float64 // bytes per second
|
||||
suspiciousGrowth bool
|
||||
|
||||
// Goroutine tracking
|
||||
lastGCTime time.Time
|
||||
startTime time.Time
|
||||
lastStats *MemoryStats
|
||||
logger *Logger
|
||||
alertThresholds MemoryAlertThresholds
|
||||
baselineGoroutines int
|
||||
baselineHeap uint64
|
||||
heapGrowthRate float64
|
||||
maxGoroutines int64
|
||||
mu sync.RWMutex
|
||||
lastGCCount uint32
|
||||
suspiciousGrowth bool
|
||||
goroutineLeakAlert bool
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +50,20 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if issuerURL == "" {
|
||||
// Provider metadata initialization failed - try to recover
|
||||
// Retry every 30 seconds to allow automatic recovery when provider comes back online
|
||||
t.metadataRetryMutex.Lock()
|
||||
shouldRetry := time.Since(t.lastMetadataRetryTime) >= 30*time.Second
|
||||
if shouldRetry {
|
||||
t.lastMetadataRetryTime = time.Now()
|
||||
}
|
||||
t.metadataRetryMutex.Unlock()
|
||||
|
||||
if shouldRetry && t.providerURL != "" {
|
||||
t.logger.Info("Attempting to recover OIDC provider metadata...")
|
||||
go t.attemptMetadataRecovery()
|
||||
}
|
||||
|
||||
t.logger.Error("OIDC provider metadata initialization failed or incomplete")
|
||||
t.sendErrorResponse(rw, req, "OIDC provider metadata initialization failed - please check provider availability and configuration", http.StatusServiceUnavailable)
|
||||
return
|
||||
|
||||
+33
-52
@@ -19,31 +19,25 @@ import (
|
||||
|
||||
// MockOAuthProvider simulates an OAuth/OIDC provider for testing
|
||||
type MockOAuthProvider struct {
|
||||
TokenEndpoint string
|
||||
AuthEndpoint string
|
||||
JWKSEndpoint string
|
||||
RevokeEndpoint string
|
||||
EndSessionEndpoint string
|
||||
|
||||
// Configurable behaviors
|
||||
TokenExchangeFunc func(grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
|
||||
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
|
||||
RevokeTokenFunc func(token, tokenType string) error
|
||||
JWKSResponseFunc func() ([]byte, error)
|
||||
|
||||
// Simulation flags
|
||||
SimulateTimeout bool
|
||||
SimulateRateLimit bool
|
||||
SimulateServerError bool
|
||||
TokenExchangeFunc func(grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
|
||||
LastRequest *http.Request
|
||||
JWKSResponseFunc func() ([]byte, error)
|
||||
RevokeTokenFunc func(token, tokenType string) error
|
||||
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
|
||||
EndSessionEndpoint string
|
||||
TokenEndpoint string
|
||||
RevokeEndpoint string
|
||||
JWKSEndpoint string
|
||||
AuthEndpoint string
|
||||
RequestHistory []*http.Request
|
||||
LastRequestBody []byte
|
||||
TimeoutDuration time.Duration
|
||||
ResponseDelay time.Duration
|
||||
|
||||
// Request tracking
|
||||
RequestCount int32
|
||||
LastRequest *http.Request
|
||||
LastRequestBody []byte
|
||||
RequestHistory []*http.Request
|
||||
mu sync.Mutex
|
||||
mu sync.Mutex
|
||||
RequestCount int32
|
||||
SimulateServerError bool
|
||||
SimulateRateLimit bool
|
||||
SimulateTimeout bool
|
||||
}
|
||||
|
||||
// NewMockOAuthProvider creates a new mock OAuth provider with default endpoints
|
||||
@@ -236,22 +230,16 @@ func (m *MockOAuthProvider) Reset() {
|
||||
|
||||
// MockSessionManager implements a mock session manager for testing
|
||||
type MockSessionManager struct {
|
||||
Sessions map[string]*SessionData
|
||||
mu sync.RWMutex
|
||||
|
||||
// Configurable behaviors
|
||||
Sessions map[string]*SessionData
|
||||
GetSessionFunc func(r *http.Request) (*SessionData, error)
|
||||
SaveSessionFunc func(r *http.Request, w http.ResponseWriter, session *SessionData) error
|
||||
DeleteSessionFunc func(r *http.Request, w http.ResponseWriter) error
|
||||
|
||||
// Simulation flags
|
||||
SimulateError bool
|
||||
SimulateNotFound bool
|
||||
|
||||
// Tracking
|
||||
GetCallCount int32
|
||||
SaveCallCount int32
|
||||
DeleteCallCount int32
|
||||
mu sync.RWMutex
|
||||
GetCallCount int32
|
||||
SaveCallCount int32
|
||||
DeleteCallCount int32
|
||||
SimulateError bool
|
||||
SimulateNotFound bool
|
||||
}
|
||||
|
||||
// NewMockSessionManager creates a new mock session manager
|
||||
@@ -370,23 +358,16 @@ func (m *MockSessionManager) Reset() {
|
||||
|
||||
// MockHTTPClient implements a mock HTTP client for testing
|
||||
type MockHTTPClient struct {
|
||||
// Response configuration
|
||||
ResponseFunc func(req *http.Request) (*http.Response, error)
|
||||
|
||||
// Default response settings
|
||||
DefaultStatusCode int
|
||||
DefaultBody string
|
||||
ResponseFunc func(req *http.Request) (*http.Response, error)
|
||||
DefaultHeaders map[string]string
|
||||
|
||||
// Simulation flags
|
||||
SimulateTimeout bool
|
||||
SimulateError bool
|
||||
TimeoutDuration time.Duration
|
||||
|
||||
// Request tracking
|
||||
Requests []*http.Request
|
||||
RequestBodies [][]byte
|
||||
mu sync.Mutex
|
||||
DefaultBody string
|
||||
Requests []*http.Request
|
||||
RequestBodies [][]byte
|
||||
DefaultStatusCode int
|
||||
TimeoutDuration time.Duration
|
||||
mu sync.Mutex
|
||||
SimulateTimeout bool
|
||||
SimulateError bool
|
||||
}
|
||||
|
||||
// NewMockHTTPClient creates a new mock HTTP client
|
||||
|
||||
+25
-56
@@ -15,38 +15,19 @@ import (
|
||||
// It implements request coalescing, rate limiting, and circuit breaking
|
||||
// specifically for token refresh operations.
|
||||
type RefreshCoordinator struct {
|
||||
// inFlightRefreshes tracks active refresh operations by refresh token hash
|
||||
inFlightRefreshes map[string]*refreshOperation
|
||||
// refreshMutex protects the inFlightRefreshes map
|
||||
refreshMutex sync.RWMutex
|
||||
|
||||
// sessionRefreshAttempts tracks refresh attempts per session
|
||||
inFlightRefreshes map[string]*refreshOperation
|
||||
cleanupTimers map[string]*time.Timer
|
||||
sessionRefreshAttempts map[string]*refreshAttemptTracker
|
||||
// attemptsMutex protects sessionRefreshAttempts map
|
||||
attemptsMutex sync.RWMutex
|
||||
|
||||
// Circuit breaker for refresh operations
|
||||
circuitBreaker *RefreshCircuitBreaker
|
||||
|
||||
// Configuration
|
||||
config RefreshCoordinatorConfig
|
||||
|
||||
// Metrics
|
||||
metrics *RefreshMetrics
|
||||
|
||||
// Logger
|
||||
logger *Logger
|
||||
|
||||
// Cleanup goroutine control
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
// delayedCleanupQueue stores items to be cleaned up after delay
|
||||
// Uses a timer-based approach instead of spawning goroutines per cleanup
|
||||
delayedCleanupQueue chan delayedCleanupItem
|
||||
// cleanupTimerPool reuses timers to avoid goroutine-per-cleanup
|
||||
cleanupTimerMu sync.Mutex
|
||||
cleanupTimers map[string]*time.Timer
|
||||
delayedCleanupQueue chan delayedCleanupItem
|
||||
circuitBreaker *RefreshCircuitBreaker
|
||||
metrics *RefreshMetrics
|
||||
logger *Logger
|
||||
stopChan chan struct{}
|
||||
config RefreshCoordinatorConfig
|
||||
wg sync.WaitGroup
|
||||
attemptsMutex sync.RWMutex
|
||||
refreshMutex sync.RWMutex
|
||||
cleanupTimerMu sync.Mutex
|
||||
}
|
||||
|
||||
// RefreshCoordinatorConfig configures the refresh coordinator behavior
|
||||
@@ -89,18 +70,12 @@ func DefaultRefreshCoordinatorConfig() RefreshCoordinatorConfig {
|
||||
|
||||
// refreshOperation represents an in-flight refresh operation
|
||||
type refreshOperation struct {
|
||||
// refreshToken being refreshed (for validation)
|
||||
startTime time.Time
|
||||
result *refreshResult
|
||||
done chan struct{}
|
||||
refreshToken string
|
||||
// result stores the final result
|
||||
result *refreshResult
|
||||
// done signals when the operation is complete
|
||||
done chan struct{}
|
||||
// startTime tracks when the operation started
|
||||
startTime time.Time
|
||||
// waiterCount tracks number of goroutines waiting
|
||||
waiterCount int32
|
||||
// mutex protects the result field
|
||||
mutex sync.RWMutex
|
||||
mutex sync.RWMutex
|
||||
waiterCount int32
|
||||
}
|
||||
|
||||
// refreshResult contains the result of a refresh operation
|
||||
@@ -112,18 +87,12 @@ type refreshResult struct {
|
||||
|
||||
// refreshAttemptTracker tracks refresh attempts for a session
|
||||
type refreshAttemptTracker struct {
|
||||
// attempts counts refresh attempts in current window
|
||||
attempts int32
|
||||
// lastAttemptTime is the timestamp of the last attempt
|
||||
lastAttemptTime time.Time
|
||||
// windowStartTime is when the current tracking window started
|
||||
windowStartTime time.Time
|
||||
// inCooldown indicates if this session is in cooldown
|
||||
inCooldown bool
|
||||
// cooldownEndTime is when cooldown period ends
|
||||
cooldownEndTime time.Time
|
||||
// consecutiveFailures tracks consecutive refresh failures
|
||||
lastAttemptTime time.Time
|
||||
windowStartTime time.Time
|
||||
cooldownEndTime time.Time
|
||||
attempts int32
|
||||
consecutiveFailures int32
|
||||
inCooldown bool
|
||||
}
|
||||
|
||||
// RefreshMetrics tracks coordinator performance metrics
|
||||
@@ -140,18 +109,18 @@ type RefreshMetrics struct {
|
||||
|
||||
// delayedCleanupItem represents an item scheduled for delayed cleanup
|
||||
type delayedCleanupItem struct {
|
||||
tokenHash string
|
||||
cleanupAt time.Time
|
||||
tokenHash string
|
||||
}
|
||||
|
||||
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations
|
||||
type RefreshCircuitBreaker struct {
|
||||
state int32 // 0=closed, 1=open, 2=half-open
|
||||
failures int32
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
config RefreshCircuitBreakerConfig
|
||||
mutex sync.RWMutex
|
||||
state int32
|
||||
failures int32
|
||||
}
|
||||
|
||||
// RefreshCircuitBreakerConfig configures the refresh circuit breaker
|
||||
|
||||
@@ -132,7 +132,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
session.SetEmail("user@example.com")
|
||||
// Azure may use opaque access tokens
|
||||
session.SetAccessToken("opaque-azure-access-token")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ") // trufflehog:ignore
|
||||
session.SetRefreshToken("azure-refresh-token")
|
||||
|
||||
// Save with proper security
|
||||
@@ -178,9 +178,9 @@ func testIssue53SameSiteCookies(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
expectedSecure bool
|
||||
expectedSameSite http.SameSite
|
||||
description string
|
||||
expectedSameSite http.SameSite
|
||||
expectedSecure bool
|
||||
}{
|
||||
{
|
||||
name: "HTTPS via proxy",
|
||||
@@ -240,9 +240,9 @@ func testIssue60MissingClaimFields(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
description string
|
||||
headers []traefikoidc.TemplatedHeader
|
||||
shouldValidate bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Direct claim access",
|
||||
|
||||
+555
-23
@@ -8,6 +8,7 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -66,23 +67,160 @@ func generateSecureRandomString(length int) (string, error) {
|
||||
// These are appended to the cookiePrefix to create full cookie names
|
||||
// #nosec G101 -- These are cookie name suffixes, not hardcoded credentials
|
||||
const (
|
||||
mainCookieSuffix = "m"
|
||||
accessTokenSuffix = "a"
|
||||
refreshTokenSuffix = "r"
|
||||
idTokenSuffix = "id"
|
||||
defaultCookiePrefix = "_oidc_raczylo_"
|
||||
mainCookieSuffix = "m"
|
||||
accessTokenSuffix = "a"
|
||||
refreshTokenSuffix = "r"
|
||||
idTokenSuffix = "id"
|
||||
combinedCookieSuffix = "s" // Combined session cookie suffix
|
||||
defaultCookiePrefix = "_oidc_raczylo_"
|
||||
)
|
||||
|
||||
const (
|
||||
maxBrowserCookieSize = 3500
|
||||
// maxCookieSize is the maximum raw data size per chunk before securecookie encoding.
|
||||
//
|
||||
// Browser cookie limit: 4096 bytes
|
||||
// Securecookie overhead: ~2x (gob + AES encryption + MAC + double base64)
|
||||
// Safety headroom: 1000 bytes for cookie metadata, headers, edge cases
|
||||
//
|
||||
// Math: 1400 raw → ~2800 encoded → safely under (4096 - 1000 = 3096) limit
|
||||
maxCookieSize = 1400
|
||||
|
||||
maxCookieSize = 1200
|
||||
// maxCombinedChunks is the maximum number of chunks allowed for combined session
|
||||
maxCombinedChunks = 10
|
||||
|
||||
absoluteSessionTimeout = 24 * time.Hour
|
||||
|
||||
minEncryptionKeyLength = 32
|
||||
)
|
||||
|
||||
// combinedSessionPayload is the JSON structure for combined cookie storage.
|
||||
// Uses short field names to minimize size.
|
||||
type combinedSessionPayload struct {
|
||||
X map[string]interface{} `json:"x,omitempty"`
|
||||
A string `json:"a,omitempty"`
|
||||
R string `json:"r,omitempty"`
|
||||
I string `json:"i,omitempty"`
|
||||
E string `json:"e,omitempty"`
|
||||
Cs string `json:"cs,omitempty"`
|
||||
N string `json:"n,omitempty"`
|
||||
Cv string `json:"cv,omitempty"`
|
||||
Ip string `json:"ip,omitempty"`
|
||||
Ca int64 `json:"ca,omitempty"`
|
||||
Rc int `json:"rc,omitempty"`
|
||||
Au bool `json:"au,omitempty"`
|
||||
}
|
||||
|
||||
// knownSessionKeys are the standard keys that are handled explicitly in the combined payload.
|
||||
// All other mainSession.Values keys are stored in the X (extra) field.
|
||||
var knownSessionKeys = map[string]bool{
|
||||
"access_token": true,
|
||||
"refresh_token": true,
|
||||
"id_token": true,
|
||||
"email": true,
|
||||
"authenticated": true,
|
||||
"csrf": true,
|
||||
"nonce": true,
|
||||
"code_verifier": true,
|
||||
"incoming_path": true,
|
||||
"created_at": true,
|
||||
"redirect_count": true,
|
||||
}
|
||||
|
||||
// compressCombinedPayload compresses the combined session payload using gzip.
|
||||
// It serializes the payload to JSON, compresses it, and returns base64-encoded data.
|
||||
// Returns the compressed string and any error encountered.
|
||||
func compressCombinedPayload(payload *combinedSessionPayload) (string, error) {
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal combined payload: %w", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
if _, err := gz.Write(jsonData); err != nil {
|
||||
return "", fmt.Errorf("failed to compress combined payload: %w", err)
|
||||
}
|
||||
if err := gz.Close(); err != nil {
|
||||
return "", fmt.Errorf("failed to close gzip writer: %w", err)
|
||||
}
|
||||
|
||||
compressed := base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||
return compressed, nil
|
||||
}
|
||||
|
||||
// decompressCombinedPayload decompresses a base64+gzip encoded combined session payload.
|
||||
// Returns the deserialized payload and any error encountered.
|
||||
func decompressCombinedPayload(compressed string) (*combinedSessionPayload, error) {
|
||||
if compressed == "" {
|
||||
return nil, fmt.Errorf("empty compressed data")
|
||||
}
|
||||
|
||||
data, err := base64.StdEncoding.DecodeString(compressed)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64: %w", err)
|
||||
}
|
||||
|
||||
gr, err := gzip.NewReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gr.Close()
|
||||
|
||||
// Limit decompressed size to prevent zip bombs
|
||||
limitedReader := io.LimitReader(gr, 512*1024) // 512KB max
|
||||
decompressed, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decompress: %w", err)
|
||||
}
|
||||
|
||||
var payload combinedSessionPayload
|
||||
if err := json.Unmarshal(decompressed, &payload); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal combined payload: %w", err)
|
||||
}
|
||||
|
||||
return &payload, nil
|
||||
}
|
||||
|
||||
// splitCombinedIntoChunks splits compressed data into chunks of maxCookieSize.
|
||||
// Returns the chunks and the total number of chunks.
|
||||
func splitCombinedIntoChunks(data string, chunkSize int) []string {
|
||||
if len(data) <= chunkSize {
|
||||
return []string{data}
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for i := 0; i < len(data); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(data) {
|
||||
end = len(data)
|
||||
}
|
||||
chunks = append(chunks, data[i:end])
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// assembleCombinedChunks reassembles chunks back into the original compressed data.
|
||||
// It reads chunks from session values in order (chunk_0, chunk_1, etc.).
|
||||
func assembleCombinedChunks(sessions []*sessions.Session) string {
|
||||
if len(sessions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var parts []string
|
||||
for i := 0; i < len(sessions); i++ {
|
||||
session := sessions[i]
|
||||
if session == nil {
|
||||
break
|
||||
}
|
||||
chunk, ok := session.Values["d"].(string) // "d" for data
|
||||
if !ok || chunk == "" {
|
||||
break
|
||||
}
|
||||
parts = append(parts, chunk)
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
|
||||
// compressToken compresses a JWT token using gzip compression if beneficial.
|
||||
// It validates the token format, attempts compression, and verifies the compressed
|
||||
// data can be decompressed correctly. Only compresses if it reduces size.
|
||||
@@ -236,22 +374,22 @@ func decompressTokenInternal(compressed string) string {
|
||||
// session object reuse and supports both HTTP and HTTPS schemes.
|
||||
type SessionManager struct {
|
||||
sessionPool sync.Pool
|
||||
ctx context.Context
|
||||
store sessions.Store
|
||||
logger *Logger
|
||||
chunkManager *ChunkManager
|
||||
cookieDomain string
|
||||
cookiePrefix string // Prefix for cookie names (default: "_oidc_raczylo_")
|
||||
sessionMaxAge time.Duration // Maximum session age (default: 24 hours)
|
||||
cleanupMutex sync.RWMutex
|
||||
forceHTTPS bool
|
||||
cleanupDone bool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
memoryMonitor *TaskMemoryMonitor
|
||||
cancel context.CancelFunc
|
||||
cookieDomain string
|
||||
cookiePrefix string
|
||||
sessionMaxAge time.Duration
|
||||
activeSessions int64
|
||||
poolHits int64
|
||||
poolMisses int64
|
||||
cleanupMutex sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
forceHTTPS bool
|
||||
cleanupDone bool
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new SessionManager instance with secure defaults.
|
||||
@@ -312,6 +450,8 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, cookieDomain strin
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
idTokenChunks: make(map[int]*sessions.Session),
|
||||
combinedChunks: make(map[int]*sessions.Session),
|
||||
useCombinedStorage: true, // Use combined storage by default for new sessions
|
||||
refreshMutex: sync.Mutex{},
|
||||
sessionMutex: sync.RWMutex{},
|
||||
dirty: false,
|
||||
@@ -349,6 +489,17 @@ func (sm *SessionManager) idTokenCookieName() string {
|
||||
return sm.cookiePrefix + idTokenSuffix
|
||||
}
|
||||
|
||||
// combinedCookieName returns the combined session cookie base name with the configured prefix
|
||||
// Chunk cookies are named: prefix + "s" + "_" + chunkIndex (e.g., "_oidc_raczylo_s_0")
|
||||
func (sm *SessionManager) combinedCookieName() string {
|
||||
return sm.cookiePrefix + combinedCookieSuffix
|
||||
}
|
||||
|
||||
// combinedChunkCookieName returns the name for a specific combined session chunk
|
||||
func (sm *SessionManager) combinedChunkCookieName(chunkIndex int) string {
|
||||
return fmt.Sprintf("%s_%d", sm.combinedCookieName(), chunkIndex)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the SessionManager and all its background tasks
|
||||
func (sm *SessionManager) Shutdown() error {
|
||||
var shutdownErr error
|
||||
@@ -650,7 +801,7 @@ func (sm *SessionManager) GetSessionMetrics() map[string]interface{} {
|
||||
metrics["force_https"] = sm.forceHTTPS
|
||||
metrics["absolute_timeout_hours"] = sm.sessionMaxAge.Hours()
|
||||
metrics["max_cookie_size"] = maxCookieSize
|
||||
metrics["max_browser_cookie_size"] = maxBrowserCookieSize
|
||||
metrics["max_encoded_cookie_size"] = maxCookieSize * 2 // ~2x encoding overhead
|
||||
|
||||
if cookieStore, ok := sm.store.(*sessions.CookieStore); ok && len(cookieStore.Codecs) > 0 {
|
||||
metrics["has_encryption"] = true
|
||||
@@ -824,9 +975,10 @@ func (sm *SessionManager) CleanupOldCookies(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
// GetSession retrieves or creates session data from the HTTP request.
|
||||
// It loads the main session and all token chunk sessions, performing validation
|
||||
// and timeout checks. The returned session must be explicitly returned to the pool
|
||||
// by calling returnToPoolSafely() to prevent memory leaks.
|
||||
// It first tries to load from combined cookies (new format), falling back to legacy
|
||||
// cookies if combined cookies don't exist. Performs validation and timeout checks.
|
||||
// The returned session must be explicitly returned to the pool by calling
|
||||
// returnToPoolSafely() to prevent memory leaks.
|
||||
// MEMORY LEAK FIX: Session is NOT returned to pool here - caller must call ReturnToPool() when done.
|
||||
// Parameters:
|
||||
// - r: The HTTP request containing session cookies.
|
||||
@@ -853,6 +1005,26 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
return nil, fmt.Errorf("%s: %w", message, err)
|
||||
}
|
||||
|
||||
// Try to load from combined cookies first (new format)
|
||||
if sm.loadFromCombinedCookies(r, sessionData) {
|
||||
sessionData.useCombinedStorage = true
|
||||
sm.logger.Debug("Loaded session from combined cookies")
|
||||
|
||||
// Check session timeout
|
||||
if sessionData.getCreatedAtUnsafe() > 0 {
|
||||
if time.Since(time.Unix(sessionData.getCreatedAtUnsafe(), 0)) > sm.sessionMaxAge {
|
||||
_ = sessionData.Clear(r, nil) // Safe to ignore: session is being invalidated
|
||||
return handleError(fmt.Errorf("session timeout"), "session expired")
|
||||
}
|
||||
}
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
// Fall back to legacy cookies
|
||||
sessionData.useCombinedStorage = false
|
||||
sm.logger.Debug("Loading session from legacy cookies")
|
||||
|
||||
var err error
|
||||
sessionData.mainSession, err = sm.store.Get(r, sm.mainCookieName())
|
||||
if err != nil {
|
||||
@@ -895,9 +1067,94 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
sm.getTokenChunkSessions(r, sm.refreshTokenCookieName(), sessionData.refreshTokenChunks)
|
||||
sm.getTokenChunkSessions(r, sm.idTokenCookieName(), sessionData.idTokenChunks)
|
||||
|
||||
// If legacy session has data, migrate to combined storage on next save
|
||||
if !sessionData.mainSession.IsNew {
|
||||
sessionData.useCombinedStorage = true
|
||||
sessionData.dirty = true // Mark dirty to trigger migration on save
|
||||
sm.logger.Debug("Legacy session found, will migrate to combined storage on save")
|
||||
}
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
// loadFromCombinedCookies attempts to load session data from combined cookies.
|
||||
// Returns true if combined cookies were found and successfully loaded.
|
||||
func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *SessionData) bool {
|
||||
// Check if first combined chunk exists
|
||||
firstChunk, err := sm.store.Get(r, sm.combinedChunkCookieName(0))
|
||||
if err != nil || firstChunk.IsNew {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get total chunk count from first chunk
|
||||
totalChunks, ok := firstChunk.Values["n"].(int)
|
||||
if !ok || totalChunks < 1 || totalChunks > maxCombinedChunks {
|
||||
sm.logger.Debugf("Invalid combined cookie chunk count: %v", firstChunk.Values["n"])
|
||||
return false
|
||||
}
|
||||
|
||||
// Load all chunks
|
||||
chunkSessions := make([]*sessions.Session, totalChunks)
|
||||
chunkSessions[0] = firstChunk
|
||||
sessionData.combinedChunks[0] = firstChunk
|
||||
|
||||
for i := 1; i < totalChunks; i++ {
|
||||
chunk, err := sm.store.Get(r, sm.combinedChunkCookieName(i))
|
||||
if err != nil || chunk.IsNew {
|
||||
sm.logger.Debugf("Missing combined cookie chunk %d", i)
|
||||
return false
|
||||
}
|
||||
chunkSessions[i] = chunk
|
||||
sessionData.combinedChunks[i] = chunk
|
||||
}
|
||||
|
||||
// Assemble and decompress
|
||||
compressed := assembleCombinedChunks(chunkSessions)
|
||||
if compressed == "" {
|
||||
sm.logger.Debug("Failed to assemble combined chunks")
|
||||
return false
|
||||
}
|
||||
|
||||
payload, err := decompressCombinedPayload(compressed)
|
||||
if err != nil {
|
||||
sm.logger.Debugf("Failed to decompress combined payload: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Hydrate the legacy session objects for compatibility with existing getter methods
|
||||
// We need to initialize them even though we're using combined storage
|
||||
sessionData.mainSession, _ = sm.store.Get(r, sm.mainCookieName())
|
||||
sessionData.accessSession, _ = sm.store.Get(r, sm.accessTokenCookieName())
|
||||
sessionData.refreshSession, _ = sm.store.Get(r, sm.refreshTokenCookieName())
|
||||
sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName())
|
||||
|
||||
// Populate legacy session values from combined payload
|
||||
sessionData.mainSession.Values["email"] = payload.E
|
||||
sessionData.mainSession.Values["authenticated"] = payload.Au
|
||||
sessionData.mainSession.Values["csrf"] = payload.Cs
|
||||
sessionData.mainSession.Values["nonce"] = payload.N
|
||||
sessionData.mainSession.Values["code_verifier"] = payload.Cv
|
||||
sessionData.mainSession.Values["incoming_path"] = payload.Ip
|
||||
sessionData.mainSession.Values["created_at"] = payload.Ca
|
||||
sessionData.mainSession.Values["redirect_count"] = payload.Rc
|
||||
|
||||
// Restore extra custom session values
|
||||
for key, val := range payload.X {
|
||||
sessionData.mainSession.Values[key] = val
|
||||
}
|
||||
|
||||
sessionData.accessSession.Values["token"] = payload.A
|
||||
sessionData.accessSession.Values["compressed"] = false
|
||||
|
||||
sessionData.refreshSession.Values["token"] = payload.R
|
||||
sessionData.refreshSession.Values["compressed"] = false
|
||||
|
||||
sessionData.idTokenSession.Values["token"] = payload.I
|
||||
sessionData.idTokenSession.Values["compressed"] = false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// getTokenChunkSessions loads all available token chunk sessions for a given token type.
|
||||
// It iterates through numbered chunk sessions until no more are found,
|
||||
// populating the provided chunks map with the loaded sessions.
|
||||
@@ -920,11 +1177,13 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
|
||||
// SessionData represents a user's authentication session with comprehensive token management.
|
||||
// It handles main session data and supports large tokens that need to be
|
||||
// split across multiple cookies due to browser size limitations.
|
||||
// Supports both legacy (separate cookies) and combined (single compressed cookie) storage.
|
||||
type SessionData struct {
|
||||
manager *SessionManager
|
||||
|
||||
request *http.Request
|
||||
|
||||
// Legacy storage (kept for backward compatibility during migration)
|
||||
mainSession *sessions.Session
|
||||
|
||||
accessSession *sessions.Session
|
||||
@@ -939,6 +1198,12 @@ type SessionData struct {
|
||||
|
||||
idTokenChunks map[int]*sessions.Session
|
||||
|
||||
// Combined storage (new approach - single compressed cookie)
|
||||
combinedChunks map[int]*sessions.Session
|
||||
|
||||
// useCombinedStorage indicates whether to use the new combined storage format
|
||||
useCombinedStorage bool
|
||||
|
||||
refreshMutex sync.Mutex
|
||||
|
||||
sessionMutex sync.RWMutex
|
||||
@@ -966,6 +1231,7 @@ func (sd *SessionData) MarkDirty() {
|
||||
// Save persists all session data including main session and token chunks.
|
||||
// It applies security options, saves all session components, and handles
|
||||
// errors gracefully by continuing to save other components even if one fails.
|
||||
// Uses combined cookie storage for efficiency when useCombinedStorage is true.
|
||||
// Parameters:
|
||||
// - r: The HTTP request context for security option configuration.
|
||||
// - w: The HTTP response writer for setting session cookies.
|
||||
@@ -978,6 +1244,117 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
options := sd.manager.getSessionOptions(isSecure)
|
||||
options = sd.manager.EnhanceSessionSecurity(options, r)
|
||||
|
||||
// Use combined storage for new sessions
|
||||
if sd.useCombinedStorage {
|
||||
return sd.saveCombined(r, w, options)
|
||||
}
|
||||
|
||||
// Legacy storage path (for backward compatibility)
|
||||
return sd.saveLegacy(r, w, options)
|
||||
}
|
||||
|
||||
// saveCombined saves all session data in a single compressed, chunked cookie.
|
||||
// This reduces cookie count and total size through combined compression.
|
||||
func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, options *sessions.Options) error {
|
||||
// Build the combined payload
|
||||
payload := &combinedSessionPayload{
|
||||
A: sd.getAccessTokenUnsafe(),
|
||||
R: sd.getRefreshTokenUnsafe(),
|
||||
I: sd.getIDTokenUnsafe(),
|
||||
E: sd.getEmailUnsafe(),
|
||||
Au: sd.getAuthenticatedUnsafe(),
|
||||
Cs: sd.getCSRFUnsafe(),
|
||||
N: sd.getNonceUnsafe(),
|
||||
Cv: sd.getCodeVerifierUnsafe(),
|
||||
Ip: sd.getIncomingPathUnsafe(),
|
||||
Ca: sd.getCreatedAtUnsafe(),
|
||||
Rc: sd.getRedirectCountUnsafe(),
|
||||
}
|
||||
|
||||
// Collect extra session values not handled by the standard fields
|
||||
sd.sessionMutex.RLock()
|
||||
if sd.mainSession != nil && len(sd.mainSession.Values) > 0 {
|
||||
extra := make(map[string]interface{})
|
||||
for key, val := range sd.mainSession.Values {
|
||||
keyStr, ok := key.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// Skip known session keys that are already in the payload
|
||||
if knownSessionKeys[keyStr] {
|
||||
continue
|
||||
}
|
||||
// Store the extra value (must be JSON-serializable)
|
||||
extra[keyStr] = val
|
||||
}
|
||||
if len(extra) > 0 {
|
||||
payload.X = extra
|
||||
}
|
||||
}
|
||||
sd.sessionMutex.RUnlock()
|
||||
|
||||
// Compress the payload
|
||||
compressed, err := compressCombinedPayload(payload)
|
||||
if err != nil {
|
||||
sd.manager.logger.Errorf("Failed to compress combined payload: %v", err)
|
||||
// Fall back to legacy storage on compression failure
|
||||
return sd.saveLegacy(r, w, options)
|
||||
}
|
||||
|
||||
sd.manager.logger.Debugf("Combined session: raw payload compressed to %d bytes", len(compressed))
|
||||
|
||||
// Split into chunks
|
||||
chunks := splitCombinedIntoChunks(compressed, maxCookieSize)
|
||||
if len(chunks) > maxCombinedChunks {
|
||||
sd.manager.logger.Errorf("Combined session requires %d chunks, exceeds max %d", len(chunks), maxCombinedChunks)
|
||||
return fmt.Errorf("session data too large: requires %d chunks, max is %d", len(chunks), maxCombinedChunks)
|
||||
}
|
||||
|
||||
sd.manager.logger.Debugf("Combined session split into %d chunks", len(chunks))
|
||||
|
||||
var firstErr error
|
||||
|
||||
// Save each chunk
|
||||
for i, chunkData := range chunks {
|
||||
cookieName := sd.manager.combinedChunkCookieName(i)
|
||||
session, err := sd.manager.store.Get(r, cookieName)
|
||||
if err != nil {
|
||||
sd.manager.logger.Errorf("Failed to get combined chunk session %s: %v", cookieName, err)
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
session.Values["d"] = chunkData // "d" for data
|
||||
session.Values["n"] = len(chunks) // "n" for total number of chunks
|
||||
session.Values["i"] = i // "i" for index
|
||||
session.Options = options
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
sd.manager.logger.Errorf("Failed to save combined chunk %d: %v", i, err)
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
sd.combinedChunks[i] = session
|
||||
}
|
||||
|
||||
// Expire old combined chunks that are no longer needed
|
||||
sd.expireOldCombinedChunks(r, w, options, len(chunks))
|
||||
|
||||
// Expire legacy cookies if they exist (migration)
|
||||
sd.expireLegacyCookies(r, w, options)
|
||||
|
||||
if firstErr == nil {
|
||||
sd.dirty = false
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// saveLegacy saves session data using the old separate cookie approach.
|
||||
// Kept for backward compatibility during migration.
|
||||
func (sd *SessionData) saveLegacy(r *http.Request, w http.ResponseWriter, options *sessions.Options) error {
|
||||
sd.mainSession.Options = options
|
||||
sd.accessSession.Options = options
|
||||
sd.refreshSession.Options = options
|
||||
@@ -1002,11 +1379,8 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
}
|
||||
|
||||
saveOrLogError(sd.mainSession, "main")
|
||||
|
||||
saveOrLogError(sd.accessSession, "access token")
|
||||
|
||||
saveOrLogError(sd.refreshSession, "refresh token")
|
||||
|
||||
saveOrLogError(sd.idTokenSession, "ID token")
|
||||
|
||||
for i, sessionChunk := range sd.accessTokenChunks {
|
||||
@@ -1030,6 +1404,84 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// expireOldCombinedChunks expires combined cookie chunks that are no longer needed.
|
||||
func (sd *SessionData) expireOldCombinedChunks(r *http.Request, w http.ResponseWriter, options *sessions.Options, currentChunks int) {
|
||||
// Expire chunks beyond the current count
|
||||
for i := currentChunks; i < maxCombinedChunks; i++ {
|
||||
cookieName := sd.manager.combinedChunkCookieName(i)
|
||||
session, err := sd.manager.store.Get(r, cookieName)
|
||||
if err != nil || session.IsNew {
|
||||
// No more old chunks
|
||||
break
|
||||
}
|
||||
// Expire this chunk
|
||||
expireOptions := *options
|
||||
expireOptions.MaxAge = -1
|
||||
session.Options = &expireOptions
|
||||
for k := range session.Values {
|
||||
delete(session.Values, k)
|
||||
}
|
||||
if err := session.Save(r, w); err != nil {
|
||||
sd.manager.logger.Debugf("Failed to expire old combined chunk %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// expireLegacyCookies expires old legacy format cookies during migration.
|
||||
func (sd *SessionData) expireLegacyCookies(r *http.Request, w http.ResponseWriter, options *sessions.Options) {
|
||||
expireOptions := *options
|
||||
expireOptions.MaxAge = -1
|
||||
|
||||
// Helper to expire a legacy session cookie without clearing in-memory values
|
||||
// IMPORTANT: We must NOT clear values from sessions that sd is holding,
|
||||
// as store.Get() returns the same cached session object
|
||||
expireLegacyChunk := func(cookieName string) {
|
||||
session, err := sd.manager.store.Get(r, cookieName)
|
||||
if err != nil || session.IsNew {
|
||||
return // Cookie doesn't exist
|
||||
}
|
||||
session.Options = &expireOptions
|
||||
// Clear values from chunk cookies (not the main session objects)
|
||||
for k := range session.Values {
|
||||
delete(session.Values, k)
|
||||
}
|
||||
_ = session.Save(r, w) // Best effort
|
||||
}
|
||||
|
||||
// For main session cookies, only set expiration WITHOUT clearing values
|
||||
// because sd.mainSession, sd.accessSession, etc. point to these same objects
|
||||
expireLegacyMain := func(cookieName string) {
|
||||
session, err := sd.manager.store.Get(r, cookieName)
|
||||
if err != nil || session.IsNew {
|
||||
return // Cookie doesn't exist
|
||||
}
|
||||
// Just expire the cookie, don't clear values (they're still needed in memory)
|
||||
session.Options = &expireOptions
|
||||
_ = session.Save(r, w) // Best effort
|
||||
}
|
||||
|
||||
// Expire main legacy cookies (don't clear in-memory values)
|
||||
expireLegacyMain(sd.manager.mainCookieName())
|
||||
expireLegacyMain(sd.manager.accessTokenCookieName())
|
||||
expireLegacyMain(sd.manager.refreshTokenCookieName())
|
||||
expireLegacyMain(sd.manager.idTokenCookieName())
|
||||
|
||||
// Expire legacy chunk cookies (safe to clear values, they're separate from main sessions)
|
||||
for i := 0; i < 50; i++ { // Max legacy chunks was 50
|
||||
accessChunk := fmt.Sprintf("%s_%d", sd.manager.accessTokenCookieName(), i)
|
||||
refreshChunk := fmt.Sprintf("%s_%d", sd.manager.refreshTokenCookieName(), i)
|
||||
idChunk := fmt.Sprintf("%s_%d", sd.manager.idTokenCookieName(), i)
|
||||
|
||||
session, err := sd.manager.store.Get(r, accessChunk)
|
||||
if err != nil || session.IsNew {
|
||||
break // No more chunks
|
||||
}
|
||||
expireLegacyChunk(accessChunk)
|
||||
expireLegacyChunk(refreshChunk)
|
||||
expireLegacyChunk(idChunk)
|
||||
}
|
||||
}
|
||||
|
||||
// clearSessionValues removes all values from a session and optionally expires it.
|
||||
// This is used during session cleanup and logout operations.
|
||||
// Parameters:
|
||||
@@ -1263,6 +1715,12 @@ func (sd *SessionData) Reset() {
|
||||
resetSession(sd.refreshSession)
|
||||
resetSession(sd.idTokenSession)
|
||||
|
||||
// Clear combined chunks
|
||||
for k, session := range sd.combinedChunks {
|
||||
resetSession(session)
|
||||
delete(sd.combinedChunks, k)
|
||||
}
|
||||
|
||||
// Clear redirect count to prevent leaking between sessions
|
||||
if sd.mainSession != nil && sd.mainSession.Values != nil {
|
||||
delete(sd.mainSession.Values, "redirect_count")
|
||||
@@ -1271,6 +1729,7 @@ func (sd *SessionData) Reset() {
|
||||
sd.dirty = false
|
||||
sd.inUse = false
|
||||
sd.request = nil
|
||||
sd.useCombinedStorage = true // Reset to use combined storage by default
|
||||
|
||||
// Reset the refresh mutex to ensure clean state
|
||||
// Note: We don't need to lock it since sessionMutex is already held
|
||||
@@ -1886,8 +2345,9 @@ func splitIntoChunks(s string, chunkSize int) []string {
|
||||
// Returns:
|
||||
// - true if the chunk is safe to store, false if it may exceed browser limits.
|
||||
func validateChunkSize(chunkData string) bool {
|
||||
// Estimate ~50% overhead for encoding, compare against ~2x maxCookieSize limit
|
||||
estimatedEncodedSize := len(chunkData) + (len(chunkData) * 50 / 100)
|
||||
return estimatedEncodedSize <= maxBrowserCookieSize
|
||||
return estimatedEncodedSize <= maxCookieSize*2
|
||||
}
|
||||
|
||||
// isCorruptionMarker detects if data contains known corruption indicators.
|
||||
@@ -2088,6 +2548,78 @@ func (sd *SessionData) getIDTokenUnsafe() string {
|
||||
return result.Token
|
||||
}
|
||||
|
||||
// getRefreshTokenUnsafe retrieves the refresh token without acquiring locks.
|
||||
// Used when the session mutex is already held to prevent deadlocks.
|
||||
func (sd *SessionData) getRefreshTokenUnsafe() string {
|
||||
token, _ := sd.refreshSession.Values["token"].(string)
|
||||
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
|
||||
|
||||
if sd.manager == nil || sd.manager.chunkManager == nil {
|
||||
return token
|
||||
}
|
||||
|
||||
result := sd.manager.chunkManager.GetToken(
|
||||
token,
|
||||
compressed,
|
||||
sd.refreshTokenChunks,
|
||||
RefreshTokenConfig,
|
||||
)
|
||||
|
||||
if result.Error != nil {
|
||||
// Handle opaque tokens
|
||||
if token != "" && !compressed && len(sd.refreshTokenChunks) == 0 {
|
||||
if strings.Count(token, ".") != 2 {
|
||||
return token
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
return result.Token
|
||||
}
|
||||
|
||||
// getEmailUnsafe retrieves the email without acquiring locks.
|
||||
func (sd *SessionData) getEmailUnsafe() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
// getCSRFUnsafe retrieves the CSRF token without acquiring locks.
|
||||
func (sd *SessionData) getCSRFUnsafe() string {
|
||||
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
// getNonceUnsafe retrieves the nonce without acquiring locks.
|
||||
func (sd *SessionData) getNonceUnsafe() string {
|
||||
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
// getCodeVerifierUnsafe retrieves the code verifier without acquiring locks.
|
||||
func (sd *SessionData) getCodeVerifierUnsafe() string {
|
||||
codeVerifier, _ := sd.mainSession.Values["code_verifier"].(string)
|
||||
return codeVerifier
|
||||
}
|
||||
|
||||
// getIncomingPathUnsafe retrieves the incoming path without acquiring locks.
|
||||
func (sd *SessionData) getIncomingPathUnsafe() string {
|
||||
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
// getCreatedAtUnsafe retrieves the created_at timestamp without acquiring locks.
|
||||
func (sd *SessionData) getCreatedAtUnsafe() int64 {
|
||||
createdAt, _ := sd.mainSession.Values["created_at"].(int64)
|
||||
return createdAt
|
||||
}
|
||||
|
||||
// getRedirectCountUnsafe retrieves the redirect count without acquiring locks.
|
||||
func (sd *SessionData) getRedirectCountUnsafe() int {
|
||||
count, _ := sd.mainSession.Values["redirect_count"].(int)
|
||||
return count
|
||||
}
|
||||
|
||||
// SetIDToken stores an ID token with automatic compression and chunking.
|
||||
// It validates the JWT format, compresses if beneficial, and splits into chunks
|
||||
// if the token exceeds cookie size limits. Includes comprehensive validation.
|
||||
|
||||
@@ -100,13 +100,12 @@ type Logger interface {
|
||||
// and error handling to ensure data integrity and prevent security vulnerabilities
|
||||
// throughout the process.
|
||||
type ChunkManager struct {
|
||||
logger Logger
|
||||
mutex *sync.RWMutex
|
||||
// sessionMap provides bounded session storage to prevent memory leaks
|
||||
lastCleanup time.Time
|
||||
logger Logger
|
||||
mutex *sync.RWMutex
|
||||
sessionMap map[string]*SessionEntry
|
||||
maxSessions int
|
||||
sessionTTL time.Duration
|
||||
lastCleanup time.Time
|
||||
}
|
||||
|
||||
// NewChunkManager creates a new ChunkManager instance with proper initialization.
|
||||
@@ -361,8 +360,8 @@ func (cm *ChunkManager) StoreSession(key string, session *sessions.Session) {
|
||||
if shouldEvict {
|
||||
// Find oldest sessions to remove
|
||||
type sessionAge struct {
|
||||
key string
|
||||
lastUsed time.Time
|
||||
key string
|
||||
}
|
||||
|
||||
sessions := make([]sessionAge, 0, currentLocal)
|
||||
|
||||
@@ -16,7 +16,7 @@ func TestTokenValidatorJWT(t *testing.T) {
|
||||
validator := NewTokenValidator()
|
||||
|
||||
// Test valid JWT format (using base64url encoded parts that are long enough)
|
||||
validJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
|
||||
validJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" // trufflehog:ignore
|
||||
err := validator.ValidateJWTFormat(validJWT, "test")
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid JWT to pass, got error: %v", err)
|
||||
@@ -186,10 +186,10 @@ func TestTokenConfigValidation(t *testing.T) {
|
||||
func TestSessionMapBounds_HardLimitEnforcement(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
maxSessions int
|
||||
sessionCount int
|
||||
expectEviction bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "within_limit",
|
||||
@@ -760,18 +760,18 @@ func TestValidateJWTContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectError bool
|
||||
description string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid JWT with required ID token claims",
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6ImNsaWVudElkIiwiZXhwIjoxNjQ2MDY0MDAwLCJpYXQiOjE2NDYwNjA0MDB9.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6ImNsaWVudElkIiwiZXhwIjoxNjQ2MDY0MDAwLCJpYXQiOjE2NDYwNjA0MDB9.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", // trufflehog:ignore
|
||||
expectError: false,
|
||||
description: "JWT with all required ID token claims should pass",
|
||||
},
|
||||
{
|
||||
name: "JWT missing required claims",
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", // trufflehog:ignore
|
||||
expectError: true,
|
||||
description: "JWT missing required claims should fail",
|
||||
},
|
||||
@@ -810,8 +810,8 @@ func TestValidateJWTHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
expectError bool
|
||||
description string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid JWT header",
|
||||
@@ -865,9 +865,9 @@ func TestValidateJWTPayload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload string
|
||||
description string
|
||||
config TokenConfig
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Valid ID token payload",
|
||||
@@ -927,8 +927,8 @@ func TestValidateJWTSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
signature string
|
||||
expectError bool
|
||||
description string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid signature",
|
||||
@@ -976,9 +976,9 @@ func TestValidateChunkStructure(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
chunks []ChunkData
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Valid chunk structure",
|
||||
@@ -1055,11 +1055,11 @@ func TestValidateChunkData(t *testing.T) {
|
||||
config := AccessTokenConfig
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
chunk ChunkData
|
||||
name string
|
||||
description string
|
||||
expectedTotal int
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Valid chunk data",
|
||||
@@ -1218,13 +1218,13 @@ func TestGetToken(t *testing.T) {
|
||||
cm := NewChunkManager(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mainSession *sessions.Session
|
||||
chunks map[int]*sessions.Session
|
||||
config TokenConfig
|
||||
name string
|
||||
expectedToken string
|
||||
expectError bool
|
||||
description string
|
||||
config TokenConfig
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Token from main session",
|
||||
@@ -1363,8 +1363,8 @@ func TestSerializeTokenToChunks(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectError bool
|
||||
description string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid token serialization",
|
||||
@@ -1436,10 +1436,10 @@ func TestDeserializeTokenFromChunks(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []ChunkData
|
||||
expectedToken string
|
||||
expectError bool
|
||||
description string
|
||||
chunks []ChunkData
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid chunks deserialization",
|
||||
@@ -1522,10 +1522,10 @@ func TestEncodeDecodeChunk(t *testing.T) {
|
||||
cs := NewChunkSerializer(NewNoOpLogger())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
chunk ChunkData
|
||||
expectError bool
|
||||
name string
|
||||
description string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid chunk encoding/decoding",
|
||||
@@ -1619,10 +1619,10 @@ func TestValidateChunkIntegrity(t *testing.T) {
|
||||
cs := NewChunkSerializer(NewNoOpLogger())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
chunk ChunkData
|
||||
expectError bool
|
||||
name string
|
||||
description string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid chunk integrity",
|
||||
|
||||
@@ -254,10 +254,10 @@ func (cs *ChunkSerializer) calculateChecksum(content string) string {
|
||||
|
||||
// ChunkData represents a single chunk of token data
|
||||
type ChunkData struct {
|
||||
Index int // Position of this chunk in the sequence
|
||||
Total int // Total number of chunks for this token
|
||||
Content string // The actual chunk content
|
||||
Checksum string // Simple checksum for integrity verification
|
||||
Content string
|
||||
Checksum string
|
||||
Index int
|
||||
Total int
|
||||
}
|
||||
|
||||
// EstimateChunkCount estimates how many chunks a token will need
|
||||
|
||||
+10
-11
@@ -84,21 +84,19 @@ type TokenRetrievalResult struct {
|
||||
// and error handling to ensure data integrity and prevent security vulnerabilities
|
||||
// throughout the process.
|
||||
type ChunkManager struct {
|
||||
logger *Logger
|
||||
mutex *sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup // WaitGroup to track background goroutine completion
|
||||
// sessionMap provides bounded session storage to prevent memory leaks
|
||||
lastCleanup time.Time
|
||||
ctx context.Context
|
||||
mutex *sync.RWMutex
|
||||
cancel context.CancelFunc
|
||||
sessionMap map[string]*SessionEntry
|
||||
logger *Logger
|
||||
wg sync.WaitGroup
|
||||
maxSessions int
|
||||
sessionTTL time.Duration
|
||||
lastCleanup time.Time
|
||||
cleanupRunning int32 // atomic flag to prevent concurrent cleanups
|
||||
// Memory usage tracking
|
||||
bytesAllocated int64
|
||||
peakSessions int64
|
||||
cleanupCount int64
|
||||
cleanupRunning int32
|
||||
}
|
||||
|
||||
// SessionEntry represents a session with expiration tracking
|
||||
@@ -393,7 +391,8 @@ func (cm *ChunkManager) processChunkedToken(chunks map[int]*sessions.Session, co
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
}
|
||||
|
||||
if len(chunk) > maxBrowserCookieSize {
|
||||
// Secondary check: ensure chunk won't exceed browser limit after encoding (~2x overhead)
|
||||
if len(chunk) > maxCookieSize*2 {
|
||||
err := fmt.Errorf("%s token chunk %d exceeds browser limit (%d bytes)",
|
||||
config.Type, i, len(chunk))
|
||||
return TokenRetrievalResult{Token: "", Error: err}
|
||||
@@ -1199,8 +1198,8 @@ func (cm *ChunkManager) findOldestSessions(k int) []string {
|
||||
|
||||
// Collect all timestamps with keys
|
||||
type sessionAge struct {
|
||||
key string
|
||||
lastUsed time.Time
|
||||
key string
|
||||
}
|
||||
|
||||
sessions := make([]sessionAge, 0, len(cm.sessionMap))
|
||||
|
||||
+15
-15
@@ -24,31 +24,31 @@ import (
|
||||
|
||||
// SessionTestCase represents a comprehensive session test scenario
|
||||
type SessionTestCase struct {
|
||||
name string
|
||||
scenario string // "creation", "validation", "expiration", "persistence", "cleanup", "chunking", "security"
|
||||
sessionType string // "user", "admin", "api", "guest", "csrf"
|
||||
setup func(*SessionTestFramework)
|
||||
execute func(*SessionTestFramework) error
|
||||
validate func(*testing.T, error, *SessionTestFramework)
|
||||
cleanup func(*SessionTestFramework)
|
||||
concurrent bool
|
||||
name string
|
||||
scenario string
|
||||
sessionType string
|
||||
skipReason string
|
||||
iterations int
|
||||
timeout time.Duration
|
||||
skipReason string
|
||||
concurrent bool
|
||||
}
|
||||
|
||||
// SessionTestFramework provides shared test infrastructure for session tests
|
||||
type SessionTestFramework struct {
|
||||
t *testing.T
|
||||
mockProvider *httptest.Server
|
||||
testTokens map[string]string
|
||||
metrics *SessionTestMetrics
|
||||
config *SessionTestConfig
|
||||
requests []*http.Request
|
||||
responses []*httptest.ResponseRecorder
|
||||
testTokens map[string]string
|
||||
sessionIDs []string
|
||||
mu sync.RWMutex
|
||||
metrics *SessionTestMetrics
|
||||
cleanupFuncs []func()
|
||||
config *SessionTestConfig
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SessionTestMetrics tracks test performance metrics
|
||||
@@ -65,12 +65,12 @@ type SessionTestMetrics struct {
|
||||
|
||||
// SessionTestConfig holds test configuration
|
||||
type SessionTestConfig struct {
|
||||
CookieDomain string
|
||||
EncryptionKey string
|
||||
MaxChunkSize int
|
||||
MaxSessions int
|
||||
EnableHTTPS bool
|
||||
CookieDomain string
|
||||
SessionTimeout time.Duration
|
||||
EncryptionKey string
|
||||
EnableHTTPS bool
|
||||
EnableCompression bool
|
||||
}
|
||||
|
||||
@@ -2849,9 +2849,9 @@ func TestSessionExpiryVsTokenExpiry(t *testing.T) {
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
expectedBehavior string
|
||||
sessionAge time.Duration
|
||||
tokenExpiry time.Duration
|
||||
expectedBehavior string
|
||||
sessionShouldExpire bool
|
||||
tokenShouldRefresh bool
|
||||
}{
|
||||
@@ -2975,10 +2975,10 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
||||
|
||||
scenarios := []struct {
|
||||
name string
|
||||
tokenExpiry time.Duration
|
||||
shouldCleanup bool
|
||||
shouldPreserve []string
|
||||
shouldRemove []string
|
||||
tokenExpiry time.Duration
|
||||
shouldCleanup bool
|
||||
}{
|
||||
{
|
||||
name: "Recently expired tokens - preserve session",
|
||||
|
||||
+103
-334
@@ -27,359 +27,128 @@ type TemplatedHeader struct {
|
||||
// It provides all necessary settings to configure OpenID Connect authentication
|
||||
// with various providers like Auth0, Logto, or any standard OIDC provider.
|
||||
type Config struct {
|
||||
HTTPClient *http.Client `json:"-"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
CookiePrefix string `json:"cookiePrefix"` // Prefix for session cookie names (default: "_oidc_raczylo_")
|
||||
SessionMaxAge int `json:"sessionMaxAge"` // Maximum session age in seconds (default: 86400 = 24 hours)
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
// Audience specifies the expected JWT audience claim value.
|
||||
// If not set, defaults to ClientID for backward compatibility.
|
||||
// For Auth0 API access tokens with custom audiences, set this to your API identifier.
|
||||
// For Azure AD with Application ID URI, set to "api://your-app-id".
|
||||
// Security: This value is validated against the JWT aud claim to prevent token confusion attacks.
|
||||
Audience string `json:"audience,omitempty"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ProviderURL string `json:"providerURL"`
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Scopes []string `json:"scopes"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
EnablePKCE bool `json:"enablePKCE"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
// StrictAudienceValidation enforces strict audience validation for access tokens.
|
||||
// When enabled, sessions are rejected if access token validation fails (prevents fallback to ID token).
|
||||
// This addresses Auth0 Scenario 2 security concerns where access tokens without proper
|
||||
// audience claims could be accepted based on ID token validation.
|
||||
// Default: false (backward compatible - allows ID token fallback)
|
||||
// Recommended: true for production environments requiring strict OAuth 2.0 compliance
|
||||
StrictAudienceValidation bool `json:"strictAudienceValidation,omitempty"`
|
||||
// AllowOpaqueTokens enables acceptance of non-JWT (opaque) access tokens.
|
||||
// When enabled, opaque tokens are validated via OAuth 2.0 Token Introspection (RFC 7662).
|
||||
// This supports Auth0 Scenario 3 and other providers that issue opaque access tokens.
|
||||
// Default: false (only JWT access tokens accepted)
|
||||
// Note: Requires introspection endpoint to be available from provider metadata
|
||||
AllowOpaqueTokens bool `json:"allowOpaqueTokens,omitempty"`
|
||||
// RequireTokenIntrospection forces token introspection for all opaque access tokens.
|
||||
// When enabled, opaque tokens are rejected if introspection endpoint is unavailable.
|
||||
// When disabled, opaque tokens fall back to ID token validation.
|
||||
// Default: false (allows fallback to ID token)
|
||||
// Recommended: true when AllowOpaqueTokens is enabled for maximum security
|
||||
RequireTokenIntrospection bool `json:"requireTokenIntrospection,omitempty"`
|
||||
// DisableReplayDetection disables JTI-based replay attack detection.
|
||||
// Enable this when running multiple Traefik replicas to prevent false positives.
|
||||
// Each replica maintains its own in-memory JTI cache, so the same valid token
|
||||
// hitting different replicas will trigger replay detection on subsequent requests.
|
||||
//
|
||||
// Security Note: When enabled, the plugin still validates token signatures,
|
||||
// expiration, and other claims. Only the JTI replay check is disabled.
|
||||
// Consider using a shared cache backend (Redis/Memcached) if replay detection
|
||||
// is required in multi-replica scenarios.
|
||||
//
|
||||
// Default: false (replay detection enabled)
|
||||
// Recommended: true for multi-replica deployments
|
||||
DisableReplayDetection bool `json:"disableReplayDetection,omitempty"`
|
||||
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
|
||||
|
||||
// Redis configures the Redis cache backend for distributed caching.
|
||||
// When enabled, provides cache sharing across multiple Traefik replicas.
|
||||
// Default: nil (disabled - uses in-memory caching)
|
||||
Redis *RedisConfig `json:"redis,omitempty"`
|
||||
|
||||
// RoleClaimName specifies the JWT claim name to extract user roles from.
|
||||
// This allows compatibility with different OIDC providers that use different claim names.
|
||||
//
|
||||
// Examples:
|
||||
// - Default (backward compatible): "roles"
|
||||
// - Auth0 namespaced: "https://myapp.com/roles"
|
||||
// - Keycloak realm roles: "realm_access.roles"
|
||||
// - Custom claim: "user_roles"
|
||||
//
|
||||
// If not specified, defaults to "roles" for backward compatibility.
|
||||
// Supports both simple names and namespaced URIs per OIDC specification.
|
||||
//
|
||||
// Default: "roles"
|
||||
RoleClaimName string `json:"roleClaimName,omitempty"`
|
||||
|
||||
// GroupClaimName specifies the JWT claim name to extract user groups from.
|
||||
// This allows compatibility with different OIDC providers that use different claim names.
|
||||
//
|
||||
// Examples:
|
||||
// - Default (backward compatible): "groups"
|
||||
// - Auth0 namespaced: "https://myapp.com/groups"
|
||||
// - Azure AD groups: "groups"
|
||||
// - Custom claim: "user_groups"
|
||||
//
|
||||
// If not specified, defaults to "groups" for backward compatibility.
|
||||
// Supports both simple names and namespaced URIs per OIDC specification.
|
||||
//
|
||||
// Default: "groups"
|
||||
GroupClaimName string `json:"groupClaimName,omitempty"`
|
||||
|
||||
// UserIdentifierClaim specifies the JWT claim to use as the user identifier.
|
||||
// This allows authentication for users without email addresses (e.g., Azure AD service accounts).
|
||||
//
|
||||
// Examples:
|
||||
// - Default (backward compatible): "email"
|
||||
// - Azure AD without email: "sub", "oid", "upn", or "preferred_username"
|
||||
// - Generic OIDC: "sub" (always present per OIDC spec)
|
||||
//
|
||||
// When set to a non-email claim:
|
||||
// - AllowedUsers will match against this claim value instead of email
|
||||
// - AllowedUserDomains validation is skipped (domains only apply to email)
|
||||
// - The session will store this identifier as the user's identity
|
||||
//
|
||||
// Default: "email"
|
||||
UserIdentifierClaim string `json:"userIdentifierClaim,omitempty"`
|
||||
|
||||
// DynamicClientRegistration enables OIDC Dynamic Client Registration (RFC 7591)
|
||||
// When enabled, the middleware will automatically register as a client with
|
||||
// the OIDC provider if ClientID/ClientSecret are not provided.
|
||||
DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"`
|
||||
|
||||
// AllowPrivateIPAddresses disables the security check that blocks private/internal IP addresses.
|
||||
// By default, the plugin rejects URLs containing private IP ranges (10.x.x.x, 172.16-31.x.x, 192.168.x.x)
|
||||
// to prevent SSRF attacks and ensure OIDC providers are publicly accessible.
|
||||
//
|
||||
// Enable this option ONLY when:
|
||||
// - Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
|
||||
// - You have no DNS resolution available for internal services
|
||||
// - Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
|
||||
//
|
||||
// Security Warning: Enabling this option reduces SSRF protection. Only use in trusted
|
||||
// network environments where the OIDC provider is known and controlled.
|
||||
//
|
||||
// Default: false (private IPs are blocked for security)
|
||||
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
|
||||
|
||||
// MinimalHeaders reduces the number of headers forwarded to downstream services.
|
||||
// This helps prevent "431 Request Header Fields Too Large" errors when downstream
|
||||
// services have limited header buffer sizes.
|
||||
//
|
||||
// When enabled (true):
|
||||
// - Only forwards: X-Forwarded-User
|
||||
// - Skips: X-Auth-Request-Token (full ID token), X-Auth-Request-Redirect
|
||||
// - Groups/roles headers (X-User-Groups, X-User-Roles) are still forwarded if configured
|
||||
// - Custom templated headers are still processed
|
||||
//
|
||||
// When disabled (false, default):
|
||||
// - Forwards all headers: X-Forwarded-User, X-Auth-Request-User, X-Auth-Request-Redirect,
|
||||
// X-Auth-Request-Token (full ID token)
|
||||
//
|
||||
// Use this option when:
|
||||
// - Downstream services return "431 Request Header Fields Too Large" errors
|
||||
// - You don't need the full ID token forwarded to backend services
|
||||
// - You want to reduce request overhead
|
||||
//
|
||||
// Default: false (all headers forwarded for backward compatibility)
|
||||
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
|
||||
Redis *RedisConfig `json:"redis,omitempty"`
|
||||
HTTPClient *http.Client `json:"-"`
|
||||
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
Audience string `json:"audience,omitempty"`
|
||||
CookiePrefix string `json:"cookiePrefix"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ProviderURL string `json:"providerURL"`
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
UserIdentifierClaim string `json:"userIdentifierClaim,omitempty"`
|
||||
GroupClaimName string `json:"groupClaimName,omitempty"`
|
||||
RoleClaimName string `json:"roleClaimName,omitempty"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
Scopes []string `json:"scopes"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
SessionMaxAge int `json:"sessionMaxAge"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
DisableReplayDetection bool `json:"disableReplayDetection,omitempty"`
|
||||
RequireTokenIntrospection bool `json:"requireTokenIntrospection,omitempty"`
|
||||
AllowOpaqueTokens bool `json:"allowOpaqueTokens,omitempty"`
|
||||
StrictAudienceValidation bool `json:"strictAudienceValidation,omitempty"`
|
||||
EnablePKCE bool `json:"enablePKCE"`
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
|
||||
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
|
||||
}
|
||||
|
||||
// RedisConfig configures Redis cache backend settings for distributed caching.
|
||||
// All fields support both JSON and YAML configuration for compatibility with Traefik's
|
||||
// dynamic configuration (labels, YAML files, etc.)
|
||||
type RedisConfig struct {
|
||||
// Enabled indicates if Redis caching should be used (default: false)
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
|
||||
// Address is the Redis server address (e.g., "localhost:6379", "redis:6379")
|
||||
Address string `json:"address" yaml:"address"`
|
||||
|
||||
// Password for Redis authentication (optional, leave empty for no auth)
|
||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||
|
||||
// DB is the Redis database number to use (default: 0)
|
||||
DB int `json:"db" yaml:"db"`
|
||||
|
||||
// KeyPrefix is the prefix for all Redis keys (default: "traefikoidc:")
|
||||
KeyPrefix string `json:"keyPrefix" yaml:"keyPrefix"`
|
||||
|
||||
// PoolSize is the maximum number of socket connections (default: 10)
|
||||
PoolSize int `json:"poolSize" yaml:"poolSize"`
|
||||
|
||||
// ConnectTimeout is the timeout for establishing connections in seconds (default: 5)
|
||||
ConnectTimeout int `json:"connectTimeout" yaml:"connectTimeout"`
|
||||
|
||||
// ReadTimeout is the timeout for read operations in seconds (default: 3)
|
||||
ReadTimeout int `json:"readTimeout" yaml:"readTimeout"`
|
||||
|
||||
// WriteTimeout is the timeout for write operations in seconds (default: 3)
|
||||
WriteTimeout int `json:"writeTimeout" yaml:"writeTimeout"`
|
||||
|
||||
// EnableTLS indicates if TLS should be used for Redis connections (default: false)
|
||||
EnableTLS bool `json:"enableTLS" yaml:"enableTLS"`
|
||||
|
||||
// TLSSkipVerify skips TLS certificate verification (not recommended for production)
|
||||
TLSSkipVerify bool `json:"tlsSkipVerify" yaml:"tlsSkipVerify"`
|
||||
|
||||
// CacheMode determines the caching strategy: "redis" (Redis only), "hybrid" (Memory+Redis), "memory" (Memory only)
|
||||
// Default: "redis" when enabled
|
||||
CacheMode string `json:"cacheMode" yaml:"cacheMode"`
|
||||
|
||||
// HybridL1Size is the maximum number of items in L1 cache for hybrid mode (default: 500)
|
||||
HybridL1Size int `json:"hybridL1Size" yaml:"hybridL1Size"`
|
||||
|
||||
// HybridL1MemoryMB is the maximum memory in MB for L1 cache in hybrid mode (default: 10)
|
||||
HybridL1MemoryMB int64 `json:"hybridL1MemoryMB" yaml:"hybridL1MemoryMB"`
|
||||
|
||||
// EnableCircuitBreaker enables circuit breaker for Redis failures (default: true)
|
||||
EnableCircuitBreaker bool `json:"enableCircuitBreaker" yaml:"enableCircuitBreaker"`
|
||||
|
||||
// CircuitBreakerThreshold is the number of failures before opening circuit (default: 5)
|
||||
CircuitBreakerThreshold int `json:"circuitBreakerThreshold" yaml:"circuitBreakerThreshold"`
|
||||
|
||||
// CircuitBreakerTimeout is the timeout in seconds before attempting to close circuit (default: 60)
|
||||
CircuitBreakerTimeout int `json:"circuitBreakerTimeout" yaml:"circuitBreakerTimeout"`
|
||||
|
||||
// EnableHealthCheck enables periodic health checks for Redis (default: true)
|
||||
EnableHealthCheck bool `json:"enableHealthCheck" yaml:"enableHealthCheck"`
|
||||
|
||||
// HealthCheckInterval is the interval in seconds between health checks (default: 30)
|
||||
HealthCheckInterval int `json:"healthCheckInterval" yaml:"healthCheckInterval"`
|
||||
KeyPrefix string `json:"keyPrefix" yaml:"keyPrefix"`
|
||||
Address string `json:"address" yaml:"address"`
|
||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||
CacheMode string `json:"cacheMode" yaml:"cacheMode"`
|
||||
WriteTimeout int `json:"writeTimeout" yaml:"writeTimeout"`
|
||||
CircuitBreakerThreshold int `json:"circuitBreakerThreshold" yaml:"circuitBreakerThreshold"`
|
||||
ConnectTimeout int `json:"connectTimeout" yaml:"connectTimeout"`
|
||||
ReadTimeout int `json:"readTimeout" yaml:"readTimeout"`
|
||||
PoolSize int `json:"poolSize" yaml:"poolSize"`
|
||||
HealthCheckInterval int `json:"healthCheckInterval" yaml:"healthCheckInterval"`
|
||||
CircuitBreakerTimeout int `json:"circuitBreakerTimeout" yaml:"circuitBreakerTimeout"`
|
||||
DB int `json:"db" yaml:"db"`
|
||||
HybridL1Size int `json:"hybridL1Size" yaml:"hybridL1Size"`
|
||||
HybridL1MemoryMB int64 `json:"hybridL1MemoryMB" yaml:"hybridL1MemoryMB"`
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
EnableCircuitBreaker bool `json:"enableCircuitBreaker" yaml:"enableCircuitBreaker"`
|
||||
TLSSkipVerify bool `json:"tlsSkipVerify" yaml:"tlsSkipVerify"`
|
||||
EnableHealthCheck bool `json:"enableHealthCheck" yaml:"enableHealthCheck"`
|
||||
EnableTLS bool `json:"enableTLS" yaml:"enableTLS"`
|
||||
}
|
||||
|
||||
// DynamicClientRegistrationConfig configures OIDC Dynamic Client Registration (RFC 7591)
|
||||
type DynamicClientRegistrationConfig struct {
|
||||
// Enabled enables automatic client registration with the OIDC provider
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// InitialAccessToken is an optional bearer token for protected registration endpoints
|
||||
// Some providers require this token to authorize new client registrations
|
||||
InitialAccessToken string `json:"initialAccessToken,omitempty"`
|
||||
|
||||
// RegistrationEndpoint overrides the endpoint discovered from provider metadata
|
||||
// If empty, uses the registration_endpoint from .well-known/openid-configuration
|
||||
RegistrationEndpoint string `json:"registrationEndpoint,omitempty"`
|
||||
|
||||
// ClientMetadata contains the client metadata to register
|
||||
ClientMetadata *ClientRegistrationMetadata `json:"clientMetadata,omitempty"`
|
||||
|
||||
// PersistCredentials determines whether to save registered credentials to a file
|
||||
// This allows reusing the same client_id/client_secret across restarts
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
|
||||
// CredentialsFile is the path to store/load registered client credentials
|
||||
// Defaults to "/tmp/oidc-client-credentials.json" if not specified
|
||||
CredentialsFile string `json:"credentialsFile,omitempty"`
|
||||
ClientMetadata *ClientRegistrationMetadata `json:"clientMetadata,omitempty"`
|
||||
InitialAccessToken string `json:"initialAccessToken,omitempty"`
|
||||
RegistrationEndpoint string `json:"registrationEndpoint,omitempty"`
|
||||
CredentialsFile string `json:"credentialsFile,omitempty"`
|
||||
Enabled bool `json:"enabled"`
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
}
|
||||
|
||||
// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591)
|
||||
type ClientRegistrationMetadata struct {
|
||||
// RedirectURIs is REQUIRED - array of redirect URIs for authorization
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
|
||||
// ResponseTypes specifies OAuth 2.0 response types (default: ["code"])
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
|
||||
// GrantTypes specifies OAuth 2.0 grant types (default: ["authorization_code"])
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
|
||||
// ApplicationType is either "web" (default) or "native"
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
|
||||
// Contacts is an array of email addresses for responsible parties
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
|
||||
// ClientName is a human-readable name for the client
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
|
||||
// LogoURI is a URL pointing to a logo for the client
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
|
||||
// ClientURI is a URL of the home page of the client
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
|
||||
// PolicyURI is a URL pointing to the client's privacy policy
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
|
||||
// TOSURI is a URL pointing to the client's terms of service
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
|
||||
// JWKSURI is a URL for the client's JSON Web Key Set
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
|
||||
// SubjectType is "pairwise" or "public" (provider-specific)
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
|
||||
// TokenEndpointAuthMethod specifies how the client authenticates at token endpoint
|
||||
// Values: "client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt", "none"
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
|
||||
// DefaultMaxAge is the default maximum authentication age in seconds
|
||||
DefaultMaxAge int `json:"default_max_age,omitempty"`
|
||||
|
||||
// RequireAuthTime specifies whether auth_time claim is required in ID token
|
||||
RequireAuthTime bool `json:"require_auth_time,omitempty"`
|
||||
|
||||
// DefaultACRValues specifies default ACR values
|
||||
DefaultACRValues []string `json:"default_acr_values,omitempty"`
|
||||
|
||||
// Scope is a space-separated list of scopes (alternative to config.Scopes)
|
||||
Scope string `json:"scope,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
DefaultACRValues []string `json:"default_acr_values,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
DefaultMaxAge int `json:"default_max_age,omitempty"`
|
||||
RequireAuthTime bool `json:"require_auth_time,omitempty"`
|
||||
}
|
||||
|
||||
// SecurityHeadersConfig configures security headers for the plugin
|
||||
type SecurityHeadersConfig struct {
|
||||
// Enable security headers (default: true)
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Security profile: "default", "strict", "development", "api", or "custom"
|
||||
Profile string `json:"profile"`
|
||||
|
||||
// Content Security Policy
|
||||
ContentSecurityPolicy string `json:"contentSecurityPolicy,omitempty"`
|
||||
|
||||
// HSTS settings
|
||||
StrictTransportSecurity bool `json:"strictTransportSecurity"`
|
||||
StrictTransportSecurityMaxAge int `json:"strictTransportSecurityMaxAge"` // seconds
|
||||
StrictTransportSecuritySubdomains bool `json:"strictTransportSecuritySubdomains"`
|
||||
StrictTransportSecurityPreload bool `json:"strictTransportSecurityPreload"`
|
||||
|
||||
// Frame options: "DENY", "SAMEORIGIN", or "ALLOW-FROM uri"
|
||||
FrameOptions string `json:"frameOptions,omitempty"`
|
||||
|
||||
// Content type options (default: "nosniff")
|
||||
ContentTypeOptions string `json:"contentTypeOptions,omitempty"`
|
||||
|
||||
// XSS protection (default: "1; mode=block")
|
||||
XSSProtection string `json:"xssProtection,omitempty"`
|
||||
|
||||
// Referrer policy
|
||||
ReferrerPolicy string `json:"referrerPolicy,omitempty"`
|
||||
|
||||
// Permissions policy
|
||||
PermissionsPolicy string `json:"permissionsPolicy,omitempty"`
|
||||
|
||||
// Cross-origin settings
|
||||
CrossOriginEmbedderPolicy string `json:"crossOriginEmbedderPolicy,omitempty"`
|
||||
CrossOriginOpenerPolicy string `json:"crossOriginOpenerPolicy,omitempty"`
|
||||
CrossOriginResourcePolicy string `json:"crossOriginResourcePolicy,omitempty"`
|
||||
|
||||
// CORS settings
|
||||
CORSEnabled bool `json:"corsEnabled"`
|
||||
CORSAllowedOrigins []string `json:"corsAllowedOrigins,omitempty"`
|
||||
CORSAllowedMethods []string `json:"corsAllowedMethods,omitempty"`
|
||||
CORSAllowedHeaders []string `json:"corsAllowedHeaders,omitempty"`
|
||||
CORSAllowCredentials bool `json:"corsAllowCredentials"`
|
||||
CORSMaxAge int `json:"corsMaxAge"` // seconds
|
||||
|
||||
// Custom headers (in addition to standard security headers)
|
||||
CustomHeaders map[string]string `json:"customHeaders,omitempty"`
|
||||
|
||||
// Security features
|
||||
DisableServerHeader bool `json:"disableServerHeader"`
|
||||
DisablePoweredByHeader bool `json:"disablePoweredByHeader"`
|
||||
CustomHeaders map[string]string `json:"customHeaders,omitempty"`
|
||||
PermissionsPolicy string `json:"permissionsPolicy,omitempty"`
|
||||
Profile string `json:"profile"`
|
||||
ContentSecurityPolicy string `json:"contentSecurityPolicy,omitempty"`
|
||||
CrossOriginResourcePolicy string `json:"crossOriginResourcePolicy,omitempty"`
|
||||
CrossOriginOpenerPolicy string `json:"crossOriginOpenerPolicy,omitempty"`
|
||||
CrossOriginEmbedderPolicy string `json:"crossOriginEmbedderPolicy,omitempty"`
|
||||
FrameOptions string `json:"frameOptions,omitempty"`
|
||||
ContentTypeOptions string `json:"contentTypeOptions,omitempty"`
|
||||
XSSProtection string `json:"xssProtection,omitempty"`
|
||||
ReferrerPolicy string `json:"referrerPolicy,omitempty"`
|
||||
CORSAllowedHeaders []string `json:"corsAllowedHeaders,omitempty"`
|
||||
CORSAllowedOrigins []string `json:"corsAllowedOrigins,omitempty"`
|
||||
CORSAllowedMethods []string `json:"corsAllowedMethods,omitempty"`
|
||||
StrictTransportSecurityMaxAge int `json:"strictTransportSecurityMaxAge"`
|
||||
CORSMaxAge int `json:"corsMaxAge"`
|
||||
StrictTransportSecurityPreload bool `json:"strictTransportSecurityPreload"`
|
||||
StrictTransportSecuritySubdomains bool `json:"strictTransportSecuritySubdomains"`
|
||||
CORSEnabled bool `json:"corsEnabled"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CORSAllowCredentials bool `json:"corsAllowCredentials"`
|
||||
StrictTransportSecurity bool `json:"strictTransportSecurity"`
|
||||
DisableServerHeader bool `json:"disableServerHeader"`
|
||||
DisablePoweredByHeader bool `json:"disablePoweredByHeader"`
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
+1
-1
@@ -17,8 +17,8 @@ type ShardedCache struct {
|
||||
|
||||
// cacheShard represents a single shard with its own mutex and data map.
|
||||
type cacheShard struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*shardedCacheItem
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// shardedCacheItem represents an item in the sharded cache with expiration.
|
||||
|
||||
+18
-33
@@ -18,33 +18,20 @@ var (
|
||||
// ResourceManager manages shared resources across all middleware instances
|
||||
// to prevent duplication and goroutine leaks when Traefik recreates middleware
|
||||
type ResourceManager struct {
|
||||
// HTTP clients shared across instances
|
||||
httpClients map[string]*http.Client
|
||||
clientsMu sync.RWMutex
|
||||
|
||||
// Caches shared across instances
|
||||
caches map[string]interface{}
|
||||
cachesMu sync.RWMutex
|
||||
|
||||
// Background tasks registry
|
||||
tasks map[string]*BackgroundTask
|
||||
tasksMu sync.RWMutex
|
||||
|
||||
// Goroutine pools for controlled concurrency
|
||||
pools map[string]*GoroutinePool
|
||||
poolsMu sync.RWMutex
|
||||
|
||||
// Reference counting for cleanup
|
||||
references map[string]*int32
|
||||
referencesMu sync.RWMutex
|
||||
|
||||
// Logger
|
||||
logger *Logger
|
||||
|
||||
// Shutdown coordination
|
||||
shutdownOnce sync.Once
|
||||
caches map[string]interface{}
|
||||
httpClients map[string]*http.Client
|
||||
tasks map[string]*BackgroundTask
|
||||
shutdownChan chan struct{}
|
||||
pools map[string]*GoroutinePool
|
||||
logger *Logger
|
||||
wg sync.WaitGroup
|
||||
cachesMu sync.RWMutex
|
||||
referencesMu sync.RWMutex
|
||||
poolsMu sync.RWMutex
|
||||
tasksMu sync.RWMutex
|
||||
clientsMu sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
}
|
||||
|
||||
// GetResourceManager returns the global singleton ResourceManager instance
|
||||
@@ -338,17 +325,15 @@ func (rm *ResourceManager) Shutdown(ctx context.Context) error {
|
||||
|
||||
// GoroutinePool provides a pool of workers for controlled concurrency
|
||||
type GoroutinePool struct {
|
||||
maxWorkers int
|
||||
taskQueue chan func()
|
||||
workerWG sync.WaitGroup
|
||||
shutdownOnce sync.Once
|
||||
shutdownChan chan struct{}
|
||||
logger *Logger
|
||||
started int32
|
||||
|
||||
// Condition variable for efficient Wait() without busy-polling
|
||||
taskCond *sync.Cond
|
||||
pendingTasks int64 // atomic counter for pending tasks
|
||||
workerWG sync.WaitGroup
|
||||
maxWorkers int
|
||||
pendingTasks int64
|
||||
shutdownOnce sync.Once
|
||||
started int32
|
||||
}
|
||||
|
||||
// NewGoroutinePool creates a new goroutine pool with the specified max workers
|
||||
@@ -517,10 +502,10 @@ func (p *GoroutinePool) Shutdown(ctx context.Context) error {
|
||||
// GenericCache provides a simple cache implementation for testing
|
||||
type GenericCache struct {
|
||||
data map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
logger *Logger
|
||||
stopChan chan struct{}
|
||||
ttl time.Duration
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewGenericCache creates a new generic cache
|
||||
|
||||
@@ -2,6 +2,8 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
@@ -97,6 +99,89 @@ func TestSingletonResourceManager(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MultiRealmMetadataRefreshTaskNaming", func(t *testing.T) {
|
||||
// This test verifies that different provider URLs generate different task names
|
||||
// which is critical for multi-realm Keycloak support (PR #88)
|
||||
|
||||
// Reset singletons for clean test state
|
||||
resetResourceManagerForTesting()
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
rm := GetResourceManager()
|
||||
|
||||
// Simulate different Keycloak realms
|
||||
providerURL1 := "https://keycloak.example.com/realms/realm1"
|
||||
providerURL2 := "https://keycloak.example.com/realms/realm2"
|
||||
|
||||
// Generate task names using the same logic as startMetadataRefresh
|
||||
hash1 := sha256.Sum256([]byte(providerURL1))
|
||||
taskName1 := "singleton-metadata-refresh-" + hex.EncodeToString(hash1[:])[0:6]
|
||||
|
||||
hash2 := sha256.Sum256([]byte(providerURL2))
|
||||
taskName2 := "singleton-metadata-refresh-" + hex.EncodeToString(hash2[:])[0:6]
|
||||
|
||||
// Verify task names are different
|
||||
if taskName1 == taskName2 {
|
||||
t.Errorf("Task names should be different for different provider URLs: %s vs %s", taskName1, taskName2)
|
||||
}
|
||||
|
||||
// Register both tasks
|
||||
task1Called := int32(0)
|
||||
task2Called := int32(0)
|
||||
|
||||
err := rm.RegisterBackgroundTask(taskName1, 100*time.Millisecond, func() {
|
||||
atomic.AddInt32(&task1Called, 1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Failed to register task 1: %v", err)
|
||||
}
|
||||
|
||||
err = rm.RegisterBackgroundTask(taskName2, 100*time.Millisecond, func() {
|
||||
atomic.AddInt32(&task2Called, 1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Failed to register task 2: %v", err)
|
||||
}
|
||||
|
||||
// Start both tasks
|
||||
_ = rm.StartBackgroundTask(taskName1)
|
||||
_ = rm.StartBackgroundTask(taskName2)
|
||||
|
||||
// Wait for tasks to execute
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
// Verify both tasks are running independently
|
||||
if !rm.IsTaskRunning(taskName1) {
|
||||
t.Error("Task 1 should be running")
|
||||
}
|
||||
if !rm.IsTaskRunning(taskName2) {
|
||||
t.Error("Task 2 should be running")
|
||||
}
|
||||
|
||||
// Verify both tasks were called (at least once)
|
||||
if atomic.LoadInt32(&task1Called) == 0 {
|
||||
t.Error("Task 1 should have been called at least once")
|
||||
}
|
||||
if atomic.LoadInt32(&task2Called) == 0 {
|
||||
t.Error("Task 2 should have been called at least once")
|
||||
}
|
||||
|
||||
// Stop both tasks
|
||||
_ = rm.StopBackgroundTask(taskName1)
|
||||
_ = rm.StopBackgroundTask(taskName2)
|
||||
|
||||
// Verify tasks are stopped
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if rm.IsTaskRunning(taskName1) {
|
||||
t.Error("Task 1 should be stopped")
|
||||
}
|
||||
if rm.IsTaskRunning(taskName2) {
|
||||
t.Error("Task 2 should be stopped")
|
||||
}
|
||||
|
||||
t.Logf("Successfully verified multi-realm task isolation: task1=%s, task2=%s", taskName1, taskName2)
|
||||
})
|
||||
|
||||
t.Run("ReferenceCountingCleanup", func(t *testing.T) {
|
||||
rm := GetResourceManager()
|
||||
|
||||
|
||||
+13
-20
@@ -9,26 +9,19 @@ import (
|
||||
|
||||
// TestConfig manages test execution configuration and performance settings
|
||||
type TestConfig struct {
|
||||
// Test execution modes
|
||||
ExtendedTests bool // Run extended/stress tests
|
||||
LongTests bool // Run long-running performance tests
|
||||
QuickMode bool // Quick smoke tests only
|
||||
|
||||
// Performance settings
|
||||
MaxConcurrency int // Maximum concurrent operations
|
||||
MaxIterations int // Maximum test iterations
|
||||
DefaultTimeout time.Duration // Default test timeout
|
||||
MemoryThreshold float64 // Memory growth threshold in MB
|
||||
GoroutineGrowth int // Acceptable goroutine growth
|
||||
|
||||
// Cache settings for tests
|
||||
CacheSize int // Default cache size for tests
|
||||
CleanupInterval time.Duration // Cleanup interval for tests
|
||||
|
||||
// Environment-specific overrides
|
||||
MemoryStressTest bool // Enable memory stress tests
|
||||
ConcurrencyTest bool // Enable high concurrency tests
|
||||
LeakDetection bool // Enable memory leak detection
|
||||
MemoryThreshold float64
|
||||
MaxConcurrency int
|
||||
MaxIterations int
|
||||
DefaultTimeout time.Duration
|
||||
GoroutineGrowth int
|
||||
CacheSize int
|
||||
CleanupInterval time.Duration
|
||||
LongTests bool
|
||||
QuickMode bool
|
||||
ExtendedTests bool
|
||||
MemoryStressTest bool
|
||||
ConcurrencyTest bool
|
||||
LeakDetection bool
|
||||
}
|
||||
|
||||
// NewTestConfig creates a test configuration based on flags and environment
|
||||
|
||||
@@ -21,11 +21,11 @@ type TestFramework struct {
|
||||
server *httptest.Server
|
||||
oidc *TraefikOidc
|
||||
config *Config
|
||||
cleanup []func()
|
||||
mocks *TestMocks
|
||||
fixtures *TestFixtures
|
||||
privateKey *rsa.PrivateKey
|
||||
publicKey *rsa.PublicKey
|
||||
cleanup []func()
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
@@ -457,12 +457,12 @@ func GetTestFramework() *TestFramework {
|
||||
|
||||
// TestScenario represents a test scenario
|
||||
type TestScenario struct {
|
||||
Name string
|
||||
Setup func(*TestFramework)
|
||||
Request func(*TestFramework) *http.Request
|
||||
ExpectedStatus int
|
||||
ExpectedBody string
|
||||
Validate func(*TestFramework, *httptest.ResponseRecorder)
|
||||
Name string
|
||||
ExpectedBody string
|
||||
ExpectedStatus int
|
||||
}
|
||||
|
||||
// RunScenarios executes a set of test scenarios
|
||||
|
||||
+15
-15
@@ -17,10 +17,10 @@ import (
|
||||
|
||||
// GlobalTestCleanup tracks and cleans up test resources
|
||||
type GlobalTestCleanup struct {
|
||||
mu sync.Mutex
|
||||
servers []*httptest.Server
|
||||
tasks []*BackgroundTask
|
||||
caches []interface{ Close() }
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var globalCleanup = &GlobalTestCleanup{}
|
||||
@@ -187,13 +187,13 @@ func GetTestDuration(normal time.Duration) time.Duration {
|
||||
|
||||
// UnifiedMockSession provides a comprehensive mock for the Session interface
|
||||
type UnifiedMockSession struct {
|
||||
mu sync.RWMutex
|
||||
data map[string]interface{}
|
||||
callCounts map[string]int64
|
||||
errors map[string]error
|
||||
delays map[string]time.Duration
|
||||
destroyed bool
|
||||
destroyCount int64
|
||||
mu sync.RWMutex
|
||||
destroyed bool
|
||||
}
|
||||
|
||||
// NewUnifiedMockSession creates a new mock session with default behavior
|
||||
@@ -326,13 +326,13 @@ func (m *UnifiedMockSession) GetDestroyCount() int64 {
|
||||
|
||||
// UnifiedMockTokenVerifier provides a comprehensive mock for token verification
|
||||
type UnifiedMockTokenVerifier struct {
|
||||
mu sync.RWMutex
|
||||
validTokens map[string]bool
|
||||
tokenMetadata map[string]map[string]interface{}
|
||||
callCounts map[string]int64
|
||||
errors map[string]error
|
||||
delays map[string]time.Duration
|
||||
verificationFunc func(string) error
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewUnifiedMockTokenVerifier creates a new mock token verifier
|
||||
@@ -414,19 +414,19 @@ func (m *UnifiedMockTokenVerifier) VerifyToken(token string) error {
|
||||
|
||||
// UnifiedMockTokenCache provides a comprehensive mock for token caching
|
||||
type UnifiedMockTokenCache struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]TestCacheEntry
|
||||
callCounts map[string]int64
|
||||
errors map[string]error
|
||||
delays map[string]time.Duration
|
||||
hitRate float64
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// TestCacheEntry represents a cached token entry for testing
|
||||
type TestCacheEntry struct {
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
Metadata map[string]interface{}
|
||||
Token string
|
||||
}
|
||||
|
||||
// NewUnifiedMockTokenCache creates a new mock token cache
|
||||
@@ -539,39 +539,39 @@ func (m *UnifiedMockTokenCache) Clear() {
|
||||
|
||||
// TableTestCase represents a standardized test case structure
|
||||
type TableTestCase struct {
|
||||
Name string
|
||||
Description string
|
||||
Input interface{}
|
||||
Expected interface{}
|
||||
ExpectedError error
|
||||
Setup func(*testing.T) error
|
||||
Teardown func(*testing.T) error
|
||||
Timeout time.Duration
|
||||
Name string
|
||||
Description string
|
||||
SkipReason string
|
||||
Tags []string
|
||||
Timeout time.Duration
|
||||
Parallel bool
|
||||
}
|
||||
|
||||
// MemoryLeakTestCase represents a test case specifically for memory leak detection
|
||||
type MemoryLeakTestCase struct {
|
||||
Operation func() error
|
||||
Setup func() error
|
||||
Teardown func() error
|
||||
Name string
|
||||
Description string
|
||||
Operation func() error
|
||||
Iterations int
|
||||
MaxGoroutineGrowth int
|
||||
MaxMemoryGrowthMB float64
|
||||
Setup func() error
|
||||
Teardown func() error
|
||||
GCBetweenRuns bool
|
||||
Timeout time.Duration
|
||||
GCBetweenRuns bool
|
||||
}
|
||||
|
||||
// TestSuiteRunner provides utilities for running table-driven tests
|
||||
type TestSuiteRunner struct {
|
||||
parallelTests bool
|
||||
timeout time.Duration
|
||||
beforeEach func(*testing.T)
|
||||
afterEach func(*testing.T)
|
||||
timeout time.Duration
|
||||
parallelTests bool
|
||||
}
|
||||
|
||||
// NewTestSuiteRunner creates a new test suite runner
|
||||
|
||||
+1
-1
@@ -20,9 +20,9 @@ func generateRandomString(length int) string {
|
||||
// Test createCaseInsensitiveStringMap function
|
||||
func TestCreateCaseInsensitiveStringMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
expected map[string]struct{}
|
||||
name string
|
||||
items []string
|
||||
expected map[string]struct{}
|
||||
}{
|
||||
{
|
||||
name: "Mixed case items",
|
||||
|
||||
+12
-12
@@ -16,18 +16,18 @@ import (
|
||||
// IntrospectionResponse represents the response from an OAuth 2.0 token introspection endpoint.
|
||||
// Per RFC 7662, this contains information about the token's validity and properties.
|
||||
type IntrospectionResponse struct {
|
||||
Active bool `json:"active"` // REQUIRED - whether the token is currently active
|
||||
Scope string `json:"scope,omitempty"` // Space-separated list of scopes
|
||||
ClientID string `json:"client_id,omitempty"` // Client identifier for the token
|
||||
Username string `json:"username,omitempty"` // Human-readable identifier for the resource owner
|
||||
TokenType string `json:"token_type,omitempty"` // Type of token (e.g., "Bearer")
|
||||
Exp int64 `json:"exp,omitempty"` // Expiration time (seconds since epoch)
|
||||
Iat int64 `json:"iat,omitempty"` // Issued at time (seconds since epoch)
|
||||
Nbf int64 `json:"nbf,omitempty"` // Not before time (seconds since epoch)
|
||||
Sub string `json:"sub,omitempty"` // Subject of the token
|
||||
Aud string `json:"aud,omitempty"` // Intended audience
|
||||
Iss string `json:"iss,omitempty"` // Issuer
|
||||
Jti string `json:"jti,omitempty"` // JWT ID
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Aud string `json:"aud,omitempty"`
|
||||
Iss string `json:"iss,omitempty"`
|
||||
Jti string `json:"jti,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
Nbf int64 `json:"nbf,omitempty"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
// introspectToken performs OAuth 2.0 Token Introspection (RFC 7662) for an opaque token.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user