mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
LRU + cache conflicts prevention. (#104)
* LRU + cache conflicts prevention. * Bugfix universalCache flooding ( issue #105 ) 1. Traefik cancels the context for old plugin instances 2. Each plugin's Close() method is called 3. The CacheInterfaceWrapper.Close() was calling cache.Close() on the shared singleton caches 4. Each Close() triggered Clear() which logged "Cleared all items" at INFO level
This commit is contained in:
Vendored
+224
-197
@@ -2,20 +2,27 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Default configuration values
|
||||
const (
|
||||
defaultShardCount = 256
|
||||
defaultMaxSize = int64(10000)
|
||||
defaultMaxMemory = int64(100 * 1024 * 1024) // 100MB
|
||||
defaultCleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// memoryCacheItem represents an item in the memory cache
|
||||
type memoryCacheItem struct {
|
||||
expiresAt time.Time
|
||||
createdAt time.Time
|
||||
accessedAt time.Time
|
||||
value interface{}
|
||||
element *list.Element
|
||||
element interface{} // *list.Element, using interface{} to avoid import cycle
|
||||
key string
|
||||
accessCount int64
|
||||
size int64
|
||||
@@ -29,56 +36,89 @@ func (item *memoryCacheItem) isExpired() bool {
|
||||
return time.Now().After(item.expiresAt)
|
||||
}
|
||||
|
||||
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
|
||||
// MemoryCacheBackend implements the CacheBackend interface using sharded in-memory storage
|
||||
// The sharded design reduces lock contention by partitioning keys across multiple shards,
|
||||
// each with its own lock.
|
||||
type MemoryCacheBackend struct {
|
||||
shards []*cacheShard
|
||||
startTime time.Time
|
||||
lastErrorTime time.Time
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
cleanupDone chan bool
|
||||
cleanupDone chan struct{}
|
||||
cleanupTicker *time.Ticker
|
||||
evictionPolicy string
|
||||
lastError string
|
||||
currentMemory int64
|
||||
misses atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
sets atomic.Int64
|
||||
hits atomic.Int64
|
||||
shardCount uint32
|
||||
shardMask uint32
|
||||
maxSize int64
|
||||
currentSize int64
|
||||
maxMemory int64
|
||||
cleanupInterval time.Duration
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool
|
||||
|
||||
// Global stats (aggregated from shards)
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
sets atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Latency tracking
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
|
||||
// State
|
||||
closed atomic.Bool
|
||||
mu sync.RWMutex // For global operations like stats and error tracking
|
||||
}
|
||||
|
||||
// NewMemoryCacheBackend creates a new memory cache backend
|
||||
// NewMemoryCacheBackend creates a new sharded memory cache backend
|
||||
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
|
||||
if maxSize <= 0 {
|
||||
maxSize = 10000 // Default to 10k items
|
||||
maxSize = defaultMaxSize
|
||||
}
|
||||
if maxMemory <= 0 {
|
||||
maxMemory = 100 * 1024 * 1024 // Default to 100MB
|
||||
maxMemory = defaultMaxMemory
|
||||
}
|
||||
if cleanupInterval <= 0 {
|
||||
cleanupInterval = 5 * time.Minute
|
||||
cleanupInterval = defaultCleanupInterval
|
||||
}
|
||||
|
||||
shardCount := uint32(defaultShardCount)
|
||||
|
||||
// For very small caches, reduce shard count to maintain sensible per-shard limits
|
||||
// Ensure each shard can hold at least 2 items for proper LRU behavior
|
||||
for shardCount > 1 && maxSize/int64(shardCount) < 2 {
|
||||
shardCount /= 2
|
||||
}
|
||||
if shardCount < 1 {
|
||||
shardCount = 1
|
||||
}
|
||||
|
||||
// Per-shard limits are soft hints; global limits are enforced
|
||||
// Give shards 2x the average to allow for uneven distribution
|
||||
shardMaxSize := (maxSize * 2) / int64(shardCount)
|
||||
if shardMaxSize < 4 {
|
||||
shardMaxSize = 4
|
||||
}
|
||||
shardMaxMemory := (maxMemory * 2) / int64(shardCount)
|
||||
if shardMaxMemory < 4096 {
|
||||
shardMaxMemory = 4096 // Minimum 4KB per shard
|
||||
}
|
||||
|
||||
m := &MemoryCacheBackend{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
shards: make([]*cacheShard, shardCount),
|
||||
shardCount: shardCount,
|
||||
shardMask: shardCount - 1, // For fast modulo with power-of-2
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
startTime: time.Now(),
|
||||
cleanupInterval: cleanupInterval,
|
||||
evictionPolicy: "lru",
|
||||
cleanupDone: make(chan bool),
|
||||
cleanupDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Initialize shards
|
||||
for i := uint32(0); i < shardCount; i++ {
|
||||
m.shards[i] = newCacheShard(shardMaxSize, shardMaxMemory)
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
@@ -88,6 +128,12 @@ func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.
|
||||
return m
|
||||
}
|
||||
|
||||
// getShard returns the shard for a given key
|
||||
func (m *MemoryCacheBackend) getShard(key string) *cacheShard {
|
||||
hash := fnv32(key)
|
||||
return m.shards[hash&m.shardMask]
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired items
|
||||
func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
for {
|
||||
@@ -100,20 +146,19 @@ func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpired removes all expired items from the cache
|
||||
// cleanupExpired removes all expired items from all shards
|
||||
func (m *MemoryCacheBackend) cleanupExpired() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var keysToDelete []string
|
||||
for key, item := range m.items {
|
||||
if item.isExpired() {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
if m.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
for _, key := range keysToDelete {
|
||||
m.deleteItemLocked(key)
|
||||
totalRemoved := 0
|
||||
for _, shard := range m.shards {
|
||||
totalRemoved += shard.cleanup()
|
||||
}
|
||||
|
||||
if totalRemoved > 0 {
|
||||
m.evictions.Add(int64(totalRemoved))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,35 +175,23 @@ func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{},
|
||||
m.getCount.Add(1)
|
||||
}()
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
shard := m.getShard(key)
|
||||
value, exists, expired := shard.get(key)
|
||||
|
||||
if expired {
|
||||
// Clean up expired item
|
||||
shard.delete(key)
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if !exists {
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
m.mu.Lock()
|
||||
m.deleteItemLocked(key)
|
||||
m.mu.Unlock()
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
// Update access time and count
|
||||
m.mu.Lock()
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
// Move to front of LRU list
|
||||
if m.evictionPolicy == "lru" && item.element != nil {
|
||||
m.lruList.MoveToFront(item.element)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.hits.Add(1)
|
||||
return item.value, nil
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with optional TTL
|
||||
@@ -174,113 +207,105 @@ func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interfac
|
||||
m.setCount.Add(1)
|
||||
}()
|
||||
|
||||
// Calculate item size (simplified estimation)
|
||||
// Calculate item size
|
||||
itemSize := int64(len(key)) + estimateValueSize(value)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Enforce global limits before adding new item
|
||||
m.enforceGlobalLimits(itemSize)
|
||||
|
||||
// Check if we need to evict items
|
||||
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
|
||||
m.evictLocked()
|
||||
}
|
||||
|
||||
// Check if key exists
|
||||
if oldItem, exists := m.items[key]; exists {
|
||||
m.currentMemory -= oldItem.size
|
||||
if oldItem.element != nil {
|
||||
m.lruList.Remove(oldItem.element)
|
||||
}
|
||||
} else {
|
||||
m.currentSize++
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
var expiresAt time.Time
|
||||
if ttl > 0 {
|
||||
expiresAt = now.Add(ttl)
|
||||
expiresAt = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: itemSize,
|
||||
}
|
||||
shard := m.getShard(key)
|
||||
shard.set(key, value, expiresAt, itemSize)
|
||||
|
||||
// Add to LRU list
|
||||
if m.evictionPolicy == "lru" {
|
||||
item.element = m.lruList.PushFront(item)
|
||||
}
|
||||
|
||||
m.items[key] = item
|
||||
m.currentMemory += itemSize
|
||||
m.sets.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// enforceGlobalLimits ensures global size and memory limits are respected
|
||||
// by evicting from shards when necessary
|
||||
func (m *MemoryCacheBackend) enforceGlobalLimits(newItemSize int64) {
|
||||
// Check and enforce size limit
|
||||
for {
|
||||
totalSize, totalMemory := m.getGlobalStats()
|
||||
|
||||
needsSizeEviction := m.maxSize > 0 && totalSize >= m.maxSize
|
||||
needsMemoryEviction := m.maxMemory > 0 && totalMemory+newItemSize > m.maxMemory
|
||||
|
||||
if !needsSizeEviction && !needsMemoryEviction {
|
||||
break
|
||||
}
|
||||
|
||||
// Find the shard with the most items and evict from it
|
||||
evicted := m.evictFromLargestShard()
|
||||
if !evicted {
|
||||
break // No more items to evict
|
||||
}
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// getGlobalStats returns the total size and memory usage across all shards
|
||||
func (m *MemoryCacheBackend) getGlobalStats() (totalSize, totalMemory int64) {
|
||||
for _, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
totalSize += size
|
||||
totalMemory += memory
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// evictFromLargestShard evicts the globally oldest item across all shards
|
||||
// This provides true LRU behavior even with sharding
|
||||
func (m *MemoryCacheBackend) evictFromLargestShard() bool {
|
||||
var oldestShard *cacheShard
|
||||
var oldestTime time.Time
|
||||
|
||||
for _, shard := range m.shards {
|
||||
accessTime := shard.getOldestAccessTime()
|
||||
// Skip empty shards
|
||||
if accessTime.IsZero() {
|
||||
continue
|
||||
}
|
||||
// Find the shard with the oldest (earliest) access time
|
||||
if oldestShard == nil || accessTime.Before(oldestTime) {
|
||||
oldestTime = accessTime
|
||||
oldestShard = shard
|
||||
}
|
||||
}
|
||||
|
||||
if oldestShard == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return oldestShard.evictOne()
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.items[key]; !exists {
|
||||
return nil
|
||||
shard := m.getShard(key)
|
||||
if shard.delete(key) {
|
||||
m.deletes.Add(1)
|
||||
}
|
||||
|
||||
m.deleteItemLocked(key)
|
||||
m.deletes.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
|
||||
if item, exists := m.items[key]; exists {
|
||||
m.currentMemory -= item.size
|
||||
m.currentSize--
|
||||
if item.element != nil {
|
||||
m.lruList.Remove(item.element)
|
||||
}
|
||||
delete(m.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// evictLocked evicts items based on the eviction policy (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) evictLocked() {
|
||||
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
|
||||
// Evict least recently used item
|
||||
element := m.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
m.deleteItemLocked(item.key)
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if m.closed.Load() {
|
||||
return false, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return !item.isExpired(), nil
|
||||
shard := m.getShard(key)
|
||||
return shard.exists(key), nil
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache
|
||||
@@ -289,13 +314,9 @@ func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.items = make(map[string]*memoryCacheItem)
|
||||
m.lruList = list.New()
|
||||
m.currentSize = 0
|
||||
m.currentMemory = 0
|
||||
for _, shard := range m.shards {
|
||||
shard.clear()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -306,29 +327,28 @@ func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range m.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
var allKeys []string
|
||||
for _, shard := range m.shards {
|
||||
keys := shard.keys(pattern)
|
||||
allKeys = append(allKeys, keys...)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
return allKeys, nil
|
||||
}
|
||||
|
||||
// Size returns the number of items in the cache
|
||||
// Size returns the total number of items in the cache
|
||||
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
var total int64
|
||||
for _, shard := range m.shards {
|
||||
size, _ := shard.stats()
|
||||
total += size
|
||||
}
|
||||
|
||||
return m.currentSize, nil
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// TTL returns the remaining time-to-live for a key
|
||||
@@ -337,24 +357,13 @@ func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
shard := m.getShard(key)
|
||||
ttl, exists := shard.ttl(key)
|
||||
if !exists {
|
||||
return 0, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, nil // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return remaining, nil
|
||||
return ttl, nil
|
||||
}
|
||||
|
||||
// Expire updates the TTL for an existing key
|
||||
@@ -363,20 +372,11 @@ func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Du
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
item, exists := m.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
shard := m.getShard(key)
|
||||
if !shard.expire(key, ttl) {
|
||||
return ErrCacheMiss
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -386,6 +386,14 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
// Aggregate stats from all shards
|
||||
var totalSize, totalMemory int64
|
||||
for _, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
totalSize += size
|
||||
totalMemory += memory
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
lastError := m.lastError
|
||||
lastErrorTime := m.lastErrorTime
|
||||
@@ -409,9 +417,9 @@ func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error
|
||||
Deletes: m.deletes.Load(),
|
||||
Errors: m.errors.Load(),
|
||||
Evictions: m.evictions.Load(),
|
||||
CurrentSize: m.currentSize,
|
||||
CurrentSize: totalSize,
|
||||
MaxSize: m.maxSize,
|
||||
MemoryUsage: m.currentMemory,
|
||||
MemoryUsage: totalMemory,
|
||||
AverageGetLatency: avgGetLatency,
|
||||
AverageSetLatency: avgSetLatency,
|
||||
LastError: lastError,
|
||||
@@ -438,10 +446,10 @@ func (m *MemoryCacheBackend) Close() error {
|
||||
m.cleanupTicker.Stop()
|
||||
close(m.cleanupDone)
|
||||
|
||||
m.mu.Lock()
|
||||
m.items = nil
|
||||
m.lruList = nil
|
||||
m.mu.Unlock()
|
||||
// Clear all shards
|
||||
for _, shard := range m.shards {
|
||||
shard.clear()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -474,12 +482,28 @@ func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
|
||||
}
|
||||
}
|
||||
|
||||
// GetShardCount returns the number of shards (for testing/monitoring)
|
||||
func (m *MemoryCacheBackend) GetShardCount() uint32 {
|
||||
return m.shardCount
|
||||
}
|
||||
|
||||
// GetShardStats returns per-shard statistics (for monitoring)
|
||||
func (m *MemoryCacheBackend) GetShardStats() []map[string]int64 {
|
||||
stats := make([]map[string]int64, m.shardCount)
|
||||
for i, shard := range m.shards {
|
||||
size, memory := shard.stats()
|
||||
stats[i] = map[string]int64{
|
||||
"size": size,
|
||||
"memory": memory,
|
||||
}
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// estimateValueSize estimates the size of a value in bytes
|
||||
func estimateValueSize(value interface{}) int64 {
|
||||
// This is a simplified estimation
|
||||
// In production, you might want to use a more accurate method
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return int64(len(v))
|
||||
@@ -502,7 +526,10 @@ func matchPattern(pattern, key string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
// Simplified pattern matching - in production, use a proper glob library
|
||||
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
|
||||
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
|
||||
// Simplified pattern matching
|
||||
if len(pattern) > 0 && pattern[0] == '*' {
|
||||
suffix := pattern[1:]
|
||||
return len(key) >= len(suffix) && key[len(key)-len(suffix):] == suffix
|
||||
}
|
||||
return key == pattern
|
||||
}
|
||||
|
||||
+290
@@ -0,0 +1,290 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// cacheShard represents a single shard of the sharded cache
|
||||
// Each shard has its own lock for reduced contention
|
||||
type cacheShard struct {
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
mu sync.RWMutex
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
size int64
|
||||
memoryUsed int64
|
||||
}
|
||||
|
||||
// newCacheShard creates a new cache shard
|
||||
func newCacheShard(maxSize, maxMemory int64) *cacheShard {
|
||||
return &cacheShard{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
}
|
||||
}
|
||||
|
||||
// get retrieves a value from this shard
|
||||
// Returns: value, exists, expired
|
||||
func (s *cacheShard) get(key string) (interface{}, bool, bool) {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
return nil, true, true // exists but expired
|
||||
}
|
||||
|
||||
// Update access time and LRU position under write lock
|
||||
s.mu.Lock()
|
||||
// Re-check item exists (could have been deleted)
|
||||
item, exists = s.items[key]
|
||||
if exists && !item.isExpired() {
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
if elem, ok := item.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.MoveToFront(elem)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
return item.value, true, false
|
||||
}
|
||||
|
||||
// set stores a value in this shard
|
||||
func (s *cacheShard) set(key string, value interface{}, expiresAt time.Time, size int64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check if we need to evict items
|
||||
if s.maxSize > 0 && s.size >= s.maxSize {
|
||||
s.evictLRULocked()
|
||||
}
|
||||
if s.maxMemory > 0 && s.memoryUsed+size > s.maxMemory {
|
||||
s.evictLRULocked()
|
||||
}
|
||||
|
||||
// Remove old item if exists
|
||||
if oldItem, exists := s.items[key]; exists {
|
||||
s.memoryUsed -= oldItem.size
|
||||
if elem, ok := oldItem.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.Remove(elem)
|
||||
}
|
||||
s.size--
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: size,
|
||||
}
|
||||
|
||||
item.element = s.lruList.PushFront(item)
|
||||
s.items[key] = item
|
||||
s.size++
|
||||
s.memoryUsed += size
|
||||
}
|
||||
|
||||
// delete removes a key from this shard
|
||||
// Returns true if the key was deleted
|
||||
func (s *cacheShard) delete(key string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
item, exists := s.items[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
}
|
||||
|
||||
// exists checks if a key exists (and is not expired)
|
||||
func (s *cacheShard) exists(key string) bool {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
return !item.isExpired()
|
||||
}
|
||||
|
||||
// ttl returns the remaining TTL for a key
|
||||
func (s *cacheShard) ttl(key string) (time.Duration, bool) {
|
||||
s.mu.RLock()
|
||||
item, exists := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, true // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return remaining, true
|
||||
}
|
||||
|
||||
// expire updates the TTL for an existing key
|
||||
func (s *cacheShard) expire(key string, ttl time.Duration) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
item, exists := s.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
return false
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// keys returns all non-expired keys matching the pattern
|
||||
func (s *cacheShard) keys(pattern string) []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range s.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// clear removes all items from this shard
|
||||
func (s *cacheShard) clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.items = make(map[string]*memoryCacheItem)
|
||||
s.lruList.Init()
|
||||
s.size = 0
|
||||
s.memoryUsed = 0
|
||||
}
|
||||
|
||||
// cleanup removes expired items
|
||||
// Returns the number of items removed
|
||||
func (s *cacheShard) cleanup() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var toRemove []*memoryCacheItem
|
||||
for _, item := range s.items {
|
||||
if item.isExpired() {
|
||||
toRemove = append(toRemove, item)
|
||||
}
|
||||
}
|
||||
|
||||
for _, item := range toRemove {
|
||||
s.deleteItemLocked(item)
|
||||
}
|
||||
|
||||
return len(toRemove)
|
||||
}
|
||||
|
||||
// stats returns statistics for this shard
|
||||
func (s *cacheShard) stats() (size, memory int64) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.size, s.memoryUsed
|
||||
}
|
||||
|
||||
// deleteItemLocked removes an item (must be called with lock held)
|
||||
func (s *cacheShard) deleteItemLocked(item *memoryCacheItem) {
|
||||
if elem, ok := item.element.(*list.Element); ok && elem != nil {
|
||||
s.lruList.Remove(elem)
|
||||
}
|
||||
delete(s.items, item.key)
|
||||
s.size--
|
||||
s.memoryUsed -= item.size
|
||||
}
|
||||
|
||||
// evictLRULocked evicts the least recently used item (must be called with lock held)
|
||||
func (s *cacheShard) evictLRULocked() bool {
|
||||
if s.lruList.Len() == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
s.deleteItemLocked(item)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// evictOne evicts one item from this shard (for global limit enforcement)
|
||||
func (s *cacheShard) evictOne() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.evictLRULocked()
|
||||
}
|
||||
|
||||
// getOldestAccessTime returns the access time of the LRU item (oldest) in this shard
|
||||
// Returns zero time if shard is empty
|
||||
func (s *cacheShard) getOldestAccessTime() time.Time {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.lruList.Len() == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
element := s.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
return item.accessedAt
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
// fnv32 computes FNV-1a hash of a string
|
||||
// This is a fast, well-distributed hash function
|
||||
func fnv32(key string) uint32 {
|
||||
const (
|
||||
offset32 = uint32(2166136261)
|
||||
prime32 = uint32(16777619)
|
||||
)
|
||||
|
||||
hash := offset32
|
||||
for i := 0; i < len(key); i++ {
|
||||
hash ^= uint32(key[i])
|
||||
hash *= prime32
|
||||
}
|
||||
return hash
|
||||
}
|
||||
+283
@@ -0,0 +1,283 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestShardedCache_ShardDistribution tests that keys are distributed across shards
|
||||
func TestShardedCache_ShardDistribution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a cache with large enough size to have multiple shards
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024 // 100MB
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add many items to see distribution
|
||||
numItems := 1000
|
||||
for i := 0; i < numItems; i++ {
|
||||
key := fmt.Sprintf("dist-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("dist-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Check that items are distributed across multiple shards
|
||||
shardStats := backend.MemoryCacheBackend.GetShardStats()
|
||||
nonEmptyShards := 0
|
||||
for _, stat := range shardStats {
|
||||
if stat["size"] > 0 {
|
||||
nonEmptyShards++
|
||||
}
|
||||
}
|
||||
|
||||
// With good hash distribution, we should have items in multiple shards
|
||||
assert.Greater(t, nonEmptyShards, 1, "Items should be distributed across multiple shards")
|
||||
}
|
||||
|
||||
// TestShardedCache_ShardCount tests that shard count adapts to cache size
|
||||
func TestShardedCache_ShardCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
maxSize int
|
||||
expectLowShards bool
|
||||
}{
|
||||
{5, true}, // Very small cache should have fewer shards
|
||||
{100, true}, // Small cache should have fewer shards
|
||||
{10000, false}, // Large cache should have default shards
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("MaxSize_%d", tt.maxSize), func(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = tt.maxSize
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
shardCount := backend.MemoryCacheBackend.GetShardCount()
|
||||
|
||||
if tt.expectLowShards {
|
||||
assert.Less(t, shardCount, uint32(256), "Small cache should have fewer shards")
|
||||
} else {
|
||||
assert.Equal(t, uint32(256), shardCount, "Large cache should have default shard count")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShardedCache_ConcurrentSameKey tests concurrent access to the same key
|
||||
func TestShardedCache_ConcurrentSameKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
key := "concurrent-same-key"
|
||||
initialValue := []byte("initial-value")
|
||||
|
||||
err = backend.Set(ctx, key, initialValue, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 50
|
||||
iterations := 100
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Mix of reads and writes
|
||||
if j%3 == 0 {
|
||||
newValue := []byte(fmt.Sprintf("value-%d-%d", id, j))
|
||||
err := backend.Set(ctx, key, newValue, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
_, _, _, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Key should still exist
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
// TestShardedCache_GlobalLRUEviction tests that global LRU is maintained
|
||||
func TestShardedCache_GlobalLRUEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a small cache to force eviction
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
// Small delay to ensure different access times
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
// Access some items to make them recently used
|
||||
for i := 5; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
_, _, _, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Add more items to trigger eviction
|
||||
for i := 10; i < 15; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Recently accessed items (5-9) should still exist
|
||||
for i := 5; i < 10; i++ {
|
||||
key := fmt.Sprintf("global-lru-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Recently accessed item %d should exist", i)
|
||||
}
|
||||
|
||||
// Check eviction stats
|
||||
stats := backend.GetStats()
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have evictions")
|
||||
}
|
||||
|
||||
// TestShardedCache_StatsAggregation tests that stats are aggregated correctly
|
||||
func TestShardedCache_StatsAggregation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 10000
|
||||
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items to multiple shards
|
||||
numItems := 100
|
||||
for i := 0; i < numItems; i++ {
|
||||
key := fmt.Sprintf("stats-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("stats-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Read some items
|
||||
for i := 0; i < numItems/2; i++ {
|
||||
key := fmt.Sprintf("stats-key-%d", i)
|
||||
backend.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Read non-existent items
|
||||
for i := 0; i < 10; i++ {
|
||||
backend.Get(ctx, fmt.Sprintf("nonexistent-%d", i))
|
||||
}
|
||||
|
||||
stats := backend.GetStats()
|
||||
|
||||
// Verify stats
|
||||
assert.Equal(t, int64(numItems), stats["sets"].(int64), "Sets should match")
|
||||
assert.Equal(t, int64(numItems/2), stats["hits"].(int64), "Hits should match")
|
||||
assert.Equal(t, int64(10), stats["misses"].(int64), "Misses should match")
|
||||
assert.Equal(t, int64(numItems), stats["size"].(int64), "Size should match")
|
||||
|
||||
// Verify hit rate
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
expectedHitRate := float64(numItems/2) / float64(numItems/2+10)
|
||||
assert.InDelta(t, expectedHitRate, hitRate, 0.01, "Hit rate should match")
|
||||
}
|
||||
|
||||
// BenchmarkShardedCache_Parallel benchmarks parallel access
|
||||
func BenchmarkShardedCache_Parallel(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024
|
||||
|
||||
backend, _ := NewMemoryBackend(config)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 10000; i++ {
|
||||
key := fmt.Sprintf("bench-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("bench-key-%d", i%10000)
|
||||
backend.Get(ctx, key)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkShardedCache_MixedOps benchmarks mixed operations
|
||||
func BenchmarkShardedCache_MixedOps(b *testing.B) {
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100000
|
||||
config.MaxMemoryBytes = 100 * 1024 * 1024
|
||||
|
||||
backend, _ := NewMemoryBackend(config)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("mixed-key-%d", i%1000)
|
||||
if i%3 == 0 {
|
||||
value := []byte(fmt.Sprintf("mixed-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
} else {
|
||||
backend.Get(ctx, key)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
+20
-30
@@ -45,21 +45,11 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
return nil, 0, false, err
|
||||
}
|
||||
|
||||
// Get the item directly to check TTL
|
||||
m.MemoryCacheBackend.mu.RLock()
|
||||
item, exists := m.MemoryCacheBackend.items[key]
|
||||
m.MemoryCacheBackend.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
var ttl time.Duration
|
||||
if !item.expiresAt.IsZero() {
|
||||
ttl = time.Until(item.expiresAt)
|
||||
if ttl < 0 {
|
||||
ttl = 0
|
||||
}
|
||||
// Get TTL using the TTL method
|
||||
ttl, ttlErr := m.MemoryCacheBackend.TTL(ctx, key)
|
||||
if ttlErr != nil {
|
||||
// If we can't get TTL, still return the value with 0 TTL
|
||||
ttl = 0
|
||||
}
|
||||
|
||||
// Convert interface{} to []byte
|
||||
@@ -68,8 +58,7 @@ func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
if bytes, ok := val.([]byte); ok {
|
||||
valueBytes = bytes
|
||||
} else {
|
||||
// If it's not already []byte, we might need to handle other types
|
||||
// For now, we'll just return an error
|
||||
// If it's not already []byte, return an error
|
||||
return nil, 0, false, ErrInvalidValue
|
||||
}
|
||||
}
|
||||
@@ -123,19 +112,20 @@ func (m *MemoryBackend) GetStats() map[string]interface{} {
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
"shard_count": m.MemoryCacheBackend.GetShardCount(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Vendored
+107
-11
@@ -431,39 +431,135 @@ func isRetryableError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SetMany stores multiple values in Redis (batch operation)
|
||||
// SetMany stores multiple values in Redis using pipelining for efficiency
|
||||
// This reduces N round-trips to a single round-trip
|
||||
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
// For simplicity, execute sequentially (can be optimized with pipelining later)
|
||||
for key, value := range items {
|
||||
if err := r.Set(ctx, key, value, ttl); err != nil {
|
||||
return err
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For single items, use regular Set
|
||||
if len(items) == 1 {
|
||||
for key, value := range items {
|
||||
return r.Set(ctx, key, value, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
// Queue all SET commands
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
ttlMillis := ttl.Milliseconds()
|
||||
|
||||
for key, value := range items {
|
||||
prefixedKey := r.prefixKey(key)
|
||||
|
||||
if ttl > 0 {
|
||||
if ttlMillis < 1000 {
|
||||
// Use PSETEX for sub-second TTLs
|
||||
pipeline.Queue("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
|
||||
} else {
|
||||
// Use SETEX for larger TTLs
|
||||
pipeline.Queue("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
|
||||
}
|
||||
} else {
|
||||
pipeline.Queue("SET", prefixedKey, string(value))
|
||||
}
|
||||
}
|
||||
|
||||
// Execute pipeline
|
||||
responses, err := pipeline.Execute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pipeline SetMany failed: %w", err)
|
||||
}
|
||||
|
||||
// Check responses for errors (each should be "OK")
|
||||
for i, resp := range responses {
|
||||
if resp == nil {
|
||||
continue
|
||||
}
|
||||
if str, ok := resp.(string); ok && str == "OK" {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("SetMany: unexpected response at index %d: %v", i, resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values from Redis
|
||||
// GetMany retrieves multiple values from Redis using pipelining for efficiency
|
||||
// This reduces N round-trips to a single round-trip
|
||||
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
result := make(map[string][]byte)
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
// For simplicity, execute sequentially
|
||||
for _, key := range keys {
|
||||
value, _, exists, err := r.Get(ctx, key)
|
||||
// For single key, use regular Get
|
||||
if len(keys) == 1 {
|
||||
result := make(map[string][]byte)
|
||||
value, _, exists, err := r.Get(ctx, keys[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
result[key] = value
|
||||
result[keys[0]] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
// Queue all GET commands
|
||||
prefixedKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
prefixedKeys[i] = r.prefixKey(key)
|
||||
pipeline.Queue("GET", prefixedKeys[i])
|
||||
}
|
||||
|
||||
// Execute pipeline
|
||||
responses, err := pipeline.Execute()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pipeline GetMany failed: %w", err)
|
||||
}
|
||||
|
||||
// Process responses
|
||||
result := make(map[string][]byte)
|
||||
for i, resp := range responses {
|
||||
if resp == nil {
|
||||
// Key doesn't exist
|
||||
r.misses.Add(1)
|
||||
continue
|
||||
}
|
||||
|
||||
value, err := RESPString(resp)
|
||||
if err != nil {
|
||||
// Invalid response, skip this key
|
||||
r.misses.Add(1)
|
||||
continue
|
||||
}
|
||||
|
||||
r.hits.Add(1)
|
||||
result[keys[i]] = []byte(value)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
||||
+461
@@ -0,0 +1,461 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// setupTestRedis creates a miniredis instance for testing
|
||||
func setupTestRedis(t *testing.T) (*miniredis.Miniredis, *RedisBackend) {
|
||||
t.Helper()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
mr.Close()
|
||||
})
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "test:",
|
||||
PoolSize: 5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
backend.Close()
|
||||
})
|
||||
|
||||
return mr, backend
|
||||
}
|
||||
|
||||
// TestPipeline_Basic tests basic pipeline functionality
|
||||
func TestPipeline_Basic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.Addr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("SingleCommand", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("SET", "single-key", "single-value")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 1)
|
||||
assert.Equal(t, "OK", responses[0])
|
||||
})
|
||||
|
||||
t.Run("MultipleCommands", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("SET", "key1", "value1")
|
||||
pipeline.Queue("SET", "key2", "value2")
|
||||
pipeline.Queue("SET", "key3", "value3")
|
||||
pipeline.Queue("GET", "key1")
|
||||
pipeline.Queue("GET", "key2")
|
||||
pipeline.Queue("GET", "key3")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 6)
|
||||
|
||||
// First 3 are SET responses
|
||||
assert.Equal(t, "OK", responses[0])
|
||||
assert.Equal(t, "OK", responses[1])
|
||||
assert.Equal(t, "OK", responses[2])
|
||||
|
||||
// Last 3 are GET responses
|
||||
assert.Equal(t, "value1", responses[3])
|
||||
assert.Equal(t, "value2", responses[4])
|
||||
assert.Equal(t, "value3", responses[5])
|
||||
})
|
||||
|
||||
t.Run("EmptyPipeline", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, responses)
|
||||
})
|
||||
|
||||
t.Run("NilResponses", func(t *testing.T) {
|
||||
pipeline := conn.NewPipeline()
|
||||
pipeline.Queue("GET", "nonexistent-key")
|
||||
|
||||
responses, err := pipeline.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, responses, 1)
|
||||
assert.Nil(t, responses[0])
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_SetMany tests pipelined SetMany
|
||||
func TestPipeline_SetMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetManyItems", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 10; i++ {
|
||||
items[fmt.Sprintf("setmany-key-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all items were set
|
||||
for key, expectedValue := range items {
|
||||
value, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key %s should exist", key)
|
||||
assert.Equal(t, expectedValue, value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SetManyEmpty", func(t *testing.T) {
|
||||
err := backend.SetMany(ctx, map[string][]byte{}, time.Minute)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetManySingleItem", func(t *testing.T) {
|
||||
items := map[string][]byte{
|
||||
"single-setmany": []byte("single-value"),
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
value, _, exists, err := backend.Get(ctx, "single-setmany")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("single-value"), value)
|
||||
})
|
||||
|
||||
t.Run("SetManyNoTTL", func(t *testing.T) {
|
||||
items := map[string][]byte{
|
||||
"nottl-key1": []byte("value1"),
|
||||
"nottl-key2": []byte("value2"),
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Keys should exist
|
||||
for key := range items {
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_GetMany tests pipelined GetMany
|
||||
func TestPipeline_GetMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("getmany-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("value-%d", i))
|
||||
err := backend.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("GetManyExisting", func(t *testing.T) {
|
||||
keys := make([]string, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
keys[i] = fmt.Sprintf("getmany-key-%d", i)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 10)
|
||||
|
||||
for i, key := range keys {
|
||||
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), results[key])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetManyMixed", func(t *testing.T) {
|
||||
keys := []string{
|
||||
"getmany-key-0", // exists
|
||||
"nonexistent-key-1", // doesn't exist
|
||||
"getmany-key-2", // exists
|
||||
"nonexistent-key-2", // doesn't exist
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only existing keys
|
||||
|
||||
assert.Equal(t, []byte("value-0"), results["getmany-key-0"])
|
||||
assert.Equal(t, []byte("value-2"), results["getmany-key-2"])
|
||||
assert.NotContains(t, results, "nonexistent-key-1")
|
||||
assert.NotContains(t, results, "nonexistent-key-2")
|
||||
})
|
||||
|
||||
t.Run("GetManyEmpty", func(t *testing.T) {
|
||||
results, err := backend.GetMany(ctx, []string{})
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, results)
|
||||
assert.Len(t, results, 0)
|
||||
})
|
||||
|
||||
t.Run("GetManySingleKey", func(t *testing.T) {
|
||||
results, err := backend.GetMany(ctx, []string{"getmany-key-5"})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, []byte("value-5"), results["getmany-key-5"])
|
||||
})
|
||||
|
||||
t.Run("GetManyAllNonexistent", func(t *testing.T) {
|
||||
keys := []string{
|
||||
"nonexistent-1",
|
||||
"nonexistent-2",
|
||||
"nonexistent-3",
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 0)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_LargeBatch tests pipelining with large batches
|
||||
func TestPipeline_LargeBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMany100Items", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 100; i++ {
|
||||
items[fmt.Sprintf("large-batch-%d", i)] = []byte(fmt.Sprintf("value-%d", i))
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify random samples
|
||||
for _, i := range []int{0, 25, 50, 75, 99} {
|
||||
key := fmt.Sprintf("large-batch-%d", i)
|
||||
value, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte(fmt.Sprintf("value-%d", i)), value)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetMany100Items", func(t *testing.T) {
|
||||
keys := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
keys[i] = fmt.Sprintf("large-batch-%d", i)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 100)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPipeline_Stats tests that stats are tracked correctly with pipelining
|
||||
func TestPipeline_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, backend := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Set some items
|
||||
items := map[string][]byte{
|
||||
"stats-key-1": []byte("value1"),
|
||||
"stats-key-2": []byte("value2"),
|
||||
}
|
||||
err := backend.SetMany(ctx, items, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get items (some exist, some don't)
|
||||
keys := []string{
|
||||
"stats-key-1",
|
||||
"stats-key-2",
|
||||
"stats-key-nonexistent",
|
||||
}
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
// Check stats
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
|
||||
assert.Equal(t, int64(2), hits, "Should have 2 hits")
|
||||
assert.Equal(t, int64(1), misses, "Should have 1 miss")
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_SetMany benchmarks SetMany with pipelining
|
||||
func BenchmarkPipeline_SetMany(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare items
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 100; i++ {
|
||||
items[fmt.Sprintf("bench-key-%d", i)] = []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = backend.SetMany(ctx, items, time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_GetMany benchmarks GetMany with pipelining
|
||||
func BenchmarkPipeline_GetMany(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pre-populate cache
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("bench-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("bench-value-%d", i))
|
||||
backend.Set(ctx, key, value, time.Hour)
|
||||
}
|
||||
|
||||
// Prepare keys
|
||||
keys := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
keys[i] = fmt.Sprintf("bench-key-%d", i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = backend.GetMany(ctx, keys)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPipeline_VsSequential benchmarks pipeline vs sequential operations
|
||||
func BenchmarkPipeline_VsSequential(b *testing.B) {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer mr.Close()
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: mr.Addr(),
|
||||
RedisPrefix: "bench:",
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare items
|
||||
items := make(map[string][]byte)
|
||||
keys := make([]string, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
key := fmt.Sprintf("compare-key-%d", i)
|
||||
keys[i] = key
|
||||
items[key] = []byte(fmt.Sprintf("compare-value-%d", i))
|
||||
}
|
||||
|
||||
b.Run("Pipelined-Set", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = backend.SetMany(ctx, items, time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Sequential-Set", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for key, value := range items {
|
||||
_ = backend.Set(ctx, key, value, time.Minute)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Pre-populate for get benchmarks
|
||||
_ = backend.SetMany(ctx, items, time.Hour)
|
||||
|
||||
b.Run("Pipelined-Get", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = backend.GetMany(ctx, keys)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Sequential-Get", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, key := range keys {
|
||||
_, _, _, _ = backend.Get(ctx, key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
+117
@@ -336,3 +336,120 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
_, err := conn.Do("PING")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Pipeline represents a Redis pipeline for batch operations
|
||||
// It queues multiple commands and executes them in a single round-trip
|
||||
type Pipeline struct {
|
||||
conn *RedisConn
|
||||
commands []pipelineCommand
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// pipelineCommand represents a single command in the pipeline
|
||||
type pipelineCommand struct {
|
||||
command string
|
||||
args []string
|
||||
}
|
||||
|
||||
// NewPipeline creates a new pipeline for the connection
|
||||
func (c *RedisConn) NewPipeline() *Pipeline {
|
||||
return &Pipeline{
|
||||
conn: c,
|
||||
commands: make([]pipelineCommand, 0, 16), // Pre-allocate for typical batch size
|
||||
}
|
||||
}
|
||||
|
||||
// Queue adds a command to the pipeline
|
||||
func (p *Pipeline) Queue(command string, args ...string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.commands = append(p.commands, pipelineCommand{
|
||||
command: command,
|
||||
args: args,
|
||||
})
|
||||
}
|
||||
|
||||
// Execute sends all queued commands and returns all responses
|
||||
// Returns a slice of responses in the same order as commands were queued
|
||||
func (p *Pipeline) Execute() ([]interface{}, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if len(p.commands) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if p.conn.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
p.conn.mu.Lock()
|
||||
defer p.conn.mu.Unlock()
|
||||
|
||||
// Set write timeout for all commands
|
||||
if p.conn.writeTimeout > 0 {
|
||||
// Use longer timeout for batch operations
|
||||
timeout := p.conn.writeTimeout * time.Duration(len(p.commands))
|
||||
if timeout > 30*time.Second {
|
||||
timeout = 30 * time.Second // Cap at 30 seconds
|
||||
}
|
||||
_ = p.conn.conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
// Write all commands (pipelining - send all before reading any responses)
|
||||
writer := NewRESPWriter(p.conn.conn)
|
||||
for _, cmd := range p.commands {
|
||||
cmdArgs := append([]string{cmd.command}, cmd.args...)
|
||||
if err := writer.WriteCommand(cmdArgs...); err != nil {
|
||||
writer.Release()
|
||||
p.conn.closed.Store(true)
|
||||
return nil, fmt.Errorf("pipeline write error: %w", err)
|
||||
}
|
||||
}
|
||||
writer.Release()
|
||||
|
||||
// Set read timeout for all responses
|
||||
if p.conn.readTimeout > 0 {
|
||||
timeout := p.conn.readTimeout * time.Duration(len(p.commands))
|
||||
if timeout > 30*time.Second {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
_ = p.conn.conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
// Read all responses
|
||||
responses := make([]interface{}, len(p.commands))
|
||||
reader := NewRESPReader(p.conn.conn)
|
||||
defer reader.Release()
|
||||
|
||||
for i := range p.commands {
|
||||
resp, err := reader.ReadResponse()
|
||||
if err != nil {
|
||||
// For nil responses, store nil instead of erroring
|
||||
if errors.Is(err, ErrNilResponse) {
|
||||
responses[i] = nil
|
||||
continue
|
||||
}
|
||||
p.conn.closed.Store(true)
|
||||
return responses[:i], fmt.Errorf("pipeline read error at command %d: %w", i, err)
|
||||
}
|
||||
responses[i] = resp
|
||||
}
|
||||
|
||||
return responses, nil
|
||||
}
|
||||
|
||||
// Clear resets the pipeline for reuse
|
||||
func (p *Pipeline) Clear() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.commands = p.commands[:0]
|
||||
}
|
||||
|
||||
// Len returns the number of queued commands
|
||||
func (p *Pipeline) Len() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return len(p.commands)
|
||||
}
|
||||
|
||||
+183
@@ -0,0 +1,183 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SingleflightCache wraps a CacheBackend with singleflight deduplication
|
||||
// to prevent thundering herd problems when multiple concurrent requests
|
||||
// try to fetch the same uncached key.
|
||||
type SingleflightCache struct {
|
||||
backend CacheBackend
|
||||
mu sync.Mutex
|
||||
calls map[string]*singleflightCall
|
||||
|
||||
// Metrics
|
||||
deduplicatedCalls atomic.Int64
|
||||
totalCalls atomic.Int64
|
||||
}
|
||||
|
||||
// singleflightCall represents an in-flight or completed fetch call
|
||||
type singleflightCall struct {
|
||||
wg sync.WaitGroup
|
||||
val []byte
|
||||
ttl time.Duration
|
||||
err error
|
||||
done bool
|
||||
}
|
||||
|
||||
// NewSingleflightCache creates a new singleflight-wrapped cache backend
|
||||
func NewSingleflightCache(backend CacheBackend) *SingleflightCache {
|
||||
return &SingleflightCache{
|
||||
backend: backend,
|
||||
calls: make(map[string]*singleflightCall),
|
||||
}
|
||||
}
|
||||
|
||||
// Fetcher is a function type that fetches data when cache misses
|
||||
type Fetcher func(ctx context.Context) (value []byte, ttl time.Duration, err error)
|
||||
|
||||
// GetOrFetch retrieves a value from cache or calls the fetcher exactly once
|
||||
// per key when there's a cache miss. Concurrent calls for the same key will
|
||||
// wait for the first call to complete and share its result.
|
||||
func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher Fetcher) ([]byte, error) {
|
||||
s.totalCalls.Add(1)
|
||||
|
||||
// Try cache first
|
||||
value, _, exists, err := s.backend.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Cache miss - use singleflight
|
||||
s.mu.Lock()
|
||||
|
||||
// Check if there's already an in-flight call for this key
|
||||
if call, ok := s.calls[key]; ok {
|
||||
s.mu.Unlock()
|
||||
s.deduplicatedCalls.Add(1)
|
||||
|
||||
// Wait for the in-flight call to complete
|
||||
call.wg.Wait()
|
||||
|
||||
// Check context cancellation
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
return call.val, call.err
|
||||
}
|
||||
|
||||
// Create new call
|
||||
call := &singleflightCall{}
|
||||
call.wg.Add(1)
|
||||
s.calls[key] = call
|
||||
s.mu.Unlock()
|
||||
|
||||
// Execute the fetcher
|
||||
call.val, call.ttl, call.err = fetcher(ctx)
|
||||
call.done = true
|
||||
|
||||
// If successful, store in cache
|
||||
if call.err == nil && call.val != nil {
|
||||
// Use a background context for cache storage to ensure it completes
|
||||
// even if the original context is cancelled
|
||||
storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = s.backend.Set(storeCtx, key, call.val, call.ttl)
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Signal waiting goroutines
|
||||
call.wg.Done()
|
||||
|
||||
// Clean up the call from the map after a short delay
|
||||
// This allows late arrivals to still benefit from the result
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
s.mu.Lock()
|
||||
if c, ok := s.calls[key]; ok && c == call {
|
||||
delete(s.calls, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
return call.val, call.err
|
||||
}
|
||||
|
||||
// Get retrieves a value from the underlying cache backend
|
||||
func (s *SingleflightCache) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
return s.backend.Get(ctx, key)
|
||||
}
|
||||
|
||||
// Set stores a value in the underlying cache backend
|
||||
func (s *SingleflightCache) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
return s.backend.Set(ctx, key, value, ttl)
|
||||
}
|
||||
|
||||
// Delete removes a key from the underlying cache backend
|
||||
func (s *SingleflightCache) Delete(ctx context.Context, key string) (bool, error) {
|
||||
return s.backend.Delete(ctx, key)
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the underlying cache backend
|
||||
func (s *SingleflightCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return s.backend.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Clear removes all keys from the underlying cache backend
|
||||
func (s *SingleflightCache) Clear(ctx context.Context) error {
|
||||
return s.backend.Clear(ctx)
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics including singleflight metrics
|
||||
func (s *SingleflightCache) GetStats() map[string]interface{} {
|
||||
stats := s.backend.GetStats()
|
||||
|
||||
// Add singleflight-specific stats
|
||||
totalCalls := s.totalCalls.Load()
|
||||
deduped := s.deduplicatedCalls.Load()
|
||||
|
||||
stats["singleflight_total_calls"] = totalCalls
|
||||
stats["singleflight_deduplicated"] = deduped
|
||||
if totalCalls > 0 {
|
||||
stats["singleflight_dedup_rate"] = float64(deduped) / float64(totalCalls)
|
||||
} else {
|
||||
stats["singleflight_dedup_rate"] = float64(0)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
stats["singleflight_inflight"] = len(s.calls)
|
||||
s.mu.Unlock()
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Close shuts down the cache backend
|
||||
func (s *SingleflightCache) Close() error {
|
||||
return s.backend.Close()
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy
|
||||
func (s *SingleflightCache) Ping(ctx context.Context) error {
|
||||
return s.backend.Ping(ctx)
|
||||
}
|
||||
|
||||
// GetBackend returns the underlying cache backend
|
||||
func (s *SingleflightCache) GetBackend() CacheBackend {
|
||||
return s.backend
|
||||
}
|
||||
|
||||
// ResetStats resets the singleflight statistics
|
||||
func (s *SingleflightCache) ResetStats() {
|
||||
s.totalCalls.Store(0)
|
||||
s.deduplicatedCalls.Store(0)
|
||||
}
|
||||
|
||||
// Ensure SingleflightCache implements CacheBackend
|
||||
var _ CacheBackend = (*SingleflightCache)(nil)
|
||||
+510
@@ -0,0 +1,510 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSingleflightCache_BasicGetOrFetch tests basic GetOrFetch functionality
|
||||
func TestSingleflightCache_BasicGetOrFetch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("CacheHit", func(t *testing.T) {
|
||||
key := "existing-key"
|
||||
value := []byte("existing-value")
|
||||
|
||||
// Pre-populate cache
|
||||
err := cache.Set(ctx, key, value, time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
var fetchCalled bool
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCalled = true
|
||||
return []byte("fetched-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, value, result)
|
||||
assert.False(t, fetchCalled, "Fetcher should not be called on cache hit")
|
||||
})
|
||||
|
||||
t.Run("CacheMiss", func(t *testing.T) {
|
||||
key := "missing-key"
|
||||
expectedValue := []byte("fetched-value")
|
||||
|
||||
var fetchCalled bool
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCalled = true
|
||||
return expectedValue, time.Minute, nil
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedValue, result)
|
||||
assert.True(t, fetchCalled, "Fetcher should be called on cache miss")
|
||||
|
||||
// Verify value was stored in cache
|
||||
cached, _, exists, err := cache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, cached)
|
||||
})
|
||||
|
||||
t.Run("FetcherError", func(t *testing.T) {
|
||||
key := "error-key"
|
||||
expectedErr := errors.New("fetch failed")
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return nil, 0, expectedErr
|
||||
}
|
||||
|
||||
result, err := cache.GetOrFetch(ctx, key, fetcher)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, expectedErr, err)
|
||||
assert.Nil(t, result)
|
||||
|
||||
// Verify nothing was stored in cache
|
||||
_, _, exists, err := cache.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSingleflightCache_Deduplication tests that concurrent calls are deduplicated
|
||||
func TestSingleflightCache_Deduplication(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
key := "dedup-key"
|
||||
expectedValue := []byte("dedup-value")
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
// Simulate slow fetch
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return expectedValue, time.Minute, nil
|
||||
}
|
||||
|
||||
// Launch multiple concurrent requests
|
||||
concurrency := 10
|
||||
var wg sync.WaitGroup
|
||||
results := make([][]byte, concurrency)
|
||||
errs := make([]error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx], errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests got the same result
|
||||
for i := 0; i < concurrency; i++ {
|
||||
assert.NoError(t, errs[i])
|
||||
assert.Equal(t, expectedValue, results[i])
|
||||
}
|
||||
|
||||
// Verify fetcher was only called once
|
||||
assert.Equal(t, int32(1), fetchCount.Load(), "Fetcher should only be called once")
|
||||
|
||||
// Verify deduplication stats
|
||||
stats := cache.GetStats()
|
||||
deduped := stats["singleflight_deduplicated"].(int64)
|
||||
assert.Equal(t, int64(concurrency-1), deduped, "Should have deduplicated N-1 calls")
|
||||
}
|
||||
|
||||
// TestSingleflightCache_DifferentKeys tests that different keys can fetch in parallel
|
||||
func TestSingleflightCache_DifferentKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetchStarted := make(chan struct{}, 3)
|
||||
fetchComplete := make(chan struct{})
|
||||
|
||||
fetcher := func(key string) Fetcher {
|
||||
return func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
fetchStarted <- struct{}{}
|
||||
<-fetchComplete // Wait for signal
|
||||
return []byte("value-" + key), time.Minute, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Launch concurrent requests for different keys
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 3; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
key := fmt.Sprintf("key-%d", idx)
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher(key))
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all fetches to start
|
||||
for i := 0; i < 3; i++ {
|
||||
<-fetchStarted
|
||||
}
|
||||
|
||||
// All 3 fetches should be running in parallel
|
||||
assert.Equal(t, int32(3), fetchCount.Load(), "All three fetches should run in parallel")
|
||||
|
||||
// Release all fetches
|
||||
close(fetchComplete)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ContextCancellation tests context cancellation
|
||||
func TestSingleflightCache_ContextCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
key := "cancel-key"
|
||||
fetchStarted := make(chan struct{})
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
close(fetchStarted)
|
||||
// Simulate slow fetch
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
// Start first request with long timeout
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx := context.Background()
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}()
|
||||
|
||||
// Wait for fetch to start
|
||||
<-fetchStarted
|
||||
|
||||
// Start second request with short timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = cache.GetOrFetch(ctx, key, fetcher)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ErrorPropagation tests that errors are properly propagated
|
||||
func TestSingleflightCache_ErrorPropagation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
key := "error-prop-key"
|
||||
expectedErr := errors.New("intentional error")
|
||||
|
||||
var fetchCount atomic.Int32
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
fetchCount.Add(1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return nil, 0, expectedErr
|
||||
}
|
||||
|
||||
// Launch multiple concurrent requests
|
||||
concurrency := 5
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, errs[idx] = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all requests got the same error
|
||||
for i := 0; i < concurrency; i++ {
|
||||
assert.Error(t, errs[i])
|
||||
assert.Equal(t, expectedErr, errs[i])
|
||||
}
|
||||
|
||||
// Verify fetcher was only called once
|
||||
assert.Equal(t, int32(1), fetchCount.Load())
|
||||
}
|
||||
|
||||
// TestSingleflightCache_PassthroughMethods tests that passthrough methods work
|
||||
func TestSingleflightCache_PassthroughMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Set", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "set-key", []byte("set-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, _, exists, err := cache.Get(ctx, "set-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("set-value"), val)
|
||||
})
|
||||
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "get-key", []byte("get-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
val, ttl, exists, err := cache.Get(ctx, "get-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("get-value"), val)
|
||||
assert.Greater(t, ttl, time.Duration(0))
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "delete-key", []byte("delete-value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := cache.Delete(ctx, "delete-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := cache.Exists(ctx, "delete-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
exists, err := cache.Exists(ctx, "nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = cache.Set(ctx, "exists-key", []byte("value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = cache.Exists(ctx, "exists-key")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
err := cache.Set(ctx, "clear-key", []byte("value"), time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err := cache.Exists(ctx, "clear-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
err := cache.Ping(ctx)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSingleflightCache_Stats tests statistics tracking
|
||||
func TestSingleflightCache_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Make some calls
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = cache.GetOrFetch(ctx, "stats-key", fetcher)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
stats := cache.GetStats()
|
||||
|
||||
// Check singleflight stats exist
|
||||
assert.Contains(t, stats, "singleflight_total_calls")
|
||||
assert.Contains(t, stats, "singleflight_deduplicated")
|
||||
assert.Contains(t, stats, "singleflight_dedup_rate")
|
||||
assert.Contains(t, stats, "singleflight_inflight")
|
||||
|
||||
// Verify values
|
||||
assert.Equal(t, int64(5), stats["singleflight_total_calls"])
|
||||
assert.Equal(t, int64(4), stats["singleflight_deduplicated"])
|
||||
|
||||
// Also check underlying backend stats are included
|
||||
assert.Contains(t, stats, "hits")
|
||||
assert.Contains(t, stats, "misses")
|
||||
}
|
||||
|
||||
// TestSingleflightCache_ResetStats tests stats reset
|
||||
func TestSingleflightCache_ResetStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return []byte("value"), time.Minute, nil
|
||||
}
|
||||
|
||||
// Make some calls
|
||||
_, _ = cache.GetOrFetch(ctx, "key1", fetcher)
|
||||
_, _ = cache.GetOrFetch(ctx, "key2", fetcher)
|
||||
|
||||
stats := cache.GetStats()
|
||||
assert.Greater(t, stats["singleflight_total_calls"].(int64), int64(0))
|
||||
|
||||
// Reset stats
|
||||
cache.ResetStats()
|
||||
|
||||
stats = cache.GetStats()
|
||||
assert.Equal(t, int64(0), stats["singleflight_total_calls"])
|
||||
assert.Equal(t, int64(0), stats["singleflight_deduplicated"])
|
||||
}
|
||||
|
||||
// TestSingleflightCache_GetBackend tests GetBackend method
|
||||
func TestSingleflightCache_GetBackend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
assert.Equal(t, backend, cache.GetBackend())
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_Sequential benchmarks sequential access
|
||||
func BenchmarkSingleflightCache_Sequential(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := fmt.Sprintf("key-%d", i%100)
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_Concurrent benchmarks concurrent access
|
||||
func BenchmarkSingleflightCache_Concurrent(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(time.Millisecond) // Simulate slow fetch
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("key-%d", i%10) // Only 10 unique keys to force deduplication
|
||||
_, _ = cache.GetOrFetch(ctx, key, fetcher)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkSingleflightCache_HighContention benchmarks high contention scenario
|
||||
func BenchmarkSingleflightCache_HighContention(b *testing.B) {
|
||||
backend, _ := NewMemoryBackend(DefaultConfig())
|
||||
defer backend.Close()
|
||||
|
||||
cache := NewSingleflightCache(backend)
|
||||
|
||||
ctx := context.Background()
|
||||
fetcher := func(ctx context.Context) ([]byte, time.Duration, error) {
|
||||
time.Sleep(10 * time.Millisecond) // Slow fetch to force queuing
|
||||
return []byte("benchmark-value"), time.Minute, nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
// All goroutines hit the same key
|
||||
_, _ = cache.GetOrFetch(ctx, "hot-key", fetcher)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user