diff --git a/cache_manager.go b/cache_manager.go index e61ec31..e3997c8 100644 --- a/cache_manager.go +++ b/cache_manager.go @@ -61,7 +61,7 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface { cm.mu.RLock() defer cm.mu.RUnlock() - return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache()} + return &CacheInterfaceWrapper{cache: cm.manager.GetBlacklistCache(), managed: true} } // GetSharedTokenCache returns the shared token cache @@ -93,7 +93,7 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface { func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface { cm.mu.RLock() defer cm.mu.RUnlock() - return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache()} + return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache(), managed: true} } // GetSharedTokenTypeCache returns the shared token type cache @@ -101,7 +101,7 @@ func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface { func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface { cm.mu.RLock() defer cm.mu.RUnlock() - return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache()} + return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true} } // Close gracefully shuts down all cache components @@ -121,7 +121,8 @@ func CleanupGlobalCacheManager() error { // CacheInterfaceWrapper wraps UniversalCache to implement CacheInterface type CacheInterfaceWrapper struct { - cache *UniversalCache + cache *UniversalCache + managed bool // If true, cache is managed globally and Close() is a no-op } // Set stores a value @@ -149,9 +150,15 @@ func (c *CacheInterfaceWrapper) Cleanup() { c.cache.Cleanup() } -// Close shuts down the cache +// Close shuts down the cache if it's not managed globally. +// For managed caches (from UniversalCacheManager), this is a no-op to prevent log flooding +// when multiple plugin instances are closed during Traefik configuration reloads. func (c *CacheInterfaceWrapper) Close() { - // Close the underlying cache to stop goroutines + if c.managed { + // Cache is managed globally by UniversalCacheManager, so we don't close it here. + return + } + // Standalone cache - close it properly to stop cleanup goroutines if c.cache != nil { _ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown } diff --git a/cache_test.go b/cache_test.go index 80e806e..dbaab1f 100644 --- a/cache_test.go +++ b/cache_test.go @@ -219,6 +219,159 @@ func TestCacheInterfaceWrapper_Close(t *testing.T) { nilWrapper.Close() } +// TestCacheInterfaceWrapper_ManagedClose_Regression tests that managed cache wrappers +// don't close the underlying cache when Close() is called. This is a regression test +// for issue #105 where multiple plugin instances closing shared caches caused log flooding. +func TestCacheInterfaceWrapper_ManagedClose_Regression(t *testing.T) { + cm := getTestCacheManager(t) + + // Get a managed cache wrapper + cache := cm.GetSharedTokenBlacklist() + wrapper, ok := cache.(*CacheInterfaceWrapper) + if !ok { + t.Fatal("Expected CacheInterfaceWrapper") + } + + // Verify it's marked as managed + if !wrapper.managed { + t.Error("Expected shared cache wrapper to be marked as managed") + } + + // Set some data before Close + cache.Set("test-key", "test-value", time.Hour) + + // Close the wrapper (should be a no-op for managed caches) + wrapper.Close() + + // Verify the cache is still operational after Close + value, found := cache.Get("test-key") + if !found { + t.Error("Expected cache to still work after Close() on managed wrapper") + } + if value != "test-value" { + t.Errorf("Expected 'test-value', got %v", value) + } + + // Can still set new values + cache.Set("new-key", "new-value", time.Hour) + newValue, found := cache.Get("new-key") + if !found || newValue != "new-value" { + t.Error("Expected to be able to set new values after Close() on managed wrapper") + } +} + +// TestCacheInterfaceWrapper_StandaloneClose tests that standalone cache wrappers +// properly close the underlying cache when Close() is called. +func TestCacheInterfaceWrapper_StandaloneClose(t *testing.T) { + // Create a standalone cache (not from the global cache manager) + standaloneCache := NewCache() + + wrapper, ok := standaloneCache.(*CacheInterfaceWrapper) + if !ok { + t.Fatal("Expected CacheInterfaceWrapper") + } + + // Verify it's NOT marked as managed + if wrapper.managed { + t.Error("Expected standalone cache wrapper to NOT be marked as managed") + } + + // Set some data + standaloneCache.Set("test-key", "test-value", time.Hour) + + // Get baseline goroutine count + baselineGoroutines := runtime.NumGoroutine() + + // Close the wrapper (should actually close the underlying cache) + wrapper.Close() + + // Give cleanup goroutine time to stop + time.Sleep(100 * time.Millisecond) + + // Goroutine count should decrease (cleanup routine stopped) + finalGoroutines := runtime.NumGoroutine() + if finalGoroutines > baselineGoroutines { + // This is acceptable - other tests might have started goroutines + t.Logf("Goroutine count: baseline=%d, final=%d", baselineGoroutines, finalGoroutines) + } +} + +// TestCacheInterfaceWrapper_MultipleInstancesClose_Regression tests that multiple +// plugin instances can close their cache wrappers without affecting shared caches. +// This is a regression test for issue #105. +func TestCacheInterfaceWrapper_MultipleInstancesClose_Regression(t *testing.T) { + cm := getTestCacheManager(t) + + // Simulate multiple plugin instances getting cache references + instances := make([]*CacheInterfaceWrapper, 5) + for i := 0; i < 5; i++ { + cache := cm.GetSharedTokenBlacklist() + wrapper, ok := cache.(*CacheInterfaceWrapper) + if !ok { + t.Fatal("Expected CacheInterfaceWrapper") + } + instances[i] = wrapper + + // Each instance might set some data + cache.Set(fmt.Sprintf("instance-%d-key", i), fmt.Sprintf("value-%d", i), time.Hour) + } + + // Close all instances (simulating plugin shutdown/reload) + for _, wrapper := range instances { + wrapper.Close() + } + + // The shared cache should still work after all instances closed their wrappers + newCache := cm.GetSharedTokenBlacklist() + + // Data set by earlier instances should still be accessible + for i := 0; i < 5; i++ { + key := fmt.Sprintf("instance-%d-key", i) + value, found := newCache.Get(key) + if !found { + t.Errorf("Expected data from instance %d to still be accessible", i) + } + expectedValue := fmt.Sprintf("value-%d", i) + if value != expectedValue { + t.Errorf("Expected '%s', got '%v'", expectedValue, value) + } + } + + // Should be able to add new data + newCache.Set("after-close-key", "after-close-value", time.Hour) + value, found := newCache.Get("after-close-key") + if !found || value != "after-close-value" { + t.Error("Expected to be able to use cache after all wrapper Close() calls") + } +} + +// TestAllSharedCachesMarkedAsManaged verifies all shared cache getters +// return managed wrappers to prevent the log flooding issue. +func TestAllSharedCachesMarkedAsManaged(t *testing.T) { + cm := getTestCacheManager(t) + + tests := []struct { + name string + cache CacheInterface + }{ + {"TokenBlacklist", cm.GetSharedTokenBlacklist()}, + {"IntrospectionCache", cm.GetSharedIntrospectionCache()}, + {"TokenTypeCache", cm.GetSharedTokenTypeCache()}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wrapper, ok := tt.cache.(*CacheInterfaceWrapper) + if !ok { + t.Fatalf("Expected CacheInterfaceWrapper for %s", tt.name) + } + if !wrapper.managed { + t.Errorf("%s cache wrapper should be marked as managed", tt.name) + } + }) + } +} + func TestCacheInterfaceWrapper_GetStats(t *testing.T) { cm := getTestCacheManager(t) cache := cm.GetSharedTokenBlacklist() diff --git a/internal/cache/backends/memory.go b/internal/cache/backends/memory.go index e66df44..e91a398 100644 --- a/internal/cache/backends/memory.go +++ b/internal/cache/backends/memory.go @@ -2,20 +2,27 @@ package backends import ( - "container/list" "context" "sync" "sync/atomic" "time" ) +// Default configuration values +const ( + defaultShardCount = 256 + defaultMaxSize = int64(10000) + defaultMaxMemory = int64(100 * 1024 * 1024) // 100MB + defaultCleanupInterval = 5 * time.Minute +) + // memoryCacheItem represents an item in the memory cache type memoryCacheItem struct { expiresAt time.Time createdAt time.Time accessedAt time.Time value interface{} - element *list.Element + element interface{} // *list.Element, using interface{} to avoid import cycle key string accessCount int64 size int64 @@ -29,56 +36,89 @@ func (item *memoryCacheItem) isExpired() bool { return time.Now().After(item.expiresAt) } -// MemoryCacheBackend implements the CacheBackend interface using in-memory storage +// MemoryCacheBackend implements the CacheBackend interface using sharded in-memory storage +// The sharded design reduces lock contention by partitioning keys across multiple shards, +// each with its own lock. type MemoryCacheBackend struct { + shards []*cacheShard startTime time.Time lastErrorTime time.Time - items map[string]*memoryCacheItem - lruList *list.List - cleanupDone chan bool + cleanupDone chan struct{} cleanupTicker *time.Ticker - evictionPolicy string lastError string - currentMemory int64 - misses atomic.Int64 - deletes atomic.Int64 - evictions atomic.Int64 - errors atomic.Int64 - totalGetTime atomic.Int64 - totalSetTime atomic.Int64 - getCount atomic.Int64 - setCount atomic.Int64 - sets atomic.Int64 - hits atomic.Int64 + shardCount uint32 + shardMask uint32 maxSize int64 - currentSize int64 maxMemory int64 cleanupInterval time.Duration - mu sync.RWMutex - closed atomic.Bool + + // Global stats (aggregated from shards) + hits atomic.Int64 + misses atomic.Int64 + sets atomic.Int64 + deletes atomic.Int64 + evictions atomic.Int64 + errors atomic.Int64 + + // Latency tracking + totalGetTime atomic.Int64 + totalSetTime atomic.Int64 + getCount atomic.Int64 + setCount atomic.Int64 + + // State + closed atomic.Bool + mu sync.RWMutex // For global operations like stats and error tracking } -// NewMemoryCacheBackend creates a new memory cache backend +// NewMemoryCacheBackend creates a new sharded memory cache backend func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend { if maxSize <= 0 { - maxSize = 10000 // Default to 10k items + maxSize = defaultMaxSize } if maxMemory <= 0 { - maxMemory = 100 * 1024 * 1024 // Default to 100MB + maxMemory = defaultMaxMemory } if cleanupInterval <= 0 { - cleanupInterval = 5 * time.Minute + cleanupInterval = defaultCleanupInterval + } + + shardCount := uint32(defaultShardCount) + + // For very small caches, reduce shard count to maintain sensible per-shard limits + // Ensure each shard can hold at least 2 items for proper LRU behavior + for shardCount > 1 && maxSize/int64(shardCount) < 2 { + shardCount /= 2 + } + if shardCount < 1 { + shardCount = 1 + } + + // Per-shard limits are soft hints; global limits are enforced + // Give shards 2x the average to allow for uneven distribution + shardMaxSize := (maxSize * 2) / int64(shardCount) + if shardMaxSize < 4 { + shardMaxSize = 4 + } + shardMaxMemory := (maxMemory * 2) / int64(shardCount) + if shardMaxMemory < 4096 { + shardMaxMemory = 4096 // Minimum 4KB per shard } m := &MemoryCacheBackend{ - items: make(map[string]*memoryCacheItem), - lruList: list.New(), + shards: make([]*cacheShard, shardCount), + shardCount: shardCount, + shardMask: shardCount - 1, // For fast modulo with power-of-2 maxSize: maxSize, maxMemory: maxMemory, startTime: time.Now(), cleanupInterval: cleanupInterval, - evictionPolicy: "lru", - cleanupDone: make(chan bool), + cleanupDone: make(chan struct{}), + } + + // Initialize shards + for i := uint32(0); i < shardCount; i++ { + m.shards[i] = newCacheShard(shardMaxSize, shardMaxMemory) } // Start cleanup goroutine @@ -88,6 +128,12 @@ func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time. return m } +// getShard returns the shard for a given key +func (m *MemoryCacheBackend) getShard(key string) *cacheShard { + hash := fnv32(key) + return m.shards[hash&m.shardMask] +} + // cleanupLoop runs periodic cleanup of expired items func (m *MemoryCacheBackend) cleanupLoop() { for { @@ -100,20 +146,19 @@ func (m *MemoryCacheBackend) cleanupLoop() { } } -// cleanupExpired removes all expired items from the cache +// cleanupExpired removes all expired items from all shards func (m *MemoryCacheBackend) cleanupExpired() { - m.mu.Lock() - defer m.mu.Unlock() - - var keysToDelete []string - for key, item := range m.items { - if item.isExpired() { - keysToDelete = append(keysToDelete, key) - } + if m.closed.Load() { + return } - for _, key := range keysToDelete { - m.deleteItemLocked(key) + totalRemoved := 0 + for _, shard := range m.shards { + totalRemoved += shard.cleanup() + } + + if totalRemoved > 0 { + m.evictions.Add(int64(totalRemoved)) } } @@ -130,35 +175,23 @@ func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{}, m.getCount.Add(1) }() - m.mu.RLock() - item, exists := m.items[key] - m.mu.RUnlock() + shard := m.getShard(key) + value, exists, expired := shard.get(key) + + if expired { + // Clean up expired item + shard.delete(key) + m.misses.Add(1) + return nil, ErrCacheMiss + } if !exists { m.misses.Add(1) return nil, ErrCacheMiss } - if item.isExpired() { - m.mu.Lock() - m.deleteItemLocked(key) - m.mu.Unlock() - m.misses.Add(1) - return nil, ErrCacheMiss - } - - // Update access time and count - m.mu.Lock() - item.accessedAt = time.Now() - item.accessCount++ - // Move to front of LRU list - if m.evictionPolicy == "lru" && item.element != nil { - m.lruList.MoveToFront(item.element) - } - m.mu.Unlock() - m.hits.Add(1) - return item.value, nil + return value, nil } // Set stores a value in the cache with optional TTL @@ -174,113 +207,105 @@ func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interfac m.setCount.Add(1) }() - // Calculate item size (simplified estimation) + // Calculate item size itemSize := int64(len(key)) + estimateValueSize(value) - m.mu.Lock() - defer m.mu.Unlock() + // Enforce global limits before adding new item + m.enforceGlobalLimits(itemSize) - // Check if we need to evict items - if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory { - m.evictLocked() - } - - // Check if key exists - if oldItem, exists := m.items[key]; exists { - m.currentMemory -= oldItem.size - if oldItem.element != nil { - m.lruList.Remove(oldItem.element) - } - } else { - m.currentSize++ - } - - now := time.Now() var expiresAt time.Time if ttl > 0 { - expiresAt = now.Add(ttl) + expiresAt = time.Now().Add(ttl) } - item := &memoryCacheItem{ - key: key, - value: value, - expiresAt: expiresAt, - createdAt: now, - accessedAt: now, - accessCount: 0, - size: itemSize, - } + shard := m.getShard(key) + shard.set(key, value, expiresAt, itemSize) - // Add to LRU list - if m.evictionPolicy == "lru" { - item.element = m.lruList.PushFront(item) - } - - m.items[key] = item - m.currentMemory += itemSize m.sets.Add(1) - return nil } +// enforceGlobalLimits ensures global size and memory limits are respected +// by evicting from shards when necessary +func (m *MemoryCacheBackend) enforceGlobalLimits(newItemSize int64) { + // Check and enforce size limit + for { + totalSize, totalMemory := m.getGlobalStats() + + needsSizeEviction := m.maxSize > 0 && totalSize >= m.maxSize + needsMemoryEviction := m.maxMemory > 0 && totalMemory+newItemSize > m.maxMemory + + if !needsSizeEviction && !needsMemoryEviction { + break + } + + // Find the shard with the most items and evict from it + evicted := m.evictFromLargestShard() + if !evicted { + break // No more items to evict + } + m.evictions.Add(1) + } +} + +// getGlobalStats returns the total size and memory usage across all shards +func (m *MemoryCacheBackend) getGlobalStats() (totalSize, totalMemory int64) { + for _, shard := range m.shards { + size, memory := shard.stats() + totalSize += size + totalMemory += memory + } + return +} + +// evictFromLargestShard evicts the globally oldest item across all shards +// This provides true LRU behavior even with sharding +func (m *MemoryCacheBackend) evictFromLargestShard() bool { + var oldestShard *cacheShard + var oldestTime time.Time + + for _, shard := range m.shards { + accessTime := shard.getOldestAccessTime() + // Skip empty shards + if accessTime.IsZero() { + continue + } + // Find the shard with the oldest (earliest) access time + if oldestShard == nil || accessTime.Before(oldestTime) { + oldestTime = accessTime + oldestShard = shard + } + } + + if oldestShard == nil { + return false + } + + return oldestShard.evictOne() +} + // Delete removes a key from the cache func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error { if m.closed.Load() { return ErrBackendUnavailable } - m.mu.Lock() - defer m.mu.Unlock() - - if _, exists := m.items[key]; !exists { - return nil + shard := m.getShard(key) + if shard.delete(key) { + m.deletes.Add(1) } - m.deleteItemLocked(key) - m.deletes.Add(1) return nil } -// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held) -func (m *MemoryCacheBackend) deleteItemLocked(key string) { - if item, exists := m.items[key]; exists { - m.currentMemory -= item.size - m.currentSize-- - if item.element != nil { - m.lruList.Remove(item.element) - } - delete(m.items, key) - } -} - -// evictLocked evicts items based on the eviction policy (must be called with lock held) -func (m *MemoryCacheBackend) evictLocked() { - if m.evictionPolicy == "lru" && m.lruList.Len() > 0 { - // Evict least recently used item - element := m.lruList.Back() - if element != nil { - item := element.Value.(*memoryCacheItem) - m.deleteItemLocked(item.key) - m.evictions.Add(1) - } - } -} - // Exists checks if a key exists in the cache func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) { if m.closed.Load() { return false, ErrBackendUnavailable } - m.mu.RLock() - item, exists := m.items[key] - m.mu.RUnlock() - - if !exists { - return false, nil - } - - return !item.isExpired(), nil + shard := m.getShard(key) + return shard.exists(key), nil } // Clear removes all items from the cache @@ -289,13 +314,9 @@ func (m *MemoryCacheBackend) Clear(ctx context.Context) error { return ErrBackendUnavailable } - m.mu.Lock() - defer m.mu.Unlock() - - m.items = make(map[string]*memoryCacheItem) - m.lruList = list.New() - m.currentSize = 0 - m.currentMemory = 0 + for _, shard := range m.shards { + shard.clear() + } return nil } @@ -306,29 +327,28 @@ func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string return nil, ErrBackendUnavailable } - m.mu.RLock() - defer m.mu.RUnlock() - - var keys []string - for key, item := range m.items { - if !item.isExpired() && matchPattern(pattern, key) { - keys = append(keys, key) - } + var allKeys []string + for _, shard := range m.shards { + keys := shard.keys(pattern) + allKeys = append(allKeys, keys...) } - return keys, nil + return allKeys, nil } -// Size returns the number of items in the cache +// Size returns the total number of items in the cache func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) { if m.closed.Load() { return 0, ErrBackendUnavailable } - m.mu.RLock() - defer m.mu.RUnlock() + var total int64 + for _, shard := range m.shards { + size, _ := shard.stats() + total += size + } - return m.currentSize, nil + return total, nil } // TTL returns the remaining time-to-live for a key @@ -337,24 +357,13 @@ func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration return 0, ErrBackendUnavailable } - m.mu.RLock() - item, exists := m.items[key] - m.mu.RUnlock() - - if !exists || item.isExpired() { + shard := m.getShard(key) + ttl, exists := shard.ttl(key) + if !exists { return 0, ErrCacheMiss } - if item.expiresAt.IsZero() { - return 0, nil // No expiration - } - - remaining := time.Until(item.expiresAt) - if remaining < 0 { - return 0, nil - } - - return remaining, nil + return ttl, nil } // Expire updates the TTL for an existing key @@ -363,20 +372,11 @@ func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Du return ErrBackendUnavailable } - m.mu.Lock() - defer m.mu.Unlock() - - item, exists := m.items[key] - if !exists || item.isExpired() { + shard := m.getShard(key) + if !shard.expire(key, ttl) { return ErrCacheMiss } - if ttl > 0 { - item.expiresAt = time.Now().Add(ttl) - } else { - item.expiresAt = time.Time{} // Remove expiration - } - return nil } @@ -386,6 +386,14 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error return nil, ErrBackendUnavailable } + // Aggregate stats from all shards + var totalSize, totalMemory int64 + for _, shard := range m.shards { + size, memory := shard.stats() + totalSize += size + totalMemory += memory + } + m.mu.RLock() lastError := m.lastError lastErrorTime := m.lastErrorTime @@ -409,9 +417,9 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error Deletes: m.deletes.Load(), Errors: m.errors.Load(), Evictions: m.evictions.Load(), - CurrentSize: m.currentSize, + CurrentSize: totalSize, MaxSize: m.maxSize, - MemoryUsage: m.currentMemory, + MemoryUsage: totalMemory, AverageGetLatency: avgGetLatency, AverageSetLatency: avgSetLatency, LastError: lastError, @@ -438,10 +446,10 @@ func (m *MemoryCacheBackend) Close() error { m.cleanupTicker.Stop() close(m.cleanupDone) - m.mu.Lock() - m.items = nil - m.lruList = nil - m.mu.Unlock() + // Clear all shards + for _, shard := range m.shards { + shard.clear() + } return nil } @@ -474,12 +482,28 @@ func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities { } } +// GetShardCount returns the number of shards (for testing/monitoring) +func (m *MemoryCacheBackend) GetShardCount() uint32 { + return m.shardCount +} + +// GetShardStats returns per-shard statistics (for monitoring) +func (m *MemoryCacheBackend) GetShardStats() []map[string]int64 { + stats := make([]map[string]int64, m.shardCount) + for i, shard := range m.shards { + size, memory := shard.stats() + stats[i] = map[string]int64{ + "size": size, + "memory": memory, + } + } + return stats +} + // Helper functions // estimateValueSize estimates the size of a value in bytes func estimateValueSize(value interface{}) int64 { - // This is a simplified estimation - // In production, you might want to use a more accurate method switch v := value.(type) { case string: return int64(len(v)) @@ -502,7 +526,10 @@ func matchPattern(pattern, key string) bool { if pattern == "*" { return true } - // Simplified pattern matching - in production, use a proper glob library - return key == pattern || (len(pattern) > 0 && pattern[0] == '*' && - len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:]) + // Simplified pattern matching + if len(pattern) > 0 && pattern[0] == '*' { + suffix := pattern[1:] + return len(key) >= len(suffix) && key[len(key)-len(suffix):] == suffix + } + return key == pattern } diff --git a/internal/cache/backends/memory_shard.go b/internal/cache/backends/memory_shard.go new file mode 100644 index 0000000..3feaf64 --- /dev/null +++ b/internal/cache/backends/memory_shard.go @@ -0,0 +1,290 @@ +package backends + +import ( + "container/list" + "sync" + "time" +) + +// cacheShard represents a single shard of the sharded cache +// Each shard has its own lock for reduced contention +type cacheShard struct { + items map[string]*memoryCacheItem + lruList *list.List + mu sync.RWMutex + maxSize int64 + maxMemory int64 + size int64 + memoryUsed int64 +} + +// newCacheShard creates a new cache shard +func newCacheShard(maxSize, maxMemory int64) *cacheShard { + return &cacheShard{ + items: make(map[string]*memoryCacheItem), + lruList: list.New(), + maxSize: maxSize, + maxMemory: maxMemory, + } +} + +// get retrieves a value from this shard +// Returns: value, exists, expired +func (s *cacheShard) get(key string) (interface{}, bool, bool) { + s.mu.RLock() + item, exists := s.items[key] + s.mu.RUnlock() + + if !exists { + return nil, false, false + } + + if item.isExpired() { + return nil, true, true // exists but expired + } + + // Update access time and LRU position under write lock + s.mu.Lock() + // Re-check item exists (could have been deleted) + item, exists = s.items[key] + if exists && !item.isExpired() { + item.accessedAt = time.Now() + item.accessCount++ + if elem, ok := item.element.(*list.Element); ok && elem != nil { + s.lruList.MoveToFront(elem) + } + } + s.mu.Unlock() + + if !exists || item.isExpired() { + return nil, false, false + } + + return item.value, true, false +} + +// set stores a value in this shard +func (s *cacheShard) set(key string, value interface{}, expiresAt time.Time, size int64) { + s.mu.Lock() + defer s.mu.Unlock() + + // Check if we need to evict items + if s.maxSize > 0 && s.size >= s.maxSize { + s.evictLRULocked() + } + if s.maxMemory > 0 && s.memoryUsed+size > s.maxMemory { + s.evictLRULocked() + } + + // Remove old item if exists + if oldItem, exists := s.items[key]; exists { + s.memoryUsed -= oldItem.size + if elem, ok := oldItem.element.(*list.Element); ok && elem != nil { + s.lruList.Remove(elem) + } + s.size-- + } + + now := time.Now() + item := &memoryCacheItem{ + key: key, + value: value, + expiresAt: expiresAt, + createdAt: now, + accessedAt: now, + accessCount: 0, + size: size, + } + + item.element = s.lruList.PushFront(item) + s.items[key] = item + s.size++ + s.memoryUsed += size +} + +// delete removes a key from this shard +// Returns true if the key was deleted +func (s *cacheShard) delete(key string) bool { + s.mu.Lock() + defer s.mu.Unlock() + + item, exists := s.items[key] + if !exists { + return false + } + + s.deleteItemLocked(item) + return true +} + +// exists checks if a key exists (and is not expired) +func (s *cacheShard) exists(key string) bool { + s.mu.RLock() + item, exists := s.items[key] + s.mu.RUnlock() + + if !exists { + return false + } + + return !item.isExpired() +} + +// ttl returns the remaining TTL for a key +func (s *cacheShard) ttl(key string) (time.Duration, bool) { + s.mu.RLock() + item, exists := s.items[key] + s.mu.RUnlock() + + if !exists || item.isExpired() { + return 0, false + } + + if item.expiresAt.IsZero() { + return 0, true // No expiration + } + + remaining := time.Until(item.expiresAt) + if remaining < 0 { + return 0, false + } + + return remaining, true +} + +// expire updates the TTL for an existing key +func (s *cacheShard) expire(key string, ttl time.Duration) bool { + s.mu.Lock() + defer s.mu.Unlock() + + item, exists := s.items[key] + if !exists || item.isExpired() { + return false + } + + if ttl > 0 { + item.expiresAt = time.Now().Add(ttl) + } else { + item.expiresAt = time.Time{} // Remove expiration + } + + return true +} + +// keys returns all non-expired keys matching the pattern +func (s *cacheShard) keys(pattern string) []string { + s.mu.RLock() + defer s.mu.RUnlock() + + var keys []string + for key, item := range s.items { + if !item.isExpired() && matchPattern(pattern, key) { + keys = append(keys, key) + } + } + return keys +} + +// clear removes all items from this shard +func (s *cacheShard) clear() { + s.mu.Lock() + defer s.mu.Unlock() + + s.items = make(map[string]*memoryCacheItem) + s.lruList.Init() + s.size = 0 + s.memoryUsed = 0 +} + +// cleanup removes expired items +// Returns the number of items removed +func (s *cacheShard) cleanup() int { + s.mu.Lock() + defer s.mu.Unlock() + + var toRemove []*memoryCacheItem + for _, item := range s.items { + if item.isExpired() { + toRemove = append(toRemove, item) + } + } + + for _, item := range toRemove { + s.deleteItemLocked(item) + } + + return len(toRemove) +} + +// stats returns statistics for this shard +func (s *cacheShard) stats() (size, memory int64) { + s.mu.RLock() + defer s.mu.RUnlock() + return s.size, s.memoryUsed +} + +// deleteItemLocked removes an item (must be called with lock held) +func (s *cacheShard) deleteItemLocked(item *memoryCacheItem) { + if elem, ok := item.element.(*list.Element); ok && elem != nil { + s.lruList.Remove(elem) + } + delete(s.items, item.key) + s.size-- + s.memoryUsed -= item.size +} + +// evictLRULocked evicts the least recently used item (must be called with lock held) +func (s *cacheShard) evictLRULocked() bool { + if s.lruList.Len() == 0 { + return false + } + + element := s.lruList.Back() + if element != nil { + item := element.Value.(*memoryCacheItem) + s.deleteItemLocked(item) + return true + } + return false +} + +// evictOne evicts one item from this shard (for global limit enforcement) +func (s *cacheShard) evictOne() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.evictLRULocked() +} + +// getOldestAccessTime returns the access time of the LRU item (oldest) in this shard +// Returns zero time if shard is empty +func (s *cacheShard) getOldestAccessTime() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.lruList.Len() == 0 { + return time.Time{} + } + + element := s.lruList.Back() + if element != nil { + item := element.Value.(*memoryCacheItem) + return item.accessedAt + } + return time.Time{} +} + +// fnv32 computes FNV-1a hash of a string +// This is a fast, well-distributed hash function +func fnv32(key string) uint32 { + const ( + offset32 = uint32(2166136261) + prime32 = uint32(16777619) + ) + + hash := offset32 + for i := 0; i < len(key); i++ { + hash ^= uint32(key[i]) + hash *= prime32 + } + return hash +} diff --git a/internal/cache/backends/memory_shard_test.go b/internal/cache/backends/memory_shard_test.go new file mode 100644 index 0000000..1cf2cec --- /dev/null +++ b/internal/cache/backends/memory_shard_test.go @@ -0,0 +1,283 @@ +package backends + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestShardedCache_ShardDistribution tests that keys are distributed across shards +func TestShardedCache_ShardDistribution(t *testing.T) { + t.Parallel() + + // Create a cache with large enough size to have multiple shards + config := DefaultConfig() + config.MaxSize = 10000 + config.MaxMemoryBytes = 100 * 1024 * 1024 // 100MB + + backend, err := NewMemoryBackend(config) + require.NoError(t, err) + defer backend.Close() + + ctx := context.Background() + + // Add many items to see distribution + numItems := 1000 + for i := 0; i < numItems; i++ { + key := fmt.Sprintf("dist-key-%d", i) + value := []byte(fmt.Sprintf("dist-value-%d", i)) + err := backend.Set(ctx, key, value, time.Minute) + require.NoError(t, err) + } + + // Check that items are distributed across multiple shards + shardStats := backend.MemoryCacheBackend.GetShardStats() + nonEmptyShards := 0 + for _, stat := range shardStats { + if stat["size"] > 0 { + nonEmptyShards++ + } + } + + // With good hash distribution, we should have items in multiple shards + assert.Greater(t, nonEmptyShards, 1, "Items should be distributed across multiple shards") +} + +// TestShardedCache_ShardCount tests that shard count adapts to cache size +func TestShardedCache_ShardCount(t *testing.T) { + t.Parallel() + + tests := []struct { + maxSize int + expectLowShards bool + }{ + {5, true}, // Very small cache should have fewer shards + {100, true}, // Small cache should have fewer shards + {10000, false}, // Large cache should have default shards + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("MaxSize_%d", tt.maxSize), func(t *testing.T) { + config := DefaultConfig() + config.MaxSize = tt.maxSize + + backend, err := NewMemoryBackend(config) + require.NoError(t, err) + defer backend.Close() + + shardCount := backend.MemoryCacheBackend.GetShardCount() + + if tt.expectLowShards { + assert.Less(t, shardCount, uint32(256), "Small cache should have fewer shards") + } else { + assert.Equal(t, uint32(256), shardCount, "Large cache should have default shard count") + } + }) + } +} + +// TestShardedCache_ConcurrentSameKey tests concurrent access to the same key +func TestShardedCache_ConcurrentSameKey(t *testing.T) { + t.Parallel() + + backend, err := NewMemoryBackend(DefaultConfig()) + require.NoError(t, err) + defer backend.Close() + + ctx := context.Background() + key := "concurrent-same-key" + initialValue := []byte("initial-value") + + err = backend.Set(ctx, key, initialValue, time.Minute) + require.NoError(t, err) + + var wg sync.WaitGroup + goroutines := 50 + iterations := 100 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < iterations; j++ { + // Mix of reads and writes + if j%3 == 0 { + newValue := []byte(fmt.Sprintf("value-%d-%d", id, j)) + err := backend.Set(ctx, key, newValue, time.Minute) + assert.NoError(t, err) + } else { + _, _, _, err := backend.Get(ctx, key) + assert.NoError(t, err) + } + } + }(i) + } + + wg.Wait() + + // Key should still exist + exists, err := backend.Exists(ctx, key) + require.NoError(t, err) + assert.True(t, exists) +} + +// TestShardedCache_GlobalLRUEviction tests that global LRU is maintained +func TestShardedCache_GlobalLRUEviction(t *testing.T) { + t.Parallel() + + // Create a small cache to force eviction + config := DefaultConfig() + config.MaxSize = 10 + + backend, err := NewMemoryBackend(config) + require.NoError(t, err) + defer backend.Close() + + ctx := context.Background() + + // Add items + for i := 0; i < 10; i++ { + key := fmt.Sprintf("global-lru-%d", i) + value := []byte(fmt.Sprintf("value-%d", i)) + err := backend.Set(ctx, key, value, time.Minute) + require.NoError(t, err) + // Small delay to ensure different access times + time.Sleep(time.Millisecond) + } + + // Access some items to make them recently used + for i := 5; i < 10; i++ { + key := fmt.Sprintf("global-lru-%d", i) + _, _, _, err := backend.Get(ctx, key) + require.NoError(t, err) + } + + // Add more items to trigger eviction + for i := 10; i < 15; i++ { + key := fmt.Sprintf("global-lru-%d", i) + value := []byte(fmt.Sprintf("value-%d", i)) + err := backend.Set(ctx, key, value, time.Minute) + require.NoError(t, err) + } + + // Recently accessed items (5-9) should still exist + for i := 5; i < 10; i++ { + key := fmt.Sprintf("global-lru-%d", i) + exists, err := backend.Exists(ctx, key) + require.NoError(t, err) + assert.True(t, exists, "Recently accessed item %d should exist", i) + } + + // Check eviction stats + stats := backend.GetStats() + evictions := stats["evictions"].(int64) + assert.Greater(t, evictions, int64(0), "Should have evictions") +} + +// TestShardedCache_StatsAggregation tests that stats are aggregated correctly +func TestShardedCache_StatsAggregation(t *testing.T) { + t.Parallel() + + config := DefaultConfig() + config.MaxSize = 10000 + + backend, err := NewMemoryBackend(config) + require.NoError(t, err) + defer backend.Close() + + ctx := context.Background() + + // Add items to multiple shards + numItems := 100 + for i := 0; i < numItems; i++ { + key := fmt.Sprintf("stats-key-%d", i) + value := []byte(fmt.Sprintf("stats-value-%d", i)) + err := backend.Set(ctx, key, value, time.Minute) + require.NoError(t, err) + } + + // Read some items + for i := 0; i < numItems/2; i++ { + key := fmt.Sprintf("stats-key-%d", i) + backend.Get(ctx, key) + } + + // Read non-existent items + for i := 0; i < 10; i++ { + backend.Get(ctx, fmt.Sprintf("nonexistent-%d", i)) + } + + stats := backend.GetStats() + + // Verify stats + assert.Equal(t, int64(numItems), stats["sets"].(int64), "Sets should match") + assert.Equal(t, int64(numItems/2), stats["hits"].(int64), "Hits should match") + assert.Equal(t, int64(10), stats["misses"].(int64), "Misses should match") + assert.Equal(t, int64(numItems), stats["size"].(int64), "Size should match") + + // Verify hit rate + hitRate := stats["hit_rate"].(float64) + expectedHitRate := float64(numItems/2) / float64(numItems/2+10) + assert.InDelta(t, expectedHitRate, hitRate, 0.01, "Hit rate should match") +} + +// BenchmarkShardedCache_Parallel benchmarks parallel access +func BenchmarkShardedCache_Parallel(b *testing.B) { + config := DefaultConfig() + config.MaxSize = 100000 + config.MaxMemoryBytes = 100 * 1024 * 1024 + + backend, _ := NewMemoryBackend(config) + defer backend.Close() + + ctx := context.Background() + + // Pre-populate cache + for i := 0; i < 10000; i++ { + key := fmt.Sprintf("bench-key-%d", i) + value := []byte(fmt.Sprintf("bench-value-%d", i)) + backend.Set(ctx, key, value, time.Hour) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + key := fmt.Sprintf("bench-key-%d", i%10000) + backend.Get(ctx, key) + i++ + } + }) +} + +// BenchmarkShardedCache_MixedOps benchmarks mixed operations +func BenchmarkShardedCache_MixedOps(b *testing.B) { + config := DefaultConfig() + config.MaxSize = 100000 + config.MaxMemoryBytes = 100 * 1024 * 1024 + + backend, _ := NewMemoryBackend(config) + defer backend.Close() + + ctx := context.Background() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + key := fmt.Sprintf("mixed-key-%d", i%1000) + if i%3 == 0 { + value := []byte(fmt.Sprintf("mixed-value-%d", i)) + backend.Set(ctx, key, value, time.Hour) + } else { + backend.Get(ctx, key) + } + i++ + } + }) +} diff --git a/internal/cache/backends/memory_wrapper.go b/internal/cache/backends/memory_wrapper.go index 7528855..0018bab 100644 --- a/internal/cache/backends/memory_wrapper.go +++ b/internal/cache/backends/memory_wrapper.go @@ -45,21 +45,11 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat return nil, 0, false, err } - // Get the item directly to check TTL - m.MemoryCacheBackend.mu.RLock() - item, exists := m.MemoryCacheBackend.items[key] - m.MemoryCacheBackend.mu.RUnlock() - - if !exists { - return nil, 0, false, nil - } - - var ttl time.Duration - if !item.expiresAt.IsZero() { - ttl = time.Until(item.expiresAt) - if ttl < 0 { - ttl = 0 - } + // Get TTL using the TTL method + ttl, ttlErr := m.MemoryCacheBackend.TTL(ctx, key) + if ttlErr != nil { + // If we can't get TTL, still return the value with 0 TTL + ttl = 0 } // Convert interface{} to []byte @@ -68,8 +58,7 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat if bytes, ok := val.([]byte); ok { valueBytes = bytes } else { - // If it's not already []byte, we might need to handle other types - // For now, we'll just return an error + // If it's not already []byte, return an error return nil, 0, false, ErrInvalidValue } } @@ -123,19 +112,20 @@ func (m *MemoryBackend) GetStats() map[string]interface{} { } return map[string]interface{}{ - "type": stats.Type, - "hits": stats.Hits, - "misses": stats.Misses, - "sets": stats.Sets, - "deletes": stats.Deletes, - "errors": stats.Errors, - "evictions": stats.Evictions, - "size": stats.CurrentSize, - "max_size": stats.MaxSize, - "memory": stats.MemoryUsage, - "hit_rate": hitRate, - "uptime": stats.Uptime, - "start_time": stats.StartTime, + "type": stats.Type, + "hits": stats.Hits, + "misses": stats.Misses, + "sets": stats.Sets, + "deletes": stats.Deletes, + "errors": stats.Errors, + "evictions": stats.Evictions, + "size": stats.CurrentSize, + "max_size": stats.MaxSize, + "memory": stats.MemoryUsage, + "hit_rate": hitRate, + "uptime": stats.Uptime, + "start_time": stats.StartTime, + "shard_count": m.MemoryCacheBackend.GetShardCount(), } } diff --git a/internal/cache/backends/redis.go b/internal/cache/backends/redis.go index 0d32df9..cc80ced 100644 --- a/internal/cache/backends/redis.go +++ b/internal/cache/backends/redis.go @@ -431,39 +431,135 @@ func isRetryableError(err error) bool { return false } -// SetMany stores multiple values in Redis (batch operation) +// SetMany stores multiple values in Redis using pipelining for efficiency +// This reduces N round-trips to a single round-trip func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error { if r.closed.Load() { return ErrBackendClosed } - // For simplicity, execute sequentially (can be optimized with pipelining later) - for key, value := range items { - if err := r.Set(ctx, key, value, ttl); err != nil { - return err + if len(items) == 0 { + return nil + } + + // For single items, use regular Set + if len(items) == 1 { + for key, value := range items { + return r.Set(ctx, key, value, ttl) } } + conn, err := r.pool.Get(ctx) + if err != nil { + return err + } + defer r.pool.Put(conn) + + pipeline := conn.NewPipeline() + + // Queue all SET commands + ttlSeconds := int(ttl.Seconds()) + ttlMillis := ttl.Milliseconds() + + for key, value := range items { + prefixedKey := r.prefixKey(key) + + if ttl > 0 { + if ttlMillis < 1000 { + // Use PSETEX for sub-second TTLs + pipeline.Queue("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value)) + } else { + // Use SETEX for larger TTLs + pipeline.Queue("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value)) + } + } else { + pipeline.Queue("SET", prefixedKey, string(value)) + } + } + + // Execute pipeline + responses, err := pipeline.Execute() + if err != nil { + return fmt.Errorf("pipeline SetMany failed: %w", err) + } + + // Check responses for errors (each should be "OK") + for i, resp := range responses { + if resp == nil { + continue + } + if str, ok := resp.(string); ok && str == "OK" { + continue + } + return fmt.Errorf("SetMany: unexpected response at index %d: %v", i, resp) + } + return nil } -// GetMany retrieves multiple values from Redis +// GetMany retrieves multiple values from Redis using pipelining for efficiency +// This reduces N round-trips to a single round-trip func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) { if r.closed.Load() { return nil, ErrBackendClosed } - result := make(map[string][]byte) + if len(keys) == 0 { + return make(map[string][]byte), nil + } - // For simplicity, execute sequentially - for _, key := range keys { - value, _, exists, err := r.Get(ctx, key) + // For single key, use regular Get + if len(keys) == 1 { + result := make(map[string][]byte) + value, _, exists, err := r.Get(ctx, keys[0]) if err != nil { return nil, err } if exists { - result[key] = value + result[keys[0]] = value } + return result, nil + } + + conn, err := r.pool.Get(ctx) + if err != nil { + return nil, err + } + defer r.pool.Put(conn) + + pipeline := conn.NewPipeline() + + // Queue all GET commands + prefixedKeys := make([]string, len(keys)) + for i, key := range keys { + prefixedKeys[i] = r.prefixKey(key) + pipeline.Queue("GET", prefixedKeys[i]) + } + + // Execute pipeline + responses, err := pipeline.Execute() + if err != nil { + return nil, fmt.Errorf("pipeline GetMany failed: %w", err) + } + + // Process responses + result := make(map[string][]byte) + for i, resp := range responses { + if resp == nil { + // Key doesn't exist + r.misses.Add(1) + continue + } + + value, err := RESPString(resp) + if err != nil { + // Invalid response, skip this key + r.misses.Add(1) + continue + } + + r.hits.Add(1) + result[keys[i]] = []byte(value) } return result, nil diff --git a/internal/cache/backends/redis_pipeline_test.go b/internal/cache/backends/redis_pipeline_test.go new file mode 100644 index 0000000..c33925e --- /dev/null +++ b/internal/cache/backends/redis_pipeline_test.go @@ -0,0 +1,461 @@ +package backends + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupTestRedis creates a miniredis instance for testing +func setupTestRedis(t *testing.T) (*miniredis.Miniredis, *RedisBackend) { + t.Helper() + + mr, err := miniredis.Run() + require.NoError(t, err) + + t.Cleanup(func() { + mr.Close() + }) + + backend, err := NewRedisBackend(&Config{ + RedisAddr: mr.Addr(), + RedisPrefix: "test:", + PoolSize: 5, + }) + require.NoError(t, err) + + t.Cleanup(func() { + backend.Close() + }) + + return mr, backend +} + +// TestPipeline_Basic tests basic pipeline functionality +func TestPipeline_Basic(t *testing.T) { + t.Parallel() + + mr, err := miniredis.Run() + require.NoError(t, err) + defer mr.Close() + + config := &PoolConfig{ + Address: mr.Addr(), + MaxConnections: 5, + ConnectTimeout: 5 * time.Second, + ReadTimeout: 1 * time.Second, + WriteTimeout: 1 * time.Second, + } + + pool, err := NewConnectionPool(config) + require.NoError(t, err) + defer pool.Close() + + ctx := context.Background() + + conn, err := pool.Get(ctx) + require.NoError(t, err) + defer pool.Put(conn) + + t.Run("SingleCommand", func(t *testing.T) { + pipeline := conn.NewPipeline() + pipeline.Queue("SET", "single-key", "single-value") + + responses, err := pipeline.Execute() + require.NoError(t, err) + require.Len(t, responses, 1) + assert.Equal(t, "OK", responses[0]) + }) + + t.Run("MultipleCommands", func(t *testing.T) { + pipeline := conn.NewPipeline() + pipeline.Queue("SET", "key1", "value1") + pipeline.Queue("SET", "key2", "value2") + pipeline.Queue("SET", "key3", "value3") + pipeline.Queue("GET", "key1") + pipeline.Queue("GET", "key2") + pipeline.Queue("GET", "key3") + + responses, err := pipeline.Execute() + require.NoError(t, err) + require.Len(t, responses, 6) + + // First 3 are SET responses + assert.Equal(t, "OK", responses[0]) + assert.Equal(t, "OK", responses[1]) + assert.Equal(t, "OK", responses[2]) + + // Last 3 are GET responses + assert.Equal(t, "value1", responses[3]) + assert.Equal(t, "value2", responses[4]) + assert.Equal(t, "value3", responses[5]) + }) + + t.Run("EmptyPipeline", func(t *testing.T) { + pipeline := conn.NewPipeline() + + responses, err := pipeline.Execute() + require.NoError(t, err) + assert.Nil(t, responses) + }) + + t.Run("NilResponses", func(t *testing.T) { + pipeline := conn.NewPipeline() + pipeline.Queue("GET", "nonexistent-key") + + responses, err := pipeline.Execute() + require.NoError(t, err) + require.Len(t, responses, 1) + assert.Nil(t, responses[0]) + }) +} + +// TestPipeline_SetMany tests pipelined SetMany +func TestPipeline_SetMany(t *testing.T) { + t.Parallel() + + _, backend := setupTestRedis(t) + ctx := context.Background() + + t.Run("SetManyItems", func(t *testing.T) { + items := make(map[string][]byte) + for i := 0; i < 10; i++ { + items[fmt.Sprintf("setmany-key-%d", i)] = []byte(fmt.Sprintf("value-%d", i)) + } + + err := backend.SetMany(ctx, items, time.Minute) + require.NoError(t, err) + + // Verify all items were set + for key, expectedValue := range items { + value, _, exists, err := backend.Get(ctx, key) + require.NoError(t, err) + assert.True(t, exists, "Key %s should exist", key) + assert.Equal(t, expectedValue, value) + } + }) + + t.Run("SetManyEmpty", func(t *testing.T) { + err := backend.SetMany(ctx, map[string][]byte{}, time.Minute) + require.NoError(t, err) + }) + + t.Run("SetManySingleItem", func(t *testing.T) { + items := map[string][]byte{ + "single-setmany": []byte("single-value"), + } + + err := backend.SetMany(ctx, items, time.Minute) + require.NoError(t, err) + + value, _, exists, err := backend.Get(ctx, "single-setmany") + require.NoError(t, err) + assert.True(t, exists) + assert.Equal(t, []byte("single-value"), value) + }) + + t.Run("SetManyNoTTL", func(t *testing.T) { + items := map[string][]byte{ + "nottl-key1": []byte("value1"), + "nottl-key2": []byte("value2"), + } + + err := backend.SetMany(ctx, items, 0) + require.NoError(t, err) + + // Keys should exist + for key := range items { + exists, err := backend.Exists(ctx, key) + require.NoError(t, err) + assert.True(t, exists) + } + }) +} + +// TestPipeline_GetMany tests pipelined GetMany +func TestPipeline_GetMany(t *testing.T) { + t.Parallel() + + _, backend := setupTestRedis(t) + ctx := context.Background() + + // Pre-populate cache + for i := 0; i < 10; i++ { + key := fmt.Sprintf("getmany-key-%d", i) + value := []byte(fmt.Sprintf("value-%d", i)) + err := backend.Set(ctx, key, value, time.Minute) + require.NoError(t, err) + } + + t.Run("GetManyExisting", func(t *testing.T) { + keys := make([]string, 10) + for i := 0; i < 10; i++ { + keys[i] = fmt.Sprintf("getmany-key-%d", i) + } + + results, err := backend.GetMany(ctx, keys) + require.NoError(t, err) + assert.Len(t, results, 10) + + for i, key := range keys { + assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), results[key]) + } + }) + + t.Run("GetManyMixed", func(t *testing.T) { + keys := []string{ + "getmany-key-0", // exists + "nonexistent-key-1", // doesn't exist + "getmany-key-2", // exists + "nonexistent-key-2", // doesn't exist + } + + results, err := backend.GetMany(ctx, keys) + require.NoError(t, err) + assert.Len(t, results, 2) // Only existing keys + + assert.Equal(t, []byte("value-0"), results["getmany-key-0"]) + assert.Equal(t, []byte("value-2"), results["getmany-key-2"]) + assert.NotContains(t, results, "nonexistent-key-1") + assert.NotContains(t, results, "nonexistent-key-2") + }) + + t.Run("GetManyEmpty", func(t *testing.T) { + results, err := backend.GetMany(ctx, []string{}) + require.NoError(t, err) + assert.NotNil(t, results) + assert.Len(t, results, 0) + }) + + t.Run("GetManySingleKey", func(t *testing.T) { + results, err := backend.GetMany(ctx, []string{"getmany-key-5"}) + require.NoError(t, err) + assert.Len(t, results, 1) + assert.Equal(t, []byte("value-5"), results["getmany-key-5"]) + }) + + t.Run("GetManyAllNonexistent", func(t *testing.T) { + keys := []string{ + "nonexistent-1", + "nonexistent-2", + "nonexistent-3", + } + + results, err := backend.GetMany(ctx, keys) + require.NoError(t, err) + assert.Len(t, results, 0) + }) +} + +// TestPipeline_LargeBatch tests pipelining with large batches +func TestPipeline_LargeBatch(t *testing.T) { + t.Parallel() + + _, backend := setupTestRedis(t) + ctx := context.Background() + + t.Run("SetMany100Items", func(t *testing.T) { + items := make(map[string][]byte) + for i := 0; i < 100; i++ { + items[fmt.Sprintf("large-batch-%d", i)] = []byte(fmt.Sprintf("value-%d", i)) + } + + err := backend.SetMany(ctx, items, time.Minute) + require.NoError(t, err) + + // Verify random samples + for _, i := range []int{0, 25, 50, 75, 99} { + key := fmt.Sprintf("large-batch-%d", i) + value, _, exists, err := backend.Get(ctx, key) + require.NoError(t, err) + assert.True(t, exists) + assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), value) + } + }) + + t.Run("GetMany100Items", func(t *testing.T) { + keys := make([]string, 100) + for i := 0; i < 100; i++ { + keys[i] = fmt.Sprintf("large-batch-%d", i) + } + + results, err := backend.GetMany(ctx, keys) + require.NoError(t, err) + assert.Len(t, results, 100) + }) +} + +// TestPipeline_Stats tests that stats are tracked correctly with pipelining +func TestPipeline_Stats(t *testing.T) { + t.Parallel() + + _, backend := setupTestRedis(t) + ctx := context.Background() + + // Set some items + items := map[string][]byte{ + "stats-key-1": []byte("value1"), + "stats-key-2": []byte("value2"), + } + err := backend.SetMany(ctx, items, time.Minute) + require.NoError(t, err) + + // Get items (some exist, some don't) + keys := []string{ + "stats-key-1", + "stats-key-2", + "stats-key-nonexistent", + } + results, err := backend.GetMany(ctx, keys) + require.NoError(t, err) + assert.Len(t, results, 2) + + // Check stats + stats := backend.GetStats() + hits := stats["hits"].(int64) + misses := stats["misses"].(int64) + + assert.Equal(t, int64(2), hits, "Should have 2 hits") + assert.Equal(t, int64(1), misses, "Should have 1 miss") +} + +// BenchmarkPipeline_SetMany benchmarks SetMany with pipelining +func BenchmarkPipeline_SetMany(b *testing.B) { + mr, err := miniredis.Run() + if err != nil { + b.Fatal(err) + } + defer mr.Close() + + backend, err := NewRedisBackend(&Config{ + RedisAddr: mr.Addr(), + RedisPrefix: "bench:", + PoolSize: 10, + }) + if err != nil { + b.Fatal(err) + } + defer backend.Close() + + ctx := context.Background() + + // Prepare items + items := make(map[string][]byte) + for i := 0; i < 100; i++ { + items[fmt.Sprintf("bench-key-%d", i)] = []byte(fmt.Sprintf("bench-value-%d", i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = backend.SetMany(ctx, items, time.Minute) + } +} + +// BenchmarkPipeline_GetMany benchmarks GetMany with pipelining +func BenchmarkPipeline_GetMany(b *testing.B) { + mr, err := miniredis.Run() + if err != nil { + b.Fatal(err) + } + defer mr.Close() + + backend, err := NewRedisBackend(&Config{ + RedisAddr: mr.Addr(), + RedisPrefix: "bench:", + PoolSize: 10, + }) + if err != nil { + b.Fatal(err) + } + defer backend.Close() + + ctx := context.Background() + + // Pre-populate cache + for i := 0; i < 100; i++ { + key := fmt.Sprintf("bench-key-%d", i) + value := []byte(fmt.Sprintf("bench-value-%d", i)) + backend.Set(ctx, key, value, time.Hour) + } + + // Prepare keys + keys := make([]string, 100) + for i := 0; i < 100; i++ { + keys[i] = fmt.Sprintf("bench-key-%d", i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = backend.GetMany(ctx, keys) + } +} + +// BenchmarkPipeline_VsSequential benchmarks pipeline vs sequential operations +func BenchmarkPipeline_VsSequential(b *testing.B) { + mr, err := miniredis.Run() + if err != nil { + b.Fatal(err) + } + defer mr.Close() + + backend, err := NewRedisBackend(&Config{ + RedisAddr: mr.Addr(), + RedisPrefix: "bench:", + PoolSize: 10, + }) + if err != nil { + b.Fatal(err) + } + defer backend.Close() + + ctx := context.Background() + + // Prepare items + items := make(map[string][]byte) + keys := make([]string, 50) + for i := 0; i < 50; i++ { + key := fmt.Sprintf("compare-key-%d", i) + keys[i] = key + items[key] = []byte(fmt.Sprintf("compare-value-%d", i)) + } + + b.Run("Pipelined-Set", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = backend.SetMany(ctx, items, time.Minute) + } + }) + + b.Run("Sequential-Set", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + for key, value := range items { + _ = backend.Set(ctx, key, value, time.Minute) + } + } + }) + + // Pre-populate for get benchmarks + _ = backend.SetMany(ctx, items, time.Hour) + + b.Run("Pipelined-Get", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = backend.GetMany(ctx, keys) + } + }) + + b.Run("Sequential-Get", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, key := range keys { + _, _, _, _ = backend.Get(ctx, key) + } + } + }) +} diff --git a/internal/cache/backends/redis_pool.go b/internal/cache/backends/redis_pool.go index 3037320..16b79d0 100644 --- a/internal/cache/backends/redis_pool.go +++ b/internal/cache/backends/redis_pool.go @@ -336,3 +336,120 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool { _, err := conn.Do("PING") return err == nil } + +// Pipeline represents a Redis pipeline for batch operations +// It queues multiple commands and executes them in a single round-trip +type Pipeline struct { + conn *RedisConn + commands []pipelineCommand + mu sync.Mutex +} + +// pipelineCommand represents a single command in the pipeline +type pipelineCommand struct { + command string + args []string +} + +// NewPipeline creates a new pipeline for the connection +func (c *RedisConn) NewPipeline() *Pipeline { + return &Pipeline{ + conn: c, + commands: make([]pipelineCommand, 0, 16), // Pre-allocate for typical batch size + } +} + +// Queue adds a command to the pipeline +func (p *Pipeline) Queue(command string, args ...string) { + p.mu.Lock() + defer p.mu.Unlock() + + p.commands = append(p.commands, pipelineCommand{ + command: command, + args: args, + }) +} + +// Execute sends all queued commands and returns all responses +// Returns a slice of responses in the same order as commands were queued +func (p *Pipeline) Execute() ([]interface{}, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if len(p.commands) == 0 { + return nil, nil + } + + if p.conn.closed.Load() { + return nil, ErrBackendClosed + } + + p.conn.mu.Lock() + defer p.conn.mu.Unlock() + + // Set write timeout for all commands + if p.conn.writeTimeout > 0 { + // Use longer timeout for batch operations + timeout := p.conn.writeTimeout * time.Duration(len(p.commands)) + if timeout > 30*time.Second { + timeout = 30 * time.Second // Cap at 30 seconds + } + _ = p.conn.conn.SetWriteDeadline(time.Now().Add(timeout)) + } + + // Write all commands (pipelining - send all before reading any responses) + writer := NewRESPWriter(p.conn.conn) + for _, cmd := range p.commands { + cmdArgs := append([]string{cmd.command}, cmd.args...) + if err := writer.WriteCommand(cmdArgs...); err != nil { + writer.Release() + p.conn.closed.Store(true) + return nil, fmt.Errorf("pipeline write error: %w", err) + } + } + writer.Release() + + // Set read timeout for all responses + if p.conn.readTimeout > 0 { + timeout := p.conn.readTimeout * time.Duration(len(p.commands)) + if timeout > 30*time.Second { + timeout = 30 * time.Second + } + _ = p.conn.conn.SetReadDeadline(time.Now().Add(timeout)) + } + + // Read all responses + responses := make([]interface{}, len(p.commands)) + reader := NewRESPReader(p.conn.conn) + defer reader.Release() + + for i := range p.commands { + resp, err := reader.ReadResponse() + if err != nil { + // For nil responses, store nil instead of erroring + if errors.Is(err, ErrNilResponse) { + responses[i] = nil + continue + } + p.conn.closed.Store(true) + return responses[:i], fmt.Errorf("pipeline read error at command %d: %w", i, err) + } + responses[i] = resp + } + + return responses, nil +} + +// Clear resets the pipeline for reuse +func (p *Pipeline) Clear() { + p.mu.Lock() + defer p.mu.Unlock() + p.commands = p.commands[:0] +} + +// Len returns the number of queued commands +func (p *Pipeline) Len() int { + p.mu.Lock() + defer p.mu.Unlock() + return len(p.commands) +} diff --git a/internal/cache/backends/singleflight.go b/internal/cache/backends/singleflight.go new file mode 100644 index 0000000..34f0afd --- /dev/null +++ b/internal/cache/backends/singleflight.go @@ -0,0 +1,183 @@ +package backends + +import ( + "context" + "sync" + "sync/atomic" + "time" +) + +// SingleflightCache wraps a CacheBackend with singleflight deduplication +// to prevent thundering herd problems when multiple concurrent requests +// try to fetch the same uncached key. +type SingleflightCache struct { + backend CacheBackend + mu sync.Mutex + calls map[string]*singleflightCall + + // Metrics + deduplicatedCalls atomic.Int64 + totalCalls atomic.Int64 +} + +// singleflightCall represents an in-flight or completed fetch call +type singleflightCall struct { + wg sync.WaitGroup + val []byte + ttl time.Duration + err error + done bool +} + +// NewSingleflightCache creates a new singleflight-wrapped cache backend +func NewSingleflightCache(backend CacheBackend) *SingleflightCache { + return &SingleflightCache{ + backend: backend, + calls: make(map[string]*singleflightCall), + } +} + +// Fetcher is a function type that fetches data when cache misses +type Fetcher func(ctx context.Context) (value []byte, ttl time.Duration, err error) + +// GetOrFetch retrieves a value from cache or calls the fetcher exactly once +// per key when there's a cache miss. Concurrent calls for the same key will +// wait for the first call to complete and share its result. +func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher Fetcher) ([]byte, error) { + s.totalCalls.Add(1) + + // Try cache first + value, _, exists, err := s.backend.Get(ctx, key) + if err != nil { + return nil, err + } + if exists { + return value, nil + } + + // Cache miss - use singleflight + s.mu.Lock() + + // Check if there's already an in-flight call for this key + if call, ok := s.calls[key]; ok { + s.mu.Unlock() + s.deduplicatedCalls.Add(1) + + // Wait for the in-flight call to complete + call.wg.Wait() + + // Check context cancellation + if ctx.Err() != nil { + return nil, ctx.Err() + } + + return call.val, call.err + } + + // Create new call + call := &singleflightCall{} + call.wg.Add(1) + s.calls[key] = call + s.mu.Unlock() + + // Execute the fetcher + call.val, call.ttl, call.err = fetcher(ctx) + call.done = true + + // If successful, store in cache + if call.err == nil && call.val != nil { + // Use a background context for cache storage to ensure it completes + // even if the original context is cancelled + storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = s.backend.Set(storeCtx, key, call.val, call.ttl) + cancel() + } + + // Signal waiting goroutines + call.wg.Done() + + // Clean up the call from the map after a short delay + // This allows late arrivals to still benefit from the result + go func() { + time.Sleep(100 * time.Millisecond) + s.mu.Lock() + if c, ok := s.calls[key]; ok && c == call { + delete(s.calls, key) + } + s.mu.Unlock() + }() + + return call.val, call.err +} + +// Get retrieves a value from the underlying cache backend +func (s *SingleflightCache) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) { + return s.backend.Get(ctx, key) +} + +// Set stores a value in the underlying cache backend +func (s *SingleflightCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error { + return s.backend.Set(ctx, key, value, ttl) +} + +// Delete removes a key from the underlying cache backend +func (s *SingleflightCache) Delete(ctx context.Context, key string) (bool, error) { + return s.backend.Delete(ctx, key) +} + +// Exists checks if a key exists in the underlying cache backend +func (s *SingleflightCache) Exists(ctx context.Context, key string) (bool, error) { + return s.backend.Exists(ctx, key) +} + +// Clear removes all keys from the underlying cache backend +func (s *SingleflightCache) Clear(ctx context.Context) error { + return s.backend.Clear(ctx) +} + +// GetStats returns cache statistics including singleflight metrics +func (s *SingleflightCache) GetStats() map[string]interface{} { + stats := s.backend.GetStats() + + // Add singleflight-specific stats + totalCalls := s.totalCalls.Load() + deduped := s.deduplicatedCalls.Load() + + stats["singleflight_total_calls"] = totalCalls + stats["singleflight_deduplicated"] = deduped + if totalCalls > 0 { + stats["singleflight_dedup_rate"] = float64(deduped) / float64(totalCalls) + } else { + stats["singleflight_dedup_rate"] = float64(0) + } + + s.mu.Lock() + stats["singleflight_inflight"] = len(s.calls) + s.mu.Unlock() + + return stats +} + +// Close shuts down the cache backend +func (s *SingleflightCache) Close() error { + return s.backend.Close() +} + +// Ping checks if the backend is healthy +func (s *SingleflightCache) Ping(ctx context.Context) error { + return s.backend.Ping(ctx) +} + +// GetBackend returns the underlying cache backend +func (s *SingleflightCache) GetBackend() CacheBackend { + return s.backend +} + +// ResetStats resets the singleflight statistics +func (s *SingleflightCache) ResetStats() { + s.totalCalls.Store(0) + s.deduplicatedCalls.Store(0) +} + +// Ensure SingleflightCache implements CacheBackend +var _ CacheBackend = (*SingleflightCache)(nil) diff --git a/internal/cache/backends/singleflight_test.go b/internal/cache/backends/singleflight_test.go new file mode 100644 index 0000000..b211c12 --- /dev/null +++ b/internal/cache/backends/singleflight_test.go @@ -0,0 +1,510 @@ +package backends + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestSingleflightCache_BasicGetOrFetch tests basic GetOrFetch functionality +func TestSingleflightCache_BasicGetOrFetch(t *testing.T) { + t.Parallel() + + backend, err := NewMemoryBackend(DefaultConfig()) + require.NoError(t, err) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + ctx := context.Background() + + t.Run("CacheHit", func(t *testing.T) { + key := "existing-key" + value := []byte("existing-value") + + // Pre-populate cache + err := cache.Set(ctx, key, value, time.Minute) + require.NoError(t, err) + + var fetchCalled bool + fetcher := func(ctx context.Context) ([]byte, time.Duration, error) { + fetchCalled = true + return []byte("fetched-value"), time.Minute, nil + } + + result, err := cache.GetOrFetch(ctx, key, fetcher) + require.NoError(t, err) + assert.Equal(t, value, result) + assert.False(t, fetchCalled, "Fetcher should not be called on cache hit") + }) + + t.Run("CacheMiss", func(t *testing.T) { + key := "missing-key" + expectedValue := []byte("fetched-value") + + var fetchCalled bool + fetcher := func(ctx context.Context) ([]byte, time.Duration, error) { + fetchCalled = true + return expectedValue, time.Minute, nil + } + + result, err := cache.GetOrFetch(ctx, key, fetcher) + require.NoError(t, err) + assert.Equal(t, expectedValue, result) + assert.True(t, fetchCalled, "Fetcher should be called on cache miss") + + // Verify value was stored in cache + cached, _, exists, err := cache.Get(ctx, key) + require.NoError(t, err) + assert.True(t, exists) + assert.Equal(t, expectedValue, cached) + }) + + t.Run("FetcherError", func(t *testing.T) { + key := "error-key" + expectedErr := errors.New("fetch failed") + + fetcher := func(ctx context.Context) ([]byte, time.Duration, error) { + return nil, 0, expectedErr + } + + result, err := cache.GetOrFetch(ctx, key, fetcher) + assert.Error(t, err) + assert.Equal(t, expectedErr, err) + assert.Nil(t, result) + + // Verify nothing was stored in cache + _, _, exists, err := cache.Get(ctx, key) + require.NoError(t, err) + assert.False(t, exists) + }) +} + +// TestSingleflightCache_Deduplication tests that concurrent calls are deduplicated +func TestSingleflightCache_Deduplication(t *testing.T) { + t.Parallel() + + backend, err := NewMemoryBackend(DefaultConfig()) + require.NoError(t, err) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + ctx := context.Background() + key := "dedup-key" + expectedValue := []byte("dedup-value") + + var fetchCount atomic.Int32 + fetcher := func(ctx context.Context) ([]byte, time.Duration, error) { + fetchCount.Add(1) + // Simulate slow fetch + time.Sleep(100 * time.Millisecond) + return expectedValue, time.Minute, nil + } + + // Launch multiple concurrent requests + concurrency := 10 + var wg sync.WaitGroup + results := make([][]byte, concurrency) + errs := make([]error, concurrency) + + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + results[idx], errs[idx] = cache.GetOrFetch(ctx, key, fetcher) + }(i) + } + + wg.Wait() + + // Verify all requests got the same result + for i := 0; i < concurrency; i++ { + assert.NoError(t, errs[i]) + assert.Equal(t, expectedValue, results[i]) + } + + // Verify fetcher was only called once + assert.Equal(t, int32(1), fetchCount.Load(), "Fetcher should only be called once") + + // Verify deduplication stats + stats := cache.GetStats() + deduped := stats["singleflight_deduplicated"].(int64) + assert.Equal(t, int64(concurrency-1), deduped, "Should have deduplicated N-1 calls") +} + +// TestSingleflightCache_DifferentKeys tests that different keys can fetch in parallel +func TestSingleflightCache_DifferentKeys(t *testing.T) { + t.Parallel() + + backend, err := NewMemoryBackend(DefaultConfig()) + require.NoError(t, err) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + ctx := context.Background() + + var fetchCount atomic.Int32 + fetchStarted := make(chan struct{}, 3) + fetchComplete := make(chan struct{}) + + fetcher := func(key string) Fetcher { + return func(ctx context.Context) ([]byte, time.Duration, error) { + fetchCount.Add(1) + fetchStarted <- struct{}{} + <-fetchComplete // Wait for signal + return []byte("value-" + key), time.Minute, nil + } + } + + // Launch concurrent requests for different keys + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + key := fmt.Sprintf("key-%d", idx) + _, _ = cache.GetOrFetch(ctx, key, fetcher(key)) + }(i) + } + + // Wait for all fetches to start + for i := 0; i < 3; i++ { + <-fetchStarted + } + + // All 3 fetches should be running in parallel + assert.Equal(t, int32(3), fetchCount.Load(), "All three fetches should run in parallel") + + // Release all fetches + close(fetchComplete) + wg.Wait() +} + +// TestSingleflightCache_ContextCancellation tests context cancellation +func TestSingleflightCache_ContextCancellation(t *testing.T) { + t.Parallel() + + backend, err := NewMemoryBackend(DefaultConfig()) + require.NoError(t, err) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + key := "cancel-key" + fetchStarted := make(chan struct{}) + + fetcher := func(ctx context.Context) ([]byte, time.Duration, error) { + close(fetchStarted) + // Simulate slow fetch + time.Sleep(500 * time.Millisecond) + return []byte("value"), time.Minute, nil + } + + // Start first request with long timeout + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + ctx := context.Background() + _, _ = cache.GetOrFetch(ctx, key, fetcher) + }() + + // Wait for fetch to start + <-fetchStarted + + // Start second request with short timeout + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = cache.GetOrFetch(ctx, key, fetcher) + assert.Error(t, err) + assert.Equal(t, context.DeadlineExceeded, err) + + wg.Wait() +} + +// TestSingleflightCache_ErrorPropagation tests that errors are properly propagated +func TestSingleflightCache_ErrorPropagation(t *testing.T) { + t.Parallel() + + backend, err := NewMemoryBackend(DefaultConfig()) + require.NoError(t, err) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + ctx := context.Background() + key := "error-prop-key" + expectedErr := errors.New("intentional error") + + var fetchCount atomic.Int32 + fetcher := func(ctx context.Context) ([]byte, time.Duration, error) { + fetchCount.Add(1) + time.Sleep(50 * time.Millisecond) + return nil, 0, expectedErr + } + + // Launch multiple concurrent requests + concurrency := 5 + var wg sync.WaitGroup + errs := make([]error, concurrency) + + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _, errs[idx] = cache.GetOrFetch(ctx, key, fetcher) + }(i) + } + + wg.Wait() + + // Verify all requests got the same error + for i := 0; i < concurrency; i++ { + assert.Error(t, errs[i]) + assert.Equal(t, expectedErr, errs[i]) + } + + // Verify fetcher was only called once + assert.Equal(t, int32(1), fetchCount.Load()) +} + +// TestSingleflightCache_PassthroughMethods tests that passthrough methods work +func TestSingleflightCache_PassthroughMethods(t *testing.T) { + t.Parallel() + + backend, err := NewMemoryBackend(DefaultConfig()) + require.NoError(t, err) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + ctx := context.Background() + + t.Run("Set", func(t *testing.T) { + err := cache.Set(ctx, "set-key", []byte("set-value"), time.Minute) + require.NoError(t, err) + + val, _, exists, err := cache.Get(ctx, "set-key") + require.NoError(t, err) + assert.True(t, exists) + assert.Equal(t, []byte("set-value"), val) + }) + + t.Run("Get", func(t *testing.T) { + err := cache.Set(ctx, "get-key", []byte("get-value"), time.Minute) + require.NoError(t, err) + + val, ttl, exists, err := cache.Get(ctx, "get-key") + require.NoError(t, err) + assert.True(t, exists) + assert.Equal(t, []byte("get-value"), val) + assert.Greater(t, ttl, time.Duration(0)) + }) + + t.Run("Delete", func(t *testing.T) { + err := cache.Set(ctx, "delete-key", []byte("delete-value"), time.Minute) + require.NoError(t, err) + + deleted, err := cache.Delete(ctx, "delete-key") + require.NoError(t, err) + assert.True(t, deleted) + + exists, err := cache.Exists(ctx, "delete-key") + require.NoError(t, err) + assert.False(t, exists) + }) + + t.Run("Exists", func(t *testing.T) { + exists, err := cache.Exists(ctx, "nonexistent") + require.NoError(t, err) + assert.False(t, exists) + + err = cache.Set(ctx, "exists-key", []byte("value"), time.Minute) + require.NoError(t, err) + + exists, err = cache.Exists(ctx, "exists-key") + require.NoError(t, err) + assert.True(t, exists) + }) + + t.Run("Clear", func(t *testing.T) { + err := cache.Set(ctx, "clear-key", []byte("value"), time.Minute) + require.NoError(t, err) + + err = cache.Clear(ctx) + require.NoError(t, err) + + exists, err := cache.Exists(ctx, "clear-key") + require.NoError(t, err) + assert.False(t, exists) + }) + + t.Run("Ping", func(t *testing.T) { + err := cache.Ping(ctx) + require.NoError(t, err) + }) +} + +// TestSingleflightCache_Stats tests statistics tracking +func TestSingleflightCache_Stats(t *testing.T) { + t.Parallel() + + backend, err := NewMemoryBackend(DefaultConfig()) + require.NoError(t, err) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + ctx := context.Background() + + // Make some calls + fetcher := func(ctx context.Context) ([]byte, time.Duration, error) { + time.Sleep(50 * time.Millisecond) + return []byte("value"), time.Minute, nil + } + + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = cache.GetOrFetch(ctx, "stats-key", fetcher) + }() + } + wg.Wait() + + stats := cache.GetStats() + + // Check singleflight stats exist + assert.Contains(t, stats, "singleflight_total_calls") + assert.Contains(t, stats, "singleflight_deduplicated") + assert.Contains(t, stats, "singleflight_dedup_rate") + assert.Contains(t, stats, "singleflight_inflight") + + // Verify values + assert.Equal(t, int64(5), stats["singleflight_total_calls"]) + assert.Equal(t, int64(4), stats["singleflight_deduplicated"]) + + // Also check underlying backend stats are included + assert.Contains(t, stats, "hits") + assert.Contains(t, stats, "misses") +} + +// TestSingleflightCache_ResetStats tests stats reset +func TestSingleflightCache_ResetStats(t *testing.T) { + t.Parallel() + + backend, err := NewMemoryBackend(DefaultConfig()) + require.NoError(t, err) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + ctx := context.Background() + + fetcher := func(ctx context.Context) ([]byte, time.Duration, error) { + return []byte("value"), time.Minute, nil + } + + // Make some calls + _, _ = cache.GetOrFetch(ctx, "key1", fetcher) + _, _ = cache.GetOrFetch(ctx, "key2", fetcher) + + stats := cache.GetStats() + assert.Greater(t, stats["singleflight_total_calls"].(int64), int64(0)) + + // Reset stats + cache.ResetStats() + + stats = cache.GetStats() + assert.Equal(t, int64(0), stats["singleflight_total_calls"]) + assert.Equal(t, int64(0), stats["singleflight_deduplicated"]) +} + +// TestSingleflightCache_GetBackend tests GetBackend method +func TestSingleflightCache_GetBackend(t *testing.T) { + t.Parallel() + + backend, err := NewMemoryBackend(DefaultConfig()) + require.NoError(t, err) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + assert.Equal(t, backend, cache.GetBackend()) +} + +// BenchmarkSingleflightCache_Sequential benchmarks sequential access +func BenchmarkSingleflightCache_Sequential(b *testing.B) { + backend, _ := NewMemoryBackend(DefaultConfig()) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + ctx := context.Background() + fetcher := func(ctx context.Context) ([]byte, time.Duration, error) { + return []byte("benchmark-value"), time.Minute, nil + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := fmt.Sprintf("key-%d", i%100) + _, _ = cache.GetOrFetch(ctx, key, fetcher) + } +} + +// BenchmarkSingleflightCache_Concurrent benchmarks concurrent access +func BenchmarkSingleflightCache_Concurrent(b *testing.B) { + backend, _ := NewMemoryBackend(DefaultConfig()) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + ctx := context.Background() + fetcher := func(ctx context.Context) ([]byte, time.Duration, error) { + time.Sleep(time.Millisecond) // Simulate slow fetch + return []byte("benchmark-value"), time.Minute, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + key := fmt.Sprintf("key-%d", i%10) // Only 10 unique keys to force deduplication + _, _ = cache.GetOrFetch(ctx, key, fetcher) + i++ + } + }) +} + +// BenchmarkSingleflightCache_HighContention benchmarks high contention scenario +func BenchmarkSingleflightCache_HighContention(b *testing.B) { + backend, _ := NewMemoryBackend(DefaultConfig()) + defer backend.Close() + + cache := NewSingleflightCache(backend) + + ctx := context.Background() + fetcher := func(ctx context.Context) ([]byte, time.Duration, error) { + time.Sleep(10 * time.Millisecond) // Slow fetch to force queuing + return []byte("benchmark-value"), time.Minute, nil + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // All goroutines hit the same key + _, _ = cache.GetOrFetch(ctx, "hot-key", fetcher) + } + }) +} diff --git a/universal_cache.go b/universal_cache.go index 2c4b779..3cb4dc1 100644 --- a/universal_cache.go +++ b/universal_cache.go @@ -436,7 +436,7 @@ func (c *UniversalCache) Clear() { c.currentSize = 0 c.currentMemory = 0 - c.logger.Infof("UniversalCache[%s]: Cleared all items", c.config.Type) + c.logger.Debugf("UniversalCache[%s]: Cleared all items", c.config.Type) } // Size returns the number of items in the cache