Add redis support for distributed caching (#83)

* Add redis support for distributed caching

* Move towards the self-provided Redis connection pool and RESP protocol implementation.
Official redis client library won't work with yaegi.

* fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* fixup! fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi.

* ... and another all nighter.

* fixup! ... and another all nighter.

* fixup! fixup! ... and another all nighter.

* fixup! fixup! fixup! ... and another all nighter.

* Resolve issue #85 by adding ability to set custom claims in JWT tokens

* Remove redundant validation in auth middleware ( issue #89 )

* Add ability to set cookie prefix for session cookies ( #87 )

* fixup! Add ability to set cookie prefix for session cookies ( #87 )

* Add ability to set cookie max age - issue #91

* Potential fix for code scanning alert no. 10: Size computation for allocation may overflow

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

* fixup! Merge main into 0.8.0-redis: resolve conflicts

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
2025-11-30 02:18:46 +00:00
committed by GitHub
parent 5fcbd54955
commit e64fc7f730
318 changed files with 100989 additions and 948 deletions
+90
View File
@@ -0,0 +1,90 @@
package backends
import "time"
// BackendType represents the type of cache backend
type BackendType string
const (
BackendTypeMemory BackendType = "memory"
BackendTypeRedis BackendType = "redis"
BackendTypeHybrid BackendType = "hybrid"
// Aliases for backward compatibility
TypeMemory BackendType = "memory"
TypeRedis BackendType = "redis"
TypeHybrid BackendType = "hybrid"
)
// Config provides common configuration for cache backends
type Config struct {
// Type specifies the backend type
Type BackendType
// Memory backend settings
MaxSize int
MaxMemoryBytes int64
CleanupInterval time.Duration
// Redis backend settings
RedisAddr string
RedisPassword string
RedisDB int
RedisPrefix string
PoolSize int
// Hybrid backend settings
L1Config *Config // Memory cache (L1)
L2Config *Config // Redis cache (L2)
AsyncWrites bool // Write to L2 asynchronously
// Resilience settings
EnableCircuitBreaker bool
EnableHealthCheck bool
HealthCheckInterval time.Duration
// Metrics
EnableMetrics bool
}
// DefaultConfig returns a default configuration for in-memory caching
func DefaultConfig() *Config {
return &Config{
Type: BackendTypeMemory,
MaxSize: 1000,
MaxMemoryBytes: 50 * 1024 * 1024, // 50MB
CleanupInterval: 5 * time.Minute,
EnableMetrics: true,
}
}
// DefaultRedisConfig returns a default configuration for Redis caching
func DefaultRedisConfig(addr string) *Config {
return &Config{
Type: BackendTypeRedis,
RedisAddr: addr,
RedisDB: 0,
RedisPrefix: "traefikoidc:",
PoolSize: 10,
EnableCircuitBreaker: true,
EnableHealthCheck: true,
HealthCheckInterval: 30 * time.Second,
EnableMetrics: true,
}
}
// DefaultHybridConfig returns a default configuration for hybrid caching
func DefaultHybridConfig(redisAddr string) *Config {
return &Config{
Type: BackendTypeHybrid,
L1Config: &Config{
Type: BackendTypeMemory,
MaxSize: 500,
MaxMemoryBytes: 10 * 1024 * 1024, // 10MB for L1
CleanupInterval: 1 * time.Minute,
},
L2Config: DefaultRedisConfig(redisAddr),
AsyncWrites: true,
EnableMetrics: true,
}
}
+59
View File
@@ -0,0 +1,59 @@
//go:build !yaegi
package backends
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestDefaultHybridConfig verifies the default hybrid configuration
func TestDefaultHybridConfig(t *testing.T) {
redisAddr := "localhost:6379"
config := DefaultHybridConfig(redisAddr)
require.NotNil(t, config)
// Verify top-level config
assert.Equal(t, BackendTypeHybrid, config.Type)
assert.True(t, config.AsyncWrites)
assert.True(t, config.EnableMetrics)
// Verify L1 (memory) config
require.NotNil(t, config.L1Config)
assert.Equal(t, BackendTypeMemory, config.L1Config.Type)
assert.Equal(t, 500, config.L1Config.MaxSize)
assert.Equal(t, int64(10*1024*1024), config.L1Config.MaxMemoryBytes) // 10MB
assert.Equal(t, 1*time.Minute, config.L1Config.CleanupInterval)
// Verify L2 (Redis) config exists
require.NotNil(t, config.L2Config)
assert.Equal(t, BackendTypeRedis, config.L2Config.Type)
}
func TestDefaultHybridConfig_DifferentRedisAddr(t *testing.T) {
tests := []struct {
name string
redisAddr string
}{
{"localhost", "localhost:6379"},
{"remote host", "redis.example.com:6379"},
{"IP address", "192.168.1.100:6379"},
{"custom port", "localhost:6380"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := DefaultHybridConfig(tt.redisAddr)
require.NotNil(t, config)
assert.Equal(t, BackendTypeHybrid, config.Type)
assert.NotNil(t, config.L1Config)
assert.NotNil(t, config.L2Config)
})
}
}
+38
View File
@@ -0,0 +1,38 @@
package backends
import "errors"
var (
// ErrBackendClosed is returned when operating on a closed backend
ErrBackendClosed = errors.New("cache backend is closed")
// ErrKeyNotFound is returned when a key doesn't exist
ErrKeyNotFound = errors.New("key not found")
// ErrCacheMiss indicates the requested key was not found in the cache
ErrCacheMiss = errors.New("cache miss")
// ErrBackendUnavailable indicates the cache backend is not available
ErrBackendUnavailable = errors.New("cache backend unavailable")
// ErrInvalidValue indicates the cached value is invalid or corrupted
ErrInvalidValue = errors.New("invalid cached value")
// ErrInvalidTTL is returned when TTL is invalid
ErrInvalidTTL = errors.New("invalid TTL")
// ErrConnectionFailed is returned when connection fails
ErrConnectionFailed = errors.New("connection failed")
// ErrCircuitOpen is returned when circuit breaker is open
ErrCircuitOpen = errors.New("circuit breaker is open")
// ErrTimeout is returned when operation times out
ErrTimeout = errors.New("operation timeout")
// ErrSerializationFailed is returned when serialization fails
ErrSerializationFailed = errors.New("serialization failed")
// ErrDeserializationFailed is returned when deserialization fails
ErrDeserializationFailed = errors.New("deserialization failed")
)
+695
View File
@@ -0,0 +1,695 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
)
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
// It provides automatic failover, async writes for non-critical data, and optimized read paths
type HybridBackend struct {
primary CacheBackend // L1: Memory cache for fast access
secondary CacheBackend // L2: Redis cache for distributed access
// Configuration
syncWriteCacheTypes map[string]bool // Which cache types require synchronous writes
asyncWriteBuffer chan *asyncWriteItem
// Metrics
l1Hits atomic.Int64
l2Hits atomic.Int64
misses atomic.Int64
l1Writes atomic.Int64
l2Writes atomic.Int64
errors atomic.Int64
// Fallback tracking
fallbackMode atomic.Bool // True when operating in degraded mode (L1 only)
lastL2Error atomic.Value // Stores last L2 error timestamp
// Lifecycle
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
// Logging
logger Logger
}
// asyncWriteItem represents an async write operation
type asyncWriteItem struct {
key string
value []byte
ttl time.Duration
ctx context.Context
}
// Logger interface for structured logging
type Logger interface {
Debugf(format string, args ...interface{})
Infof(format string, args ...interface{})
Warnf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// defaultLogger provides a basic logger implementation
type defaultLogger struct {
*log.Logger
}
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
l.Printf("[DEBUG] "+format, args...)
}
func (l *defaultLogger) Infof(format string, args ...interface{}) {
l.Printf("[INFO] "+format, args...)
}
func (l *defaultLogger) Warnf(format string, args ...interface{}) {
l.Printf("[WARN] "+format, args...)
}
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
l.Printf("[ERROR] "+format, args...)
}
// HybridConfig provides configuration for the hybrid backend
type HybridConfig struct {
Primary CacheBackend
Secondary CacheBackend
SyncWriteCacheTypes map[string]bool // Cache types requiring synchronous L2 writes
AsyncBufferSize int
Logger Logger
}
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if config.Primary == nil {
return nil, fmt.Errorf("primary (L1) backend is required")
}
if config.Secondary == nil {
return nil, fmt.Errorf("secondary (L2) backend is required")
}
if config.Logger == nil {
config.Logger = &defaultLogger{Logger: log.New(log.Writer(), "[HybridCache] ", log.LstdFlags)}
}
if config.AsyncBufferSize <= 0 {
config.AsyncBufferSize = 1000
}
// Default critical cache types that require synchronous writes
if config.SyncWriteCacheTypes == nil {
config.SyncWriteCacheTypes = map[string]bool{
"blacklist": true, // Token blacklist must be immediately consistent
"token": true, // Token validation is critical
}
}
ctx, cancel := context.WithCancel(context.Background())
h := &HybridBackend{
primary: config.Primary,
secondary: config.Secondary,
syncWriteCacheTypes: config.SyncWriteCacheTypes,
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
ctx: ctx,
cancel: cancel,
logger: config.Logger,
}
// Start async write worker
h.wg.Add(1)
go h.asyncWriteWorker()
// Start health monitoring
h.wg.Add(1)
go h.healthMonitor()
h.logger.Infof("HybridBackend initialized with L1 (memory) and L2 (Redis) tiers")
h.logger.Infof("Sync write cache types: %v", config.SyncWriteCacheTypes)
h.logger.Infof("Async write buffer size: %d", config.AsyncBufferSize)
return h, nil
}
// Set stores a value in both L1 and L2 caches
func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
// Always write to L1 first (synchronous)
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
h.errors.Add(1)
h.logger.Warnf("Failed to write to L1 cache: %v", err)
// Continue to try L2 even if L1 fails
} else {
h.l1Writes.Add(1)
}
// Check if we're in fallback mode
if h.fallbackMode.Load() {
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", key)
return nil // Don't fail the operation if L2 is down
}
// Determine if this should be a sync or async write based on cache type
cacheType := h.extractCacheType(key)
requiresSync := h.syncWriteCacheTypes[cacheType]
if requiresSync {
// Synchronous write for critical cache types
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
h.errors.Add(1)
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", key, err)
h.recordL2Error()
// Don't fail the operation - L1 write succeeded
return nil
}
h.l2Writes.Add(1)
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", key)
} else {
// Asynchronous write for non-critical cache types
select {
case h.asyncWriteBuffer <- &asyncWriteItem{
key: key,
value: value,
ttl: ttl,
ctx: ctx,
}:
h.logger.Debugf("Queued async write to L2 for key: %s", key)
default:
// Buffer is full, log and continue
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", key)
h.errors.Add(1)
}
}
return nil
}
// Get retrieves a value from cache, checking L1 first, then L2
func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
// Try L1 first
value, ttl, exists, err := h.primary.Get(ctx, key)
if err != nil {
h.errors.Add(1)
h.logger.Debugf("L1 get error for key %s: %v", key, err)
}
if exists {
h.l1Hits.Add(1)
return value, ttl, true, nil
}
// Check if we're in fallback mode
if h.fallbackMode.Load() {
h.misses.Add(1)
return nil, 0, false, nil
}
// Try L2
value, ttl, exists, err = h.secondary.Get(ctx, key)
if err != nil {
h.errors.Add(1)
h.logger.Debugf("L2 get error for key %s: %v", key, err)
h.recordL2Error()
h.misses.Add(1)
return nil, 0, false, nil // Don't propagate L2 errors
}
if !exists {
h.misses.Add(1)
return nil, 0, false, nil
}
h.l2Hits.Add(1)
// Populate L1 cache with value from L2 (write-through on read)
// Use goroutine to avoid blocking the read path
go func() {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
} else {
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
}
}()
return value, ttl, true, nil
}
// Delete removes a key from both L1 and L2 caches
func (h *HybridBackend) Delete(ctx context.Context, key string) (bool, error) {
var deleted bool
// Delete from L1
if d, err := h.primary.Delete(ctx, key); err != nil {
h.logger.Debugf("Failed to delete from L1 cache: %v", err)
} else if d {
deleted = true
}
// Delete from L2 if not in fallback mode
if !h.fallbackMode.Load() {
if d, err := h.secondary.Delete(ctx, key); err != nil {
h.logger.Debugf("Failed to delete from L2 cache: %v", err)
h.recordL2Error()
} else if d {
deleted = true
}
}
return deleted, nil
}
// Exists checks if a key exists in either cache
func (h *HybridBackend) Exists(ctx context.Context, key string) (bool, error) {
// Check L1 first
if exists, err := h.primary.Exists(ctx, key); err == nil && exists {
return true, nil
}
// Check L2 if not in fallback mode
if !h.fallbackMode.Load() {
if exists, err := h.secondary.Exists(ctx, key); err == nil && exists {
return true, nil
}
}
return false, nil
}
// Clear removes all keys from both caches
func (h *HybridBackend) Clear(ctx context.Context) error {
var lastErr error
// Clear L1
if err := h.primary.Clear(ctx); err != nil {
h.logger.Errorf("Failed to clear L1 cache: %v", err)
lastErr = err
}
// Clear L2 if not in fallback mode
if !h.fallbackMode.Load() {
if err := h.secondary.Clear(ctx); err != nil {
h.logger.Errorf("Failed to clear L2 cache: %v", err)
h.recordL2Error()
lastErr = err
}
}
return lastErr
}
// GetStats returns statistics for the hybrid cache
func (h *HybridBackend) GetStats() map[string]interface{} {
l1Hits := h.l1Hits.Load()
l2Hits := h.l2Hits.Load()
misses := h.misses.Load()
total := l1Hits + l2Hits + misses
stats := map[string]interface{}{
"type": TypeHybrid,
"l1_hits": l1Hits,
"l2_hits": l2Hits,
"misses": misses,
"total": total,
"l1_writes": h.l1Writes.Load(),
"l2_writes": h.l2Writes.Load(),
"errors": h.errors.Load(),
"fallback_mode": h.fallbackMode.Load(),
}
if total > 0 {
stats["l1_hit_rate"] = float64(l1Hits) / float64(total)
stats["l2_hit_rate"] = float64(l2Hits) / float64(total)
stats["overall_hit_rate"] = float64(l1Hits+l2Hits) / float64(total)
}
// Add sub-backend stats
stats["l1_stats"] = h.primary.GetStats()
stats["l2_stats"] = h.secondary.GetStats()
// Add last L2 error time if available
if lastErr := h.lastL2Error.Load(); lastErr != nil {
if t, ok := lastErr.(time.Time); ok {
stats["last_l2_error"] = t.Format(time.RFC3339)
stats["seconds_since_l2_error"] = time.Since(t).Seconds()
}
}
return stats
}
// Ping checks if both backends are healthy
func (h *HybridBackend) Ping(ctx context.Context) error {
// Check L1
if err := h.primary.Ping(ctx); err != nil {
return fmt.Errorf("L1 ping failed: %w", err)
}
// Check L2 (but don't fail if it's down)
if err := h.secondary.Ping(ctx); err != nil {
h.logger.Warnf("L2 ping failed: %v", err)
h.recordL2Error()
// Don't return error - we can operate with L1 only
} else {
// L2 is healthy, clear fallback mode if it was set
if h.fallbackMode.CompareAndSwap(true, false) {
h.logger.Infof("L2 backend recovered, exiting fallback mode")
}
}
return nil
}
// Close shuts down the hybrid backend
func (h *HybridBackend) Close() error {
// Cancel context to stop workers
h.cancel()
// Close async write channel
close(h.asyncWriteBuffer)
// Wait for workers to finish with timeout
done := make(chan struct{})
go func() {
h.wg.Wait()
close(done)
}()
select {
case <-done:
// Workers finished
case <-time.After(5 * time.Second):
h.logger.Warnf("Timeout waiting for workers to finish")
}
var lastErr error
// Close backends
if err := h.primary.Close(); err != nil {
h.logger.Errorf("Failed to close L1 backend: %v", err)
lastErr = err
}
if err := h.secondary.Close(); err != nil {
h.logger.Errorf("Failed to close L2 backend: %v", err)
lastErr = err
}
h.logger.Infof("HybridBackend closed")
return lastErr
}
// GetMany retrieves multiple values efficiently
func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
results := make(map[string][]byte, len(keys))
missingKeys := make([]string, 0)
// Try L1 first for all keys
for _, key := range keys {
if value, _, exists, _ := h.primary.Get(ctx, key); exists {
results[key] = value
h.l1Hits.Add(1)
} else {
missingKeys = append(missingKeys, key)
}
}
// If all found in L1 or in fallback mode, return
if len(missingKeys) == 0 || h.fallbackMode.Load() {
return results, nil
}
// Try L2 for missing keys using batch operation if available
if batcher, ok := h.secondary.(interface {
GetMany(context.Context, []string) (map[string][]byte, error)
}); ok {
l2Results, err := batcher.GetMany(ctx, missingKeys)
if err != nil {
h.logger.Debugf("L2 batch get error: %v", err)
h.recordL2Error()
} else {
for key, value := range l2Results {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
}(key, value)
}
}
} else {
// Fallback to individual gets
for _, key := range missingKeys {
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte, t time.Duration) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, t)
}(key, value, ttl)
}
}
}
// Count misses for keys not found anywhere
for _, key := range keys {
if _, found := results[key]; !found {
h.misses.Add(1)
}
}
return results, nil
}
// SetMany stores multiple key-value pairs efficiently
func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
if len(items) == 0 {
return nil
}
// Write to L1 first
for key, value := range items {
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to write to L1 in batch: %v", err)
} else {
h.l1Writes.Add(1)
}
}
// Skip L2 if in fallback mode
if h.fallbackMode.Load() {
return nil
}
// Check if L2 supports batch operations
if batcher, ok := h.secondary.(interface {
SetMany(context.Context, map[string][]byte, time.Duration) error
}); ok {
if err := batcher.SetMany(ctx, items, ttl); err != nil {
h.logger.Warnf("Failed to batch write to L2: %v", err)
h.recordL2Error()
} else {
h.l2Writes.Add(int64(len(items)))
}
} else {
// Fallback to individual sets
for key, value := range items {
cacheType := h.extractCacheType(key)
if h.syncWriteCacheTypes[cacheType] {
// Sync write for critical types
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to write to L2: %v", err)
h.recordL2Error()
} else {
h.l2Writes.Add(1)
}
} else {
// Async write for non-critical types
select {
case h.asyncWriteBuffer <- &asyncWriteItem{
key: key,
value: value,
ttl: ttl,
ctx: ctx,
}:
// Queued
default:
h.logger.Warnf("Async buffer full for batch write")
}
}
}
}
return nil
}
// asyncWriteWorker processes asynchronous writes to L2
func (h *HybridBackend) asyncWriteWorker() {
defer h.wg.Done()
for {
select {
case <-h.ctx.Done():
// Drain remaining items with best effort
for len(h.asyncWriteBuffer) > 0 {
select {
case item := <-h.asyncWriteBuffer:
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
_ = h.secondary.Set(ctx, item.key, item.value, item.ttl)
cancel()
default:
return
}
}
return
case item, ok := <-h.asyncWriteBuffer:
if !ok {
return
}
// Skip if in fallback mode
if h.fallbackMode.Load() {
continue
}
// Perform the write with a timeout
writeCtx, cancel := context.WithTimeout(item.ctx, 500*time.Millisecond)
if err := h.secondary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
h.errors.Add(1)
h.logger.Debugf("Async write to L2 failed for key %s: %v", item.key, err)
h.recordL2Error()
} else {
h.l2Writes.Add(1)
h.logger.Debugf("Async write to L2 completed for key: %s", item.key)
}
cancel()
}
}
}
// healthMonitor periodically checks L2 health and manages fallback mode
func (h *HybridBackend) healthMonitor() {
defer h.wg.Done()
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-h.ctx.Done():
return
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := h.secondary.Ping(ctx); err != nil {
if !h.fallbackMode.Load() {
h.fallbackMode.Store(true)
h.logger.Warnf("L2 backend unhealthy, entering fallback mode: %v", err)
}
} else {
if h.fallbackMode.CompareAndSwap(true, false) {
h.logger.Infof("L2 backend healthy, exiting fallback mode")
}
}
cancel()
}
}
}
// recordL2Error records the timestamp of an L2 error
func (h *HybridBackend) recordL2Error() {
h.lastL2Error.Store(time.Now())
// Check if we should enter fallback mode based on recent errors
if !h.fallbackMode.Load() {
// Simple heuristic: if we've had an error in the last second, consider L2 unhealthy
if lastErr := h.lastL2Error.Load(); lastErr != nil {
if t, ok := lastErr.(time.Time); ok && time.Since(t) < time.Second {
h.fallbackMode.Store(true)
h.logger.Warnf("Multiple L2 errors detected, entering fallback mode")
}
}
}
}
// extractCacheType attempts to determine the cache type from the key
func (h *HybridBackend) extractCacheType(key string) string {
// Simple heuristic based on key prefixes
// This should match the actual cache type strategy in the main application
if len(key) > 10 {
prefix := key[:10]
switch {
case contains(prefix, "blacklist"):
return "blacklist"
case contains(prefix, "token"):
return "token"
case contains(prefix, "metadata"):
return "metadata"
case contains(prefix, "jwk"):
return "jwk"
case contains(prefix, "session"):
return "session"
case contains(prefix, "introspect"):
return "introspection"
}
}
return "general"
}
// contains checks if a string contains a substring (case-insensitive)
func contains(s, substr string) bool {
if len(substr) > len(s) {
return false
}
for i := 0; i <= len(s)-len(substr); i++ {
match := true
for j := 0; j < len(substr); j++ {
if toLower(s[i+j]) != toLower(substr[j]) {
match = false
break
}
}
if match {
return true
}
}
return false
}
// toLower converts a byte to lowercase
func toLower(b byte) byte {
if b >= 'A' && b <= 'Z' {
return b + 32
}
return b
}
File diff suppressed because it is too large Load Diff
+133
View File
@@ -0,0 +1,133 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"context"
"time"
)
// CacheBackend defines the interface for all cache backend implementations
// Implementations include: MemoryBackend, RedisBackend, and HybridBackend
type CacheBackend interface {
// Set stores a value in the cache with the specified TTL
// Returns an error if the operation fails
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
// Get retrieves a value from the cache
// Returns: value, remaining TTL, exists flag, and error
// If the key doesn't exist, exists will be false
Get(ctx context.Context, key string) (value []byte, ttl time.Duration, exists bool, err error)
// Delete removes a key from the cache
// Returns true if the key was deleted, false if it didn't exist
Delete(ctx context.Context, key string) (bool, error)
// Exists checks if a key exists in the cache
Exists(ctx context.Context, key string) (bool, error)
// Clear removes all keys from the cache
Clear(ctx context.Context) error
// GetStats returns cache statistics
// Stats include: hits, misses, size, memory usage, etc.
GetStats() map[string]interface{}
// Close shuts down the cache backend and releases resources
Close() error
// Ping checks if the backend is healthy and responsive
Ping(ctx context.Context) error
}
// BackendStats represents statistics for a cache backend
type BackendStats struct {
// Type is the backend type
Type BackendType
// Hits is the number of cache hits
Hits int64
// Misses is the number of cache misses
Misses int64
// Sets is the number of set operations
Sets int64
// Deletes is the number of delete operations
Deletes int64
// Errors is the number of errors
Errors int64
// Evictions is the number of evicted items
Evictions int64
// CurrentSize is the current number of items in cache
CurrentSize int64
// MaxSize is the maximum number of items (0 means unlimited)
MaxSize int64
// MemoryUsage is the approximate memory usage in bytes
MemoryUsage int64
// AverageGetLatency is the average latency for get operations
AverageGetLatency time.Duration
// AverageSetLatency is the average latency for set operations
AverageSetLatency time.Duration
// LastError is the last error encountered
LastError string
// LastErrorTime is when the last error occurred
LastErrorTime time.Time
// Uptime is how long the backend has been running
Uptime time.Duration
// StartTime is when the backend was started
StartTime time.Time
}
// BackendCapabilities describes the capabilities of a cache backend
type BackendCapabilities struct {
// Distributed indicates if the backend is distributed across multiple instances
Distributed bool
// Persistent indicates if the backend persists data across restarts
Persistent bool
// Eviction indicates if the backend supports automatic eviction
Eviction bool
// TTL indicates if the backend supports TTL (time-to-live)
TTL bool
// MaxKeySize is the maximum size of a key in bytes (0 = unlimited)
MaxKeySize int64
// MaxValueSize is the maximum size of a value in bytes (0 = unlimited)
MaxValueSize int64
// MaxKeys is the maximum number of keys (0 = unlimited)
MaxKeys int64
// SupportsExpire indicates if the backend supports expiration
SupportsExpire bool
// SupportsMultiGet indicates if the backend supports batch get operations
SupportsMultiGet bool
// SupportsTransaction indicates if the backend supports transactions
SupportsTransaction bool
// SupportsCompression indicates if the backend supports compression
SupportsCompression bool
// RequiresSerialize indicates if values must be serialized
RequiresSerialize bool
// AtomicOperations indicates if the backend supports atomic operations
AtomicOperations bool
}
+421
View File
@@ -0,0 +1,421 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestCacheBackendContract defines a set of tests that all CacheBackend implementations must pass
// This ensures that Memory, Redis, and Hybrid backends all behave consistently
func TestCacheBackendContract(t *testing.T) {
// Test suite will be run against each backend type
t.Run("MemoryBackend", func(t *testing.T) {
backend := setupMemoryBackend(t)
runContractTests(t, backend)
})
t.Run("RedisBackend", func(t *testing.T) {
backend := setupRedisBackend(t)
runContractTests(t, backend)
})
t.Run("HybridBackend", func(t *testing.T) {
backend := setupHybridBackend(t)
runContractTests(t, backend)
})
}
// runContractTests executes all contract tests against a backend
func runContractTests(t *testing.T, backend CacheBackend) {
t.Helper()
ctx := context.Background()
t.Run("BasicSetGet", func(t *testing.T) {
testBasicSetGet(t, ctx, backend)
})
t.Run("GetNonExistent", func(t *testing.T) {
testGetNonExistent(t, ctx, backend)
})
t.Run("UpdateExisting", func(t *testing.T) {
testUpdateExisting(t, ctx, backend)
})
t.Run("Delete", func(t *testing.T) {
testDelete(t, ctx, backend)
})
t.Run("DeleteNonExistent", func(t *testing.T) {
testDeleteNonExistent(t, ctx, backend)
})
t.Run("Exists", func(t *testing.T) {
testExists(t, ctx, backend)
})
t.Run("TTLExpiration", func(t *testing.T) {
testTTLExpiration(t, ctx, backend)
})
t.Run("Clear", func(t *testing.T) {
testClear(t, ctx, backend)
})
t.Run("Ping", func(t *testing.T) {
testPing(t, ctx, backend)
})
t.Run("Stats", func(t *testing.T) {
testStats(t, ctx, backend)
})
t.Run("ConcurrentAccess", func(t *testing.T) {
testConcurrentAccess(t, ctx, backend)
})
t.Run("LargeValues", func(t *testing.T) {
testLargeValues(t, ctx, backend)
})
t.Run("EmptyValues", func(t *testing.T) {
testEmptyValues(t, ctx, backend)
})
t.Run("SpecialCharactersInKeys", func(t *testing.T) {
testSpecialCharactersInKeys(t, ctx, backend)
})
}
// testBasicSetGet verifies basic set and get operations
func testBasicSetGet(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "test-key-1"
value := []byte("test-value-1")
ttl := 1 * time.Minute
// Set value
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err, "Set should not return error")
// Get value
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err, "Get should not return error")
assert.True(t, exists, "Key should exist")
assert.Equal(t, value, retrieved, "Retrieved value should match")
assert.Greater(t, remainingTTL, 50*time.Second, "TTL should be close to original")
assert.LessOrEqual(t, remainingTTL, ttl, "TTL should not exceed original")
}
// testGetNonExistent verifies behavior when getting non-existent keys
func testGetNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "non-existent-key"
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err, "Get should not return error for non-existent key")
assert.False(t, exists, "Key should not exist")
assert.Nil(t, retrieved, "Value should be nil")
assert.Equal(t, time.Duration(0), ttl, "TTL should be zero")
}
// testUpdateExisting verifies updating an existing key
func testUpdateExisting(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
ttl := 1 * time.Minute
// Set initial value
err := backend.Set(ctx, key, value1, ttl)
require.NoError(t, err)
// Update value
err = backend.Set(ctx, key, value2, ttl)
require.NoError(t, err)
// Verify updated value
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved, "Value should be updated")
}
// testDelete verifies delete operation
func testDelete(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "delete-key"
value := []byte("delete-value")
// Set value
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Verify exists
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Delete
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted, "Delete should return true for existing key")
// Verify deleted
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after delete")
}
// testDeleteNonExistent verifies deleting non-existent keys
func testDeleteNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "non-existent-delete-key"
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.False(t, deleted, "Delete should return false for non-existent key")
}
// testExists verifies the Exists operation
func testExists(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "exists-key"
value := []byte("exists-value")
// Check non-existent key
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist initially")
// Set value
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check existing key
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key should exist after Set")
}
// testTTLExpiration verifies TTL expiration behavior
func testTTLExpiration(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "ttl-key"
value := []byte("ttl-value")
shortTTL := 100 * time.Millisecond
// Set with short TTL
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Verify exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key should exist immediately after Set")
// Wait for expiration
time.Sleep(200 * time.Millisecond)
// Verify expired
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after TTL expiration")
}
// testClear verifies Clear operation
func testClear(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
// Set multiple keys
for i := 0; i < 5; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Give async writes time to complete before clearing
// This prevents race conditions with async write workers
time.Sleep(50 * time.Millisecond)
// Clear all
err := backend.Clear(ctx)
require.NoError(t, err)
// Verify all keys are gone
for i := 0; i < 5; i++ {
key := fmt.Sprintf("clear-key-%d", i)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after Clear")
}
}
// testPing verifies Ping operation
func testPing(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
err := backend.Ping(ctx)
assert.NoError(t, err, "Ping should succeed on healthy backend")
}
// testStats verifies GetStats operation
func testStats(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
stats := backend.GetStats()
assert.NotNil(t, stats, "Stats should not be nil")
// Stats should contain basic metrics
_, hasHits := stats["hits"]
_, hasMisses := stats["misses"]
assert.True(t, hasHits || hasMisses, "Stats should contain hits or misses")
}
// testConcurrentAccess verifies thread safety
func testConcurrentAccess(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
var wg sync.WaitGroup
goroutines := 10
iterations := 20
// Concurrent writes
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
// Read back
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
}
}(i)
}
wg.Wait()
}
// testLargeValues verifies handling of large values
func testLargeValues(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "large-value-key"
value := GenerateLargeValue(1024 * 1024) // 1MB
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle large values")
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(value), len(retrieved), "Large value should be retrieved intact")
}
// testEmptyValues verifies handling of empty values
func testEmptyValues(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "empty-value-key"
value := []byte{}
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle empty values")
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Empty value should exist")
assert.Equal(t, 0, len(retrieved), "Retrieved value should be empty")
}
// testSpecialCharactersInKeys verifies handling of special characters in keys
func testSpecialCharactersInKeys(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
specialKeys := []string{
"key:with:colons",
"key/with/slashes",
"key-with-dashes",
"key_with_underscores",
"key.with.dots",
"key|with|pipes",
}
for _, key := range specialKeys {
value := []byte(fmt.Sprintf("value-for-%s", key))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle special character in key: %s", key)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key with special characters should exist: %s", key)
assert.Equal(t, value, retrieved)
}
}
// Helper functions to setup different backend types
// These will be implemented in respective test files
func setupMemoryBackend(t *testing.T) CacheBackend {
t.Helper()
// This will be implemented in memory_test.go
// For now, return nil to allow compilation
t.Skip("MemoryBackend implementation pending")
return nil
}
func setupRedisBackend(t *testing.T) CacheBackend {
t.Helper()
// This will be implemented in redis_test.go
// For now, return nil to allow compilation
t.Skip("RedisBackend implementation pending")
return nil
}
func setupHybridBackend(t *testing.T) CacheBackend {
t.Helper()
primary := newMockBackend()
secondary := newMockBackend()
config := &HybridConfig{
Primary: primary,
Secondary: secondary,
AsyncBufferSize: 100,
Logger: NewTestLogger(t),
}
hybrid, err := NewHybridBackend(config)
require.NoError(t, err)
t.Cleanup(func() {
hybrid.Close()
})
return hybrid
}
+516
View File
@@ -0,0 +1,516 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"container/list"
"context"
"sync"
"sync/atomic"
"time"
)
// memoryCacheItem represents an item in the memory cache
type memoryCacheItem struct {
key string
value interface{}
expiresAt time.Time
createdAt time.Time
accessedAt time.Time
accessCount int64
size int64
element *list.Element // for LRU tracking
}
// isExpired checks if the item is expired
func (item *memoryCacheItem) isExpired() bool {
if item.expiresAt.IsZero() {
return false
}
return time.Now().After(item.expiresAt)
}
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
type MemoryCacheBackend struct {
mu sync.RWMutex
items map[string]*memoryCacheItem
lruList *list.List
maxSize int64
maxMemory int64
currentSize int64
currentMemory int64
// Statistics
hits atomic.Int64
misses atomic.Int64
sets atomic.Int64
deletes atomic.Int64
evictions atomic.Int64
errors atomic.Int64
// Latency tracking
totalGetTime atomic.Int64
totalSetTime atomic.Int64
getCount atomic.Int64
setCount atomic.Int64
// Status
startTime time.Time
lastError string
lastErrorTime time.Time
cleanupTicker *time.Ticker
cleanupDone chan bool
closed atomic.Bool
// Configuration
cleanupInterval time.Duration
evictionPolicy string // "lru", "lfu", "fifo"
}
// NewMemoryCacheBackend creates a new memory cache backend
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
if maxSize <= 0 {
maxSize = 10000 // Default to 10k items
}
if maxMemory <= 0 {
maxMemory = 100 * 1024 * 1024 // Default to 100MB
}
if cleanupInterval <= 0 {
cleanupInterval = 5 * time.Minute
}
m := &MemoryCacheBackend{
items: make(map[string]*memoryCacheItem),
lruList: list.New(),
maxSize: maxSize,
maxMemory: maxMemory,
startTime: time.Now(),
cleanupInterval: cleanupInterval,
evictionPolicy: "lru",
cleanupDone: make(chan bool),
}
// Start cleanup goroutine
m.cleanupTicker = time.NewTicker(cleanupInterval)
go m.cleanupLoop()
return m
}
// cleanupLoop runs periodic cleanup of expired items
func (m *MemoryCacheBackend) cleanupLoop() {
for {
select {
case <-m.cleanupTicker.C:
m.cleanupExpired()
case <-m.cleanupDone:
return
}
}
}
// cleanupExpired removes all expired items from the cache
func (m *MemoryCacheBackend) cleanupExpired() {
m.mu.Lock()
defer m.mu.Unlock()
var keysToDelete []string
for key, item := range m.items {
if item.isExpired() {
keysToDelete = append(keysToDelete, key)
}
}
for _, key := range keysToDelete {
m.deleteItemLocked(key)
}
}
// Get retrieves a value from the cache
func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{}, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
start := time.Now()
defer func() {
duration := time.Since(start).Nanoseconds()
m.totalGetTime.Add(duration)
m.getCount.Add(1)
}()
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists {
m.misses.Add(1)
return nil, ErrCacheMiss
}
if item.isExpired() {
m.mu.Lock()
m.deleteItemLocked(key)
m.mu.Unlock()
m.misses.Add(1)
return nil, ErrCacheMiss
}
// Update access time and count
m.mu.Lock()
item.accessedAt = time.Now()
item.accessCount++
// Move to front of LRU list
if m.evictionPolicy == "lru" && item.element != nil {
m.lruList.MoveToFront(item.element)
}
m.mu.Unlock()
m.hits.Add(1)
return item.value, nil
}
// Set stores a value in the cache with optional TTL
func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
start := time.Now()
defer func() {
duration := time.Since(start).Nanoseconds()
m.totalSetTime.Add(duration)
m.setCount.Add(1)
}()
// Calculate item size (simplified estimation)
itemSize := int64(len(key)) + estimateValueSize(value)
m.mu.Lock()
defer m.mu.Unlock()
// Check if we need to evict items
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
m.evictLocked()
}
// Check if key exists
if oldItem, exists := m.items[key]; exists {
m.currentMemory -= oldItem.size
if oldItem.element != nil {
m.lruList.Remove(oldItem.element)
}
} else {
m.currentSize++
}
now := time.Now()
var expiresAt time.Time
if ttl > 0 {
expiresAt = now.Add(ttl)
}
item := &memoryCacheItem{
key: key,
value: value,
expiresAt: expiresAt,
createdAt: now,
accessedAt: now,
accessCount: 0,
size: itemSize,
}
// Add to LRU list
if m.evictionPolicy == "lru" {
item.element = m.lruList.PushFront(item)
}
m.items[key] = item
m.currentMemory += itemSize
m.sets.Add(1)
return nil
}
// Delete removes a key from the cache
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.items[key]; !exists {
return nil
}
m.deleteItemLocked(key)
m.deletes.Add(1)
return nil
}
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
if item, exists := m.items[key]; exists {
m.currentMemory -= item.size
m.currentSize--
if item.element != nil {
m.lruList.Remove(item.element)
}
delete(m.items, key)
}
}
// evictLocked evicts items based on the eviction policy (must be called with lock held)
func (m *MemoryCacheBackend) evictLocked() {
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
// Evict least recently used item
element := m.lruList.Back()
if element != nil {
item := element.Value.(*memoryCacheItem)
m.deleteItemLocked(item.key)
m.evictions.Add(1)
}
}
}
// Exists checks if a key exists in the cache
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
if m.closed.Load() {
return false, ErrBackendUnavailable
}
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists {
return false, nil
}
return !item.isExpired(), nil
}
// Clear removes all items from the cache
func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
m.items = make(map[string]*memoryCacheItem)
m.lruList = list.New()
m.currentSize = 0
m.currentMemory = 0
return nil
}
// Keys returns all keys matching the pattern (use "*" for all keys)
func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
m.mu.RLock()
defer m.mu.RUnlock()
var keys []string
for key, item := range m.items {
if !item.isExpired() && matchPattern(pattern, key) {
keys = append(keys, key)
}
}
return keys, nil
}
// Size returns the number of items in the cache
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
if m.closed.Load() {
return 0, ErrBackendUnavailable
}
m.mu.RLock()
defer m.mu.RUnlock()
return m.currentSize, nil
}
// TTL returns the remaining time-to-live for a key
func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration, error) {
if m.closed.Load() {
return 0, ErrBackendUnavailable
}
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists || item.isExpired() {
return 0, ErrCacheMiss
}
if item.expiresAt.IsZero() {
return 0, nil // No expiration
}
remaining := time.Until(item.expiresAt)
if remaining < 0 {
return 0, nil
}
return remaining, nil
}
// Expire updates the TTL for an existing key
func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Duration) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
item, exists := m.items[key]
if !exists || item.isExpired() {
return ErrCacheMiss
}
if ttl > 0 {
item.expiresAt = time.Now().Add(ttl)
} else {
item.expiresAt = time.Time{} // Remove expiration
}
return nil
}
// GetStats returns statistics about the cache backend
func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
m.mu.RLock()
lastError := m.lastError
lastErrorTime := m.lastErrorTime
m.mu.RUnlock()
avgGetLatency := time.Duration(0)
if getCount := m.getCount.Load(); getCount > 0 {
avgGetLatency = time.Duration(m.totalGetTime.Load() / getCount)
}
avgSetLatency := time.Duration(0)
if setCount := m.setCount.Load(); setCount > 0 {
avgSetLatency = time.Duration(m.totalSetTime.Load() / setCount)
}
return &BackendStats{
Type: TypeMemory,
Hits: m.hits.Load(),
Misses: m.misses.Load(),
Sets: m.sets.Load(),
Deletes: m.deletes.Load(),
Errors: m.errors.Load(),
Evictions: m.evictions.Load(),
CurrentSize: m.currentSize,
MaxSize: m.maxSize,
MemoryUsage: m.currentMemory,
AverageGetLatency: avgGetLatency,
AverageSetLatency: avgSetLatency,
LastError: lastError,
LastErrorTime: lastErrorTime,
Uptime: time.Since(m.startTime),
StartTime: m.startTime,
}, nil
}
// Ping checks if the backend is healthy
func (m *MemoryCacheBackend) Ping(ctx context.Context) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
return nil
}
// Close closes the backend and releases resources
func (m *MemoryCacheBackend) Close() error {
if m.closed.Swap(true) {
return nil // Already closed
}
m.cleanupTicker.Stop()
close(m.cleanupDone)
m.mu.Lock()
m.items = nil
m.lruList = nil
m.mu.Unlock()
return nil
}
// IsHealthy returns true if the backend is healthy
func (m *MemoryCacheBackend) IsHealthy() bool {
return !m.closed.Load()
}
// Type returns the backend type
func (m *MemoryCacheBackend) Type() BackendType {
return TypeMemory
}
// Capabilities returns the backend capabilities
func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
return &BackendCapabilities{
Distributed: false,
Persistent: false,
Eviction: true,
TTL: true,
MaxKeySize: 1024, // 1KB
MaxValueSize: 10485760, // 10MB
MaxKeys: m.maxSize,
SupportsExpire: true,
SupportsMultiGet: true,
SupportsTransaction: false,
SupportsCompression: false,
RequiresSerialize: false,
}
}
// Helper functions
// estimateValueSize estimates the size of a value in bytes
func estimateValueSize(value interface{}) int64 {
// This is a simplified estimation
// In production, you might want to use a more accurate method
switch v := value.(type) {
case string:
return int64(len(v))
case []byte:
return int64(len(v))
case int, int32, int64, uint, uint32, uint64:
return 8
case float32, float64:
return 8
case bool:
return 1
default:
// For complex types, use a default estimate
return 256
}
}
// matchPattern checks if a key matches a pattern (simplified glob matching)
func matchPattern(pattern, key string) bool {
if pattern == "*" {
return true
}
// Simplified pattern matching - in production, use a proper glob library
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
}
+182
View File
@@ -0,0 +1,182 @@
package backends
import (
"context"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
)
// setupBenchmarkRedis creates a miniredis instance for benchmarking
func setupBenchmarkRedis(b *testing.B) string {
b.Helper()
mr, err := miniredis.Run()
if err != nil {
b.Fatal(err)
}
b.Cleanup(func() {
mr.Close()
})
return mr.Addr()
}
// BenchmarkRedisOperations_WithPooling benchmarks memory allocations with object pooling
func BenchmarkRedisOperations_WithPooling(b *testing.B) {
addr := setupBenchmarkRedis(b)
config := &PoolConfig{
Address: addr,
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
if err != nil {
b.Fatal(err)
}
defer pool.Close()
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn, err := pool.Get(ctx)
if err != nil {
b.Fatal(err)
}
// Perform various operations
_, _ = conn.Do("SET", "bench-key", "bench-value")
_, _ = conn.Do("GET", "bench-key")
_, _ = conn.Do("EXISTS", "bench-key")
_, _ = conn.Do("DEL", "bench-key")
pool.Put(conn)
}
}
// BenchmarkRedisBackend_SetGet benchmarks the full backend with pooling
func BenchmarkRedisBackend_SetGet(b *testing.B) {
addr := setupBenchmarkRedis(b)
backend, err := NewRedisBackend(&Config{
RedisAddr: addr,
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
testData := []byte("benchmark test data with some content")
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Set operation
err := backend.Set(ctx, "bench-key", testData, 0)
if err != nil {
b.Fatal(err)
}
// Get operation
_, _, _, err = backend.Get(ctx, "bench-key")
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkRedisBackend_ConcurrentAccess benchmarks concurrent operations with pooling
func BenchmarkRedisBackend_ConcurrentAccess(b *testing.B) {
addr := setupBenchmarkRedis(b)
backend, err := NewRedisBackend(&Config{
RedisAddr: addr,
PoolSize: 10,
})
if err != nil {
b.Fatal(err)
}
defer backend.Close()
ctx := context.Background()
testData := []byte("concurrent benchmark data")
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = backend.Set(ctx, "concurrent-key", testData, 0)
_, _, _, _ = backend.Get(ctx, "concurrent-key")
}
})
}
// BenchmarkRESPProtocol_WriteRead benchmarks RESP protocol encoding/decoding
func BenchmarkRESPProtocol_WriteRead(b *testing.B) {
addr := setupBenchmarkRedis(b)
config := &PoolConfig{
Address: addr,
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
if err != nil {
b.Fatal(err)
}
defer pool.Close()
ctx := context.Background()
conn, err := pool.Get(ctx)
if err != nil {
b.Fatal(err)
}
defer pool.Put(conn)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// This tests the pooling of RESPReader/RESPWriter
_, _ = conn.Do("PING")
}
}
// BenchmarkConnectionPool_GetPut benchmarks connection pool operations
func BenchmarkConnectionPool_GetPut(b *testing.B) {
addr := setupBenchmarkRedis(b)
config := &PoolConfig{
Address: addr,
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
if err != nil {
b.Fatal(err)
}
defer pool.Close()
ctx := context.Background()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn, err := pool.Get(ctx)
if err != nil {
b.Fatal(err)
}
pool.Put(conn)
}
}
+783
View File
@@ -0,0 +1,783 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestMemoryBackend_BasicOperations tests basic CRUD operations
func TestMemoryBackend_BasicOperations(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetAndGet", func(t *testing.T) {
key := "test-key"
value := []byte("test-value")
ttl := 1 * time.Minute
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
assert.Greater(t, remainingTTL, 50*time.Second)
assert.LessOrEqual(t, remainingTTL, ttl)
})
t.Run("GetNonExistent", func(t *testing.T) {
_, _, exists, err := backend.Get(ctx, "non-existent")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Delete", func(t *testing.T) {
key := "delete-key"
value := []byte("delete-value")
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("DeleteNonExistent", func(t *testing.T) {
deleted, err := backend.Delete(ctx, "non-existent-delete")
require.NoError(t, err)
assert.False(t, deleted)
})
t.Run("Exists", func(t *testing.T) {
key := "exists-key"
value := []byte("exists-value")
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
})
t.Run("Clear", func(t *testing.T) {
// Add multiple items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
err := backend.Clear(ctx)
require.NoError(t, err)
stats := backend.GetStats()
size := stats["size"].(int64)
assert.Equal(t, int64(0), size)
})
}
// TestMemoryBackend_TTLExpiration tests TTL and expiration
func TestMemoryBackend_TTLExpiration(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.CleanupInterval = 50 * time.Millisecond
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("ShortTTL", func(t *testing.T) {
key := "short-ttl-key"
value := []byte("short-ttl-value")
shortTTL := 100 * time.Millisecond
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Verify exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Wait for expiration
time.Sleep(150 * time.Millisecond)
// Should be expired
_, _, exists, err = backend.Get(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("TTLDecrement", func(t *testing.T) {
key := "ttl-decrement-key"
value := []byte("ttl-decrement-value")
ttl := 2 * time.Second
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
// Check TTL immediately
_, ttl1, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Wait a bit
time.Sleep(500 * time.Millisecond)
// Check TTL again - should be less
_, ttl2, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Less(t, ttl2, ttl1, "TTL should decrease over time")
})
t.Run("CleanupExpiredItems", func(t *testing.T) {
// Set multiple items with short TTL
for i := 0; i < 5; i++ {
key := fmt.Sprintf("cleanup-key-%d", i)
value := []byte(fmt.Sprintf("cleanup-value-%d", i))
err := backend.Set(ctx, key, value, 50*time.Millisecond)
require.NoError(t, err)
}
// Wait for cleanup to run
time.Sleep(200 * time.Millisecond)
// All items should be cleaned up
for i := 0; i < 5; i++ {
key := fmt.Sprintf("cleanup-key-%d", i)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Expired items should be cleaned up")
}
})
}
// TestMemoryBackend_LRUEviction tests LRU eviction
func TestMemoryBackend_LRUEviction(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxSize = 5
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Fill cache to max size
for i := 0; i < 5; i++ {
key := fmt.Sprintf("lru-key-%d", i)
value := []byte(fmt.Sprintf("lru-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Access first item to make it most recently used
_, _, exists, err := backend.Get(ctx, "lru-key-0")
require.NoError(t, err)
assert.True(t, exists)
// Add a new item - should evict lru-key-1 (least recently used)
err = backend.Set(ctx, "lru-key-new", []byte("new-value"), 1*time.Minute)
require.NoError(t, err)
// lru-key-0 should still exist (was accessed recently)
exists, err = backend.Exists(ctx, "lru-key-0")
require.NoError(t, err)
assert.True(t, exists, "Recently accessed item should not be evicted")
// lru-key-1 should be evicted
exists, err = backend.Exists(ctx, "lru-key-1")
require.NoError(t, err)
assert.False(t, exists, "Least recently used item should be evicted")
// Check eviction count
stats := backend.GetStats()
evictions := stats["evictions"].(int64)
assert.Greater(t, evictions, int64(0), "Should have evictions")
}
// TestMemoryBackend_MemoryLimit tests memory-based eviction
func TestMemoryBackend_MemoryLimit(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxSize = 100
config.MaxMemoryBytes = 1024 // 1KB limit
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add items until memory limit is reached
largeValue := make([]byte, 512) // 512 bytes each
for i := 0; i < 5; i++ {
key := fmt.Sprintf("mem-key-%d", i)
err := backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
}
stats := backend.GetStats()
memory := stats["memory"].(int64)
assert.LessOrEqual(t, memory, config.MaxMemoryBytes, "Memory should not exceed limit")
evictions := stats["evictions"].(int64)
assert.Greater(t, evictions, int64(0), "Should have memory-based evictions")
}
// TestMemoryBackend_ConcurrentAccess tests thread safety
func TestMemoryBackend_ConcurrentAccess(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
var wg sync.WaitGroup
goroutines := 20
iterations := 50
// Concurrent writes
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
// Read back
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
// Random deletes
if j%5 == 0 {
backend.Delete(ctx, key)
}
}
}(i)
}
wg.Wait()
// Verify stats are consistent
stats := backend.GetStats()
hits := stats["hits"].(int64)
misses := stats["misses"].(int64)
assert.Greater(t, hits+misses, int64(0), "Should have cache operations")
}
// TestMemoryBackend_UpdateExisting tests updating existing keys
func TestMemoryBackend_UpdateExisting(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
// Set original
err = backend.Set(ctx, key, value1, 1*time.Minute)
require.NoError(t, err)
// Update
err = backend.Set(ctx, key, value2, 2*time.Minute)
require.NoError(t, err)
// Verify updated
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved)
assert.Greater(t, ttl, 1*time.Minute, "TTL should be updated")
// Size should not increase (same key)
stats := backend.GetStats()
size := stats["size"].(int64)
assert.Equal(t, int64(1), size, "Size should be 1 for one key")
}
// TestMemoryBackend_Stats tests statistics tracking
func TestMemoryBackend_Stats(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Initial stats
stats := backend.GetStats()
assert.Equal(t, int64(0), stats["hits"].(int64))
assert.Equal(t, int64(0), stats["misses"].(int64))
// Add items and track hits/misses
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
backend.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
// Hit
backend.Get(ctx, "key1")
// Miss
backend.Get(ctx, "non-existent")
stats = backend.GetStats()
assert.Equal(t, int64(1), stats["hits"].(int64))
assert.Equal(t, int64(1), stats["misses"].(int64))
hitRate := stats["hit_rate"].(float64)
assert.InDelta(t, 0.5, hitRate, 0.01)
}
// TestMemoryBackend_EmptyValues tests handling of empty values
func TestMemoryBackend_EmptyValues(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "empty-key"
emptyValue := []byte{}
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, 0, len(retrieved))
}
// TestMemoryBackend_LargeValues tests handling of large values
func TestMemoryBackend_LargeValues(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxMemoryBytes = 10 * 1024 * 1024 // 10MB
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "large-key"
largeValue := make([]byte, 1024*1024) // 1MB
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(largeValue), len(retrieved))
}
// TestMemoryBackend_Close tests proper cleanup on close
func TestMemoryBackend_Close(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
ctx := context.Background()
// Add some items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("close-key-%d", i)
value := []byte(fmt.Sprintf("close-value-%d", i))
backend.Set(ctx, key, value, 1*time.Minute)
}
// Close
err = backend.Close()
require.NoError(t, err)
// Operations after close should fail
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
_, _, _, err = backend.Get(ctx, "close-key-0")
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
// Closing again should be safe
err = backend.Close()
assert.NoError(t, err)
}
// TestMemoryBackend_Ping tests ping operation
func TestMemoryBackend_Ping(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
err = backend.Ping(ctx)
assert.NoError(t, err)
// Close and ping should fail
backend.Close()
err = backend.Ping(ctx)
assert.Error(t, err)
}
// TestMemoryBackend_ValueIsolation tests that returned values are isolated
func TestMemoryBackend_ValueIsolation(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "isolation-key"
originalValue := []byte("original-value")
err = backend.Set(ctx, key, originalValue, 1*time.Minute)
require.NoError(t, err)
// Get value and modify it
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Modify retrieved value
if len(retrieved) > 0 {
retrieved[0] = 'X'
}
// Get again - should be unchanged
retrieved2, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, originalValue, retrieved2, "Original value should not be modified")
}
// TestMemoryBackend_Keys tests the Keys method with pattern matching
func TestMemoryBackend_Keys(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add test data
testKeys := []string{"user:1", "user:2", "session:abc", "session:def", "token:xyz"}
for _, key := range testKeys {
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
}
t.Run("AllKeys", func(t *testing.T) {
keys, err := backend.Keys(ctx, "*")
require.NoError(t, err)
assert.Len(t, keys, 5)
})
t.Run("SpecificPattern", func(t *testing.T) {
// Simple exact match
keys, err := backend.Keys(ctx, "user:1")
require.NoError(t, err)
assert.Len(t, keys, 1)
assert.Contains(t, keys, "user:1")
})
t.Run("ExcludesExpired", func(t *testing.T) {
// Add an expired key
expiredKey := "expired:key"
err := backend.Set(ctx, expiredKey, []byte("value"), 1*time.Millisecond)
require.NoError(t, err)
// Wait for expiration
time.Sleep(10 * time.Millisecond)
keys, err := backend.Keys(ctx, "*")
require.NoError(t, err)
assert.NotContains(t, keys, expiredKey, "Expired keys should not be returned")
})
t.Run("AfterClose", func(t *testing.T) {
closedBackend, _ := NewMemoryBackend(DefaultConfig())
closedBackend.Close()
_, err := closedBackend.Keys(ctx, "*")
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
})
}
// TestMemoryBackend_Size tests the Size method
func TestMemoryBackend_Size(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Initially empty
size, err := backend.Size(ctx)
require.NoError(t, err)
assert.Equal(t, int64(0), size)
// Add items
for i := 0; i < 5; i++ {
key := fmt.Sprintf("key-%d", i)
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
}
size, err = backend.Size(ctx)
require.NoError(t, err)
assert.Equal(t, int64(5), size)
// Delete one
backend.Delete(ctx, "key-0")
size, err = backend.Size(ctx)
require.NoError(t, err)
assert.Equal(t, int64(4), size)
// After close
backend.Close()
_, err = backend.Size(ctx)
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
}
// TestMemoryBackend_TTL tests the TTL method
func TestMemoryBackend_TTL(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("ExistingKey", func(t *testing.T) {
key := "ttl-key"
ttl := 1 * time.Minute
err := backend.Set(ctx, key, []byte("value"), ttl)
require.NoError(t, err)
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
assert.Greater(t, remaining, 50*time.Second)
assert.LessOrEqual(t, remaining, ttl)
})
t.Run("NonExistentKey", func(t *testing.T) {
_, err := backend.TTL(ctx, "non-existent")
assert.Error(t, err)
assert.Equal(t, ErrCacheMiss, err)
})
t.Run("NoExpiration", func(t *testing.T) {
key := "no-expiry"
// TTL of 0 typically means no expiration
err := backend.Set(ctx, key, []byte("value"), 0)
require.NoError(t, err)
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
// No expiration returns 0
assert.Equal(t, time.Duration(0), remaining)
})
t.Run("AfterClose", func(t *testing.T) {
closedBackend, _ := NewMemoryBackend(DefaultConfig())
closedBackend.Close()
_, err := closedBackend.TTL(ctx, "key")
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
})
}
// TestMemoryBackend_Expire tests the Expire method
func TestMemoryBackend_Expire(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("UpdateTTL", func(t *testing.T) {
key := "expire-key"
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
// Update to shorter TTL
err = backend.Expire(ctx, key, 5*time.Second)
require.NoError(t, err)
// Check new TTL
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
assert.LessOrEqual(t, remaining, 5*time.Second)
})
t.Run("NonExistentKey", func(t *testing.T) {
err := backend.Expire(ctx, "non-existent", 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrCacheMiss, err)
})
t.Run("RemoveExpiration", func(t *testing.T) {
key := "no-expire-key"
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
require.NoError(t, err)
// Set TTL to 0 to remove expiration
err = backend.Expire(ctx, key, 0)
require.NoError(t, err)
// TTL should now be 0
remaining, err := backend.TTL(ctx, key)
require.NoError(t, err)
assert.Equal(t, time.Duration(0), remaining)
})
t.Run("AfterClose", func(t *testing.T) {
closedBackend, _ := NewMemoryBackend(DefaultConfig())
closedBackend.Close()
err := closedBackend.Expire(ctx, "key", 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrBackendUnavailable, err)
})
}
// TestMemoryBackend_IsHealthy tests the IsHealthy method
func TestMemoryBackend_IsHealthy(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
// Should be healthy when open
assert.True(t, backend.IsHealthy())
// Should be unhealthy after close
backend.Close()
assert.False(t, backend.IsHealthy())
}
// TestMemoryBackend_Type tests the Type method
func TestMemoryBackend_Type(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
backendType := backend.Type()
assert.Equal(t, TypeMemory, backendType)
}
// TestMemoryBackend_Capabilities tests the Capabilities method
func TestMemoryBackend_Capabilities(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
caps := backend.Capabilities()
require.NotNil(t, caps)
// Memory backend should not be distributed or persistent
assert.False(t, caps.Distributed)
assert.False(t, caps.Persistent)
// Should support eviction and TTL
assert.True(t, caps.Eviction)
assert.True(t, caps.TTL)
assert.True(t, caps.SupportsExpire)
assert.True(t, caps.SupportsMultiGet)
// Check limits
assert.Greater(t, caps.MaxKeySize, int64(0))
assert.Greater(t, caps.MaxValueSize, int64(0))
}
// TestMatchPattern tests the matchPattern helper function
func TestMatchPattern(t *testing.T) {
t.Parallel()
tests := []struct {
pattern string
key string
matches bool
}{
{"*", "any-key", true},
{"*", "another", true},
{"user:1", "user:1", true},
{"user:1", "user:2", false},
{"*:suffix", "prefix:suffix", true},
{"*suffix", "prefix-suffix", true},
{"*abc", "xyzabc", true},
{"*abc", "xyz", false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s-%s", tt.pattern, tt.key), func(t *testing.T) {
result := matchPattern(tt.pattern, tt.key)
assert.Equal(t, tt.matches, result)
})
}
}
+153
View File
@@ -0,0 +1,153 @@
package backends
import (
"context"
"time"
)
// MemoryBackend wraps MemoryCacheBackend to implement the CacheBackend interface
type MemoryBackend struct {
*MemoryCacheBackend
}
// NewMemoryBackend creates a new memory backend from a config
func NewMemoryBackend(config *Config) (*MemoryBackend, error) {
maxSize := int64(config.MaxSize)
if maxSize <= 0 {
maxSize = 1000
}
cacheBackend := NewMemoryCacheBackend(maxSize, config.MaxMemoryBytes, config.CleanupInterval)
return &MemoryBackend{
MemoryCacheBackend: cacheBackend,
}, nil
}
// Set stores a value in the cache with the specified TTL
func (m *MemoryBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
err := m.MemoryCacheBackend.Set(ctx, key, value, ttl)
if err == ErrBackendUnavailable {
return ErrBackendClosed
}
return err
}
// Get retrieves a value from the cache
func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
val, err := m.MemoryCacheBackend.Get(ctx, key)
if err != nil {
if err == ErrCacheMiss {
return nil, 0, false, nil
}
if err == ErrBackendUnavailable {
return nil, 0, false, ErrBackendClosed
}
return nil, 0, false, err
}
// Get the item directly to check TTL
m.MemoryCacheBackend.mu.RLock()
item, exists := m.MemoryCacheBackend.items[key]
m.MemoryCacheBackend.mu.RUnlock()
if !exists {
return nil, 0, false, nil
}
var ttl time.Duration
if !item.expiresAt.IsZero() {
ttl = time.Until(item.expiresAt)
if ttl < 0 {
ttl = 0
}
}
// Convert interface{} to []byte
var valueBytes []byte
if val != nil {
if bytes, ok := val.([]byte); ok {
valueBytes = bytes
} else {
// If it's not already []byte, we might need to handle other types
// For now, we'll just return an error
return nil, 0, false, ErrInvalidValue
}
}
return valueBytes, ttl, true, nil
}
// Delete removes a key from the cache
func (m *MemoryBackend) Delete(ctx context.Context, key string) (bool, error) {
// Check if key exists first
exists, err := m.MemoryCacheBackend.Exists(ctx, key)
if err != nil {
return false, err
}
if !exists {
return false, nil
}
err = m.MemoryCacheBackend.Delete(ctx, key)
if err != nil {
return false, err
}
return true, nil
}
// Exists checks if a key exists in the cache
func (m *MemoryBackend) Exists(ctx context.Context, key string) (bool, error) {
return m.MemoryCacheBackend.Exists(ctx, key)
}
// Clear removes all keys from the cache
func (m *MemoryBackend) Clear(ctx context.Context) error {
return m.MemoryCacheBackend.Clear(ctx)
}
// GetStats returns cache statistics
func (m *MemoryBackend) GetStats() map[string]interface{} {
stats, err := m.MemoryCacheBackend.GetStats(context.Background())
if err != nil {
return map[string]interface{}{
"error": err.Error(),
}
}
// Convert BackendStats to map
hitRate := float64(0)
total := stats.Hits + stats.Misses
if total > 0 {
hitRate = float64(stats.Hits) / float64(total)
}
return map[string]interface{}{
"type": stats.Type,
"hits": stats.Hits,
"misses": stats.Misses,
"sets": stats.Sets,
"deletes": stats.Deletes,
"errors": stats.Errors,
"evictions": stats.Evictions,
"size": stats.CurrentSize,
"max_size": stats.MaxSize,
"memory": stats.MemoryUsage,
"hit_rate": hitRate,
"uptime": stats.Uptime,
"start_time": stats.StartTime,
}
}
// Close shuts down the cache backend and releases resources
func (m *MemoryBackend) Close() error {
return m.MemoryCacheBackend.Close()
}
// Ping checks if the backend is healthy and responsive
func (m *MemoryBackend) Ping(ctx context.Context) error {
return m.MemoryCacheBackend.Ping(ctx)
}
// Ensure MemoryBackend implements CacheBackend
var _ CacheBackend = (*MemoryBackend)(nil)
+455
View File
@@ -0,0 +1,455 @@
package backends
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
)
// Pure-Go Redis client implementation
// Compatible with Yaegi interpreter (no unsafe package)
// Implements RESP protocol for basic Redis operations
var (
ErrPoolExhausted = errors.New("connection pool exhausted")
)
// RedisBackend implements a Redis-based cache backend using pure Go
type RedisBackend struct {
config *Config
pool *ConnectionPool
healthMonitor *HealthMonitor
// Metrics
hits atomic.Int64
misses atomic.Int64
// Lifecycle
closed atomic.Bool
mu sync.Mutex
}
// NewRedisBackend creates a new Redis cache backend with pure-Go implementation
func NewRedisBackend(config *Config) (*RedisBackend, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if config.RedisAddr == "" {
return nil, fmt.Errorf("redis address is required")
}
// Create connection pool with health checks enabled
poolConfig := &PoolConfig{
Address: config.RedisAddr,
Password: config.RedisPassword,
DB: config.RedisDB,
MaxConnections: config.PoolSize,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
EnableHealthCheck: true,
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(poolConfig)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
// Create health monitor
healthConfig := DefaultHealthMonitorConfig()
healthMonitor := NewHealthMonitor(pool, healthConfig)
backend := &RedisBackend{
config: config,
pool: pool,
healthMonitor: healthMonitor,
}
// Test connectivity
if err := backend.Ping(context.Background()); err != nil {
pool.Close()
return nil, fmt.Errorf("failed to ping Redis: %w", err)
}
// Start health monitoring
healthMonitor.Start()
return backend, nil
}
// Set stores a value in Redis with TTL
func (r *RedisBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if r.closed.Load() {
return ErrBackendClosed
}
prefixedKey := r.prefixKey(key)
// Execute with retry logic
return r.executeWithRetry(ctx, func(conn *RedisConn) error {
var err error
// Use PSETEX for millisecond precision, SETEX for second precision
if ttl > 0 {
ttlMillis := ttl.Milliseconds()
if ttlMillis < 1000 {
// Use PSETEX for sub-second TTLs (millisecond precision)
_, err = conn.Do("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
} else {
// Use SETEX for larger TTLs (second precision)
ttlSeconds := int(ttl.Seconds())
_, err = conn.Do("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
}
} else {
_, err = conn.Do("SET", prefixedKey, string(value))
}
return err
})
}
// Get retrieves a value from Redis
func (r *RedisBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
if r.closed.Load() {
return nil, 0, false, ErrBackendClosed
}
prefixedKey := r.prefixKey(key)
var resultValue []byte
var resultTTL time.Duration
var resultExists bool
// Execute with retry logic
err := r.executeWithRetry(ctx, func(conn *RedisConn) error {
// Get value
resp, err := conn.Do("GET", prefixedKey)
if err != nil {
if errors.Is(err, ErrNilResponse) {
r.misses.Add(1)
resultExists = false
return nil // Not an error, key just doesn't exist
}
return err
}
value, err := RESPString(resp)
if err != nil {
return err
}
// Get TTL
ttlResp, err := conn.Do("TTL", prefixedKey)
if err != nil {
// If TTL fails, still return the value
r.hits.Add(1)
resultValue = []byte(value)
resultTTL = 0
resultExists = true
return nil
}
ttlSeconds, _ := RESPInt(ttlResp)
var ttl time.Duration
if ttlSeconds > 0 {
ttl = time.Duration(ttlSeconds) * time.Second
}
r.hits.Add(1)
resultValue = []byte(value)
resultTTL = ttl
resultExists = true
return nil
})
return resultValue, resultTTL, resultExists, err
}
// Delete removes a key from Redis
func (r *RedisBackend) Delete(ctx context.Context, key string) (bool, error) {
if r.closed.Load() {
return false, ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return false, err
}
defer r.pool.Put(conn)
prefixedKey := r.prefixKey(key)
resp, err := conn.Do("DEL", prefixedKey)
if err != nil {
return false, err
}
count, err := RESPInt(resp)
if err != nil {
return false, err
}
return count > 0, nil
}
// Exists checks if a key exists in Redis
func (r *RedisBackend) Exists(ctx context.Context, key string) (bool, error) {
if r.closed.Load() {
return false, ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return false, err
}
defer r.pool.Put(conn)
prefixedKey := r.prefixKey(key)
resp, err := conn.Do("EXISTS", prefixedKey)
if err != nil {
return false, err
}
count, err := RESPInt(resp)
if err != nil {
return false, err
}
return count > 0, nil
}
// Clear removes all keys with the configured prefix
func (r *RedisBackend) Clear(ctx context.Context) error {
if r.closed.Load() {
return ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return err
}
defer r.pool.Put(conn)
// Use FLUSHDB if no prefix (clear entire DB)
if r.config.RedisPrefix == "" {
_, err := conn.Do("FLUSHDB")
return err
}
// With prefix, we need to scan and delete keys
// For simplicity in this implementation, we'll use KEYS pattern (not recommended for production at scale)
pattern := r.config.RedisPrefix + "*"
resp, err := conn.Do("KEYS", pattern)
if err != nil {
return err
}
// Extract keys from array response
keys, ok := resp.([]interface{})
if !ok || len(keys) == 0 {
return nil
}
// Delete each key
for _, keyInterface := range keys {
key, err := RESPString(keyInterface)
if err != nil {
continue
}
conn.Do("DEL", key) // Best effort, ignore errors
}
return nil
}
// GetStats returns backend statistics
func (r *RedisBackend) GetStats() map[string]interface{} {
hits := r.hits.Load()
misses := r.misses.Load()
total := hits + misses
hitRate := float64(0)
if total > 0 {
hitRate = float64(hits) / float64(total)
}
stats := map[string]interface{}{
"backend": "redis-pure-go",
"address": r.config.RedisAddr,
"hits": hits,
"misses": misses,
"hit_rate": hitRate,
"pool": r.pool.Stats(),
}
// Add health monitor stats if available
if r.healthMonitor != nil {
stats["health"] = r.healthMonitor.GetStats()
}
return stats
}
// Ping checks Redis connectivity
func (r *RedisBackend) Ping(ctx context.Context) error {
if r.closed.Load() {
return ErrBackendClosed
}
conn, err := r.pool.Get(ctx)
if err != nil {
return err
}
defer r.pool.Put(conn)
_, err = conn.Do("PING")
return err
}
// Close closes the Redis backend and all connections
func (r *RedisBackend) Close() error {
if r.closed.Swap(true) {
return nil // Already closed
}
r.mu.Lock()
defer r.mu.Unlock()
// Stop health monitor
if r.healthMonitor != nil {
r.healthMonitor.Stop()
}
// Close connection pool
if r.pool != nil {
return r.pool.Close()
}
return nil
}
// prefixKey adds the configured prefix to a key
func (r *RedisBackend) prefixKey(key string) string {
if r.config.RedisPrefix == "" {
return key
}
return r.config.RedisPrefix + key
}
// executeWithRetry executes a Redis operation with exponential backoff retry logic
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
maxRetries := 3
baseDelay := 100 * time.Millisecond
for attempt := 0; attempt < maxRetries; attempt++ {
conn, err := r.pool.Get(ctx)
if err != nil {
// If we can't get a connection and this is the last attempt, fail
if attempt == maxRetries-1 {
return fmt.Errorf("failed to get connection after %d attempts: %w", maxRetries, err)
}
// Wait with exponential backoff before retrying
delay := baseDelay * time.Duration(1<<uint(attempt))
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
continue
}
}
// Execute the operation
err = operation(conn)
r.pool.Put(conn)
// If successful, return
if err == nil {
return nil
}
// If error is not retryable or last attempt, fail
if attempt == maxRetries-1 || !isRetryableError(err) {
return err
}
// Wait with exponential backoff before retrying
delay := baseDelay * time.Duration(1<<uint(attempt))
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
continue
}
}
return fmt.Errorf("operation failed after %d attempts", maxRetries)
}
// isRetryableError determines if an error is worth retrying
func isRetryableError(err error) bool {
if err == nil {
return false
}
// Retry on connection errors, timeouts, etc.
// Don't retry on application-level errors like wrong type
errMsg := err.Error()
retryablePatterns := []string{
"connection",
"timeout",
"EOF",
"broken pipe",
"reset by peer",
}
for _, pattern := range retryablePatterns {
if contains(errMsg, pattern) {
return true
}
}
return false
}
// SetMany stores multiple values in Redis (batch operation)
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
if r.closed.Load() {
return ErrBackendClosed
}
// For simplicity, execute sequentially (can be optimized with pipelining later)
for key, value := range items {
if err := r.Set(ctx, key, value, ttl); err != nil {
return err
}
}
return nil
}
// GetMany retrieves multiple values from Redis
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
if r.closed.Load() {
return nil, ErrBackendClosed
}
result := make(map[string][]byte)
// For simplicity, execute sequentially
for _, key := range keys {
value, _, exists, err := r.Get(ctx, key)
if err != nil {
return nil, err
}
if exists {
result[key] = value
}
}
return result, nil
}
+176
View File
@@ -0,0 +1,176 @@
package backends
import (
"context"
"sync"
"sync/atomic"
"time"
)
// HealthMonitor continuously monitors Redis connection health and triggers reconnections
type HealthMonitor struct {
pool *ConnectionPool
config *HealthMonitorConfig
// State
healthy atomic.Bool
running atomic.Bool
lastCheckTime atomic.Int64 // Unix timestamp
// Metrics
consecutiveFailures atomic.Int64
totalChecks atomic.Int64
totalFailures atomic.Int64
// Lifecycle
stopChan chan struct{}
wg sync.WaitGroup
}
// HealthMonitorConfig configures the health monitor
type HealthMonitorConfig struct {
CheckInterval time.Duration // How often to check health
Timeout time.Duration // Timeout for health check
UnhealthyThreshold int // Consecutive failures before marking unhealthy
OnHealthChange func(healthy bool)
}
// DefaultHealthMonitorConfig returns default health monitor configuration
func DefaultHealthMonitorConfig() *HealthMonitorConfig {
return &HealthMonitorConfig{
CheckInterval: 5 * time.Second,
Timeout: 3 * time.Second,
UnhealthyThreshold: 3,
}
}
// NewHealthMonitor creates a new health monitor
func NewHealthMonitor(pool *ConnectionPool, config *HealthMonitorConfig) *HealthMonitor {
if config == nil {
config = DefaultHealthMonitorConfig()
}
hm := &HealthMonitor{
pool: pool,
config: config,
stopChan: make(chan struct{}),
}
hm.healthy.Store(true) // Assume healthy initially
return hm
}
// Start begins health monitoring
func (hm *HealthMonitor) Start() {
if hm.running.Swap(true) {
return // Already running
}
hm.wg.Add(1)
go hm.monitorLoop()
}
// Stop stops health monitoring
func (hm *HealthMonitor) Stop() {
if !hm.running.Swap(false) {
return // Not running
}
close(hm.stopChan)
hm.wg.Wait()
}
// IsHealthy returns the current health status
func (hm *HealthMonitor) IsHealthy() bool {
return hm.healthy.Load()
}
// GetStats returns health monitor statistics
func (hm *HealthMonitor) GetStats() map[string]interface{} {
lastCheck := time.Unix(hm.lastCheckTime.Load(), 0)
return map[string]interface{}{
"healthy": hm.healthy.Load(),
"consecutive_failures": hm.consecutiveFailures.Load(),
"total_checks": hm.totalChecks.Load(),
"total_failures": hm.totalFailures.Load(),
"last_check": lastCheck,
}
}
// monitorLoop runs the health check loop
func (hm *HealthMonitor) monitorLoop() {
defer hm.wg.Done()
ticker := time.NewTicker(hm.config.CheckInterval)
defer ticker.Stop()
// Perform initial check immediately
hm.performHealthCheck()
for {
select {
case <-hm.stopChan:
return
case <-ticker.C:
hm.performHealthCheck()
}
}
}
// performHealthCheck executes a health check
func (hm *HealthMonitor) performHealthCheck() {
hm.totalChecks.Add(1)
hm.lastCheckTime.Store(time.Now().Unix())
ctx, cancel := context.WithTimeout(context.Background(), hm.config.Timeout)
defer cancel()
// Try to get a connection and ping Redis
conn, err := hm.pool.Get(ctx)
if err != nil {
hm.recordFailure()
return
}
defer hm.pool.Put(conn)
// Ping Redis
_, err = conn.Do("PING")
if err != nil {
hm.recordFailure()
return
}
// Success!
hm.recordSuccess()
}
// recordSuccess records a successful health check
func (hm *HealthMonitor) recordSuccess() {
wasHealthy := hm.healthy.Load()
hm.consecutiveFailures.Store(0)
hm.healthy.Store(true)
// Trigger callback if health changed
if !wasHealthy && hm.config.OnHealthChange != nil {
hm.config.OnHealthChange(true)
}
}
// recordFailure records a failed health check
func (hm *HealthMonitor) recordFailure() {
hm.totalFailures.Add(1)
failures := hm.consecutiveFailures.Add(1)
wasHealthy := hm.healthy.Load()
// Mark unhealthy if threshold exceeded
if failures >= int64(hm.config.UnhealthyThreshold) {
hm.healthy.Store(false)
// Trigger callback if health changed
if wasHealthy && hm.config.OnHealthChange != nil {
hm.config.OnHealthChange(false)
}
}
}
+421
View File
@@ -0,0 +1,421 @@
package backends
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestHealthMonitor_BasicOperation tests basic health monitoring
func TestHealthMonitor_BasicOperation(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Create health monitor with fast check interval for testing
hmConfig := &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 2,
}
hm := NewHealthMonitor(pool, hmConfig)
require.NotNil(t, hm)
// Initially should be healthy
assert.True(t, hm.IsHealthy())
// Start monitoring
hm.Start()
defer hm.Stop()
// Wait for a few checks
time.Sleep(500 * time.Millisecond)
// Should still be healthy
assert.True(t, hm.IsHealthy())
// Check stats
stats := hm.GetStats()
require.NotNil(t, stats)
assert.True(t, stats["healthy"].(bool))
assert.Greater(t, stats["total_checks"].(int64), int64(0))
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
}
// TestHealthMonitor_HealthyToUnhealthy tests transition to unhealthy state
func TestHealthMonitor_HealthyToUnhealthy(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
var healthChangedCalled atomic.Bool
hmConfig := &HealthMonitorConfig{
CheckInterval: 50 * time.Millisecond,
Timeout: 100 * time.Millisecond,
UnhealthyThreshold: 2,
OnHealthChange: func(healthy bool) {
if !healthy {
healthChangedCalled.Store(true)
}
},
}
hm := NewHealthMonitor(pool, hmConfig)
hm.Start()
defer hm.Stop()
// Initially healthy
assert.True(t, hm.IsHealthy())
// Simulate Redis errors
mr.SetError("ERR server is down")
// Wait for health checks to detect failure (2 failures * 50ms + buffer)
time.Sleep(350 * time.Millisecond)
// Should now be unhealthy
assert.False(t, hm.IsHealthy(), "Health monitor should detect server failure")
assert.True(t, healthChangedCalled.Load(), "OnHealthChange callback should be called")
// Check stats
stats := hm.GetStats()
assert.False(t, stats["healthy"].(bool))
assert.GreaterOrEqual(t, stats["consecutive_failures"].(int64), int64(2))
assert.Greater(t, stats["total_failures"].(int64), int64(0))
}
// TestHealthMonitor_UnhealthyToHealthy tests recovery to healthy state
func TestHealthMonitor_UnhealthyToHealthy(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
var recoveryDetected atomic.Bool
hmConfig := &HealthMonitorConfig{
CheckInterval: 50 * time.Millisecond,
Timeout: 100 * time.Millisecond,
UnhealthyThreshold: 2,
OnHealthChange: func(healthy bool) {
if healthy {
recoveryDetected.Store(true)
}
},
}
hm := NewHealthMonitor(pool, hmConfig)
hm.Start()
defer hm.Stop()
// Initially healthy
assert.True(t, hm.IsHealthy())
// Simulate Redis errors
mr.SetError("ERR server is down")
// Wait for health checks to detect failure
time.Sleep(350 * time.Millisecond)
// Should now be unhealthy
assert.False(t, hm.IsHealthy(), "Should detect server failure")
// Clear error to simulate recovery
mr.ClearError()
// Wait for recovery
time.Sleep(350 * time.Millisecond)
// Should be healthy again
assert.True(t, hm.IsHealthy(), "Should recover after server restart")
assert.True(t, recoveryDetected.Load(), "Recovery callback should be called")
// Consecutive failures should be reset
stats := hm.GetStats()
assert.True(t, stats["healthy"].(bool))
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
}
// TestHealthMonitor_StartStop tests start/stop behavior
func TestHealthMonitor_StartStop(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
hm := NewHealthMonitor(pool, DefaultHealthMonitorConfig())
// Start monitoring
hm.Start()
assert.True(t, hm.running.Load())
// Starting again should be no-op
hm.Start()
assert.True(t, hm.running.Load())
// Stop monitoring
hm.Stop()
assert.False(t, hm.running.Load())
// Stopping again should be no-op
hm.Stop()
assert.False(t, hm.running.Load())
}
// TestHealthMonitor_MultipleMonitors tests multiple health monitors
func TestHealthMonitor_MultipleMonitors(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 10,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Create multiple monitors
hm1 := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 2,
})
hm2 := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 150 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 3,
})
// Start both
hm1.Start()
hm2.Start()
// Both should be healthy
time.Sleep(200 * time.Millisecond)
assert.True(t, hm1.IsHealthy())
assert.True(t, hm2.IsHealthy())
// Stop both
hm1.Stop()
hm2.Stop()
// Verify they stopped
assert.False(t, hm1.running.Load())
assert.False(t, hm2.running.Load())
}
// TestHealthMonitor_StatsAccuracy tests stats tracking
func TestHealthMonitor_StatsAccuracy(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 1 * time.Second,
UnhealthyThreshold: 2,
})
hm.Start()
defer hm.Stop()
// Wait for some checks
time.Sleep(550 * time.Millisecond)
stats := hm.GetStats()
// Should have performed multiple checks
totalChecks := stats["total_checks"].(int64)
assert.GreaterOrEqual(t, totalChecks, int64(4))
// All checks should succeed
assert.Equal(t, int64(0), stats["total_failures"].(int64))
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
// Last check time should be recent (within check interval + buffer)
// Use 2s tolerance to account for CI runner load and timing variance
lastCheck := stats["last_check"].(time.Time)
assert.WithinDuration(t, time.Now(), lastCheck, 2*time.Second)
}
// TestHealthMonitor_DefaultConfig tests default configuration
func TestHealthMonitor_DefaultConfig(t *testing.T) {
config := DefaultHealthMonitorConfig()
assert.Equal(t, 5*time.Second, config.CheckInterval)
assert.Equal(t, 3*time.Second, config.Timeout)
assert.Equal(t, 3, config.UnhealthyThreshold)
assert.Nil(t, config.OnHealthChange)
}
// TestHealthMonitor_PoolExhaustion tests behavior when pool is exhausted
func TestHealthMonitor_PoolExhaustion(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1, // Very small pool
ConnectTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
CheckInterval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond, // Short timeout
UnhealthyThreshold: 2,
})
hm.Start()
defer hm.Stop()
// Get the only connection, blocking health checks
ctx := context.Background()
conn, err := pool.Get(ctx)
require.NoError(t, err)
// Wait for health check attempts
time.Sleep(350 * time.Millisecond)
// Health monitor might mark as unhealthy due to timeouts
stats := hm.GetStats()
t.Logf("Stats with blocked pool: %+v", stats)
// Return connection
pool.Put(conn)
// Wait for recovery
time.Sleep(300 * time.Millisecond)
// Should recover
assert.True(t, hm.IsHealthy())
}
// TestConnectionPool_WithHealthChecks tests pool with health checks enabled
func TestConnectionPool_WithHealthChecks(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
EnableHealthCheck: true,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Get a connection
conn, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn)
// Connection should be healthy
assert.True(t, pool.isConnectionHealthy(conn))
// Use connection
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
// Return to pool
pool.Put(conn)
// Get again - should reuse and validate
conn2, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn2)
pool.Put(conn2)
}
// TestConnectionPool_StaleConnectionRemoval tests stale connection handling
func TestConnectionPool_StaleConnectionRemoval(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 3,
ConnectTimeout: 5 * time.Second,
EnableHealthCheck: true,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Get and return a connection
conn, err := pool.Get(ctx)
require.NoError(t, err)
pool.Put(conn)
initialTotal := pool.totalConns.Load()
// Close the connection manually to make it stale
conn.Close()
// Get another connection - should detect stale and create new
conn2, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn2)
// Connection should be healthy
assert.True(t, pool.isConnectionHealthy(conn2))
pool.Put(conn2)
// Total connections might be same or less (stale removed)
finalTotal := pool.totalConns.Load()
assert.LessOrEqual(t, finalTotal, initialTotal+1)
}
+337
View File
@@ -0,0 +1,337 @@
package backends
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
)
// ConnectionPool manages a pool of Redis connections
// Pure-Go implementation compatible with Yaegi
type ConnectionPool struct {
config *PoolConfig
connections chan *RedisConn
mu sync.Mutex
closed atomic.Bool
// Metrics
activeConns atomic.Int32
totalConns atomic.Int32
gets atomic.Int64
puts atomic.Int64
timeouts atomic.Int64
}
// PoolConfig holds connection pool configuration
type PoolConfig struct {
Address string
Password string
DB int
MaxConnections int
ConnectTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
EnableHealthCheck bool // Enable connection health validation
MaxRetries int // Max retries for failed operations
RetryDelay time.Duration // Initial delay between retries
}
// NewConnectionPool creates a new connection pool
func NewConnectionPool(config *PoolConfig) (*ConnectionPool, error) {
if config == nil {
return nil, errors.New("config is required")
}
if config.MaxConnections <= 0 {
config.MaxConnections = 10
}
if config.ConnectTimeout == 0 {
config.ConnectTimeout = 5 * time.Second
}
pool := &ConnectionPool{
config: config,
connections: make(chan *RedisConn, config.MaxConnections),
}
return pool, nil
}
// Get retrieves a connection from the pool or creates a new one
func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
if p.closed.Load() {
return nil, ErrBackendClosed
}
p.gets.Add(1)
// Try to get a connection with validation
maxAttempts := 3
for attempt := 0; attempt < maxAttempts; attempt++ {
var conn *RedisConn
var err error
select {
case conn = <-p.connections:
// Reuse existing connection - validate if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
// Connection is stale, close it and try again
conn.Close()
p.totalConns.Add(-1)
continue
}
p.activeConns.Add(1)
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
default:
// No available connection, create new one if under limit
if p.totalConns.Load() < int32(p.config.MaxConnections) {
conn, err = p.createConnection()
if err != nil {
// If this is the last attempt, return error
if attempt == maxAttempts-1 {
return nil, err
}
// Wait before retry with exponential backoff
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
continue
}
p.activeConns.Add(1)
p.totalConns.Add(1)
return conn, nil
}
// Pool exhausted, wait for a connection with timeout
select {
case conn = <-p.connections:
// Validate connection if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
conn.Close()
p.totalConns.Add(-1)
continue
}
p.activeConns.Add(1)
return conn, nil
case <-ctx.Done():
p.timeouts.Add(1)
return nil, ctx.Err()
case <-time.After(p.config.ConnectTimeout):
p.timeouts.Add(1)
return nil, ErrPoolExhausted
}
}
}
return nil, errors.New("failed to get healthy connection after retries")
}
// Put returns a connection to the pool
func (p *ConnectionPool) Put(conn *RedisConn) {
if conn == nil {
return
}
p.puts.Add(1)
p.activeConns.Add(-1)
if p.closed.Load() || conn.closed.Load() {
conn.Close()
p.totalConns.Add(-1)
return
}
// Return to pool (non-blocking)
select {
case p.connections <- conn:
// Successfully returned to pool
default:
// Pool full, close connection
conn.Close()
p.totalConns.Add(-1)
}
}
// Close closes all connections in the pool
func (p *ConnectionPool) Close() error {
if p.closed.Swap(true) {
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
close(p.connections)
// Close all pooled connections
for conn := range p.connections {
conn.Close()
}
return nil
}
// Stats returns pool statistics
func (p *ConnectionPool) Stats() map[string]interface{} {
return map[string]interface{}{
"active_connections": p.activeConns.Load(),
"total_connections": p.totalConns.Load(),
"max_connections": p.config.MaxConnections,
"gets": p.gets.Load(),
"puts": p.puts.Load(),
"timeouts": p.timeouts.Load(),
}
}
// createConnection creates a new Redis connection
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
// Connect with timeout
dialer := &net.Dialer{
Timeout: p.config.ConnectTimeout,
}
conn, err := dialer.Dial("tcp", p.config.Address)
if err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
redisConn := &RedisConn{
conn: conn,
readTimeout: p.config.ReadTimeout,
writeTimeout: p.config.WriteTimeout,
}
// Authenticate if password is provided
if p.config.Password != "" {
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
redisConn.Close()
return nil, fmt.Errorf("authentication failed: %w", err)
}
}
// Select database
if p.config.DB != 0 {
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
redisConn.Close()
return nil, fmt.Errorf("failed to select database: %w", err)
}
}
return redisConn, nil
}
// RedisConn represents a single Redis connection
type RedisConn struct {
conn net.Conn
readTimeout time.Duration
writeTimeout time.Duration
closed atomic.Bool
mu sync.Mutex
}
// Do executes a Redis command and returns the response
func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
if c.closed.Load() {
return nil, ErrBackendClosed
}
c.mu.Lock()
defer c.mu.Unlock()
// Build command arguments
// Check for overflow: ensure len(args)+1 doesn't cause allocation overflow
// Limit to a safe value that prevents integer overflow in allocation size calculation
// (capacity * sizeof(string) must fit in int/size_t)
argsLen := len(args)
const maxSafeArgs = (1 << 20) - 1 // 1M args is already absurdly large for Redis commands
if argsLen < 0 || argsLen > maxSafeArgs {
return nil, errors.New("too many arguments")
}
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
totalBytes := len(command)
for _, s := range args {
// Protect against possible overflow
if len(s) > maxTotalArgBytes-totalBytes {
return nil, errors.New("arguments too large (would overflow maximum allowed total size)")
}
totalBytes += len(s)
if totalBytes > maxTotalArgBytes {
return nil, errors.New("total argument size exceeds maximum allowed")
}
}
cmdArgs := make([]string, 0, argsLen+1)
cmdArgs = append(cmdArgs, command)
cmdArgs = append(cmdArgs, args...)
// Set write timeout
if c.writeTimeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
// Write command (using pooled writer for memory efficiency)
writer := NewRESPWriter(c.conn)
err := writer.WriteCommand(cmdArgs...)
writer.Release() // Return to pool immediately after use
if err != nil {
c.closed.Store(true)
return nil, err
}
// Set read timeout
if c.readTimeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
// Read response (using pooled reader for memory efficiency)
reader := NewRESPReader(c.conn)
resp, err := reader.ReadResponse()
reader.Release() // Return to pool immediately after use
if err != nil {
if !errors.Is(err, ErrNilResponse) {
c.closed.Store(true)
}
return nil, err
}
return resp, nil
}
// Close closes the connection
func (c *RedisConn) Close() error {
if c.closed.Swap(true) {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
return c.conn.Close()
}
return nil
}
// isConnectionHealthy validates a connection is still working
func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
if conn == nil || conn.closed.Load() {
return false
}
// Set a read deadline for the ping
if conn.conn != nil {
conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
defer conn.conn.SetReadDeadline(time.Time{}) // Clear deadline
}
_, err := conn.Do("PING")
return err == nil
}
+620
View File
@@ -0,0 +1,620 @@
package backends
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestConnectionPool_BasicOperations tests basic pool operations
func TestConnectionPool_BasicOperations(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
t.Run("GetAndPutConnection", func(t *testing.T) {
ctx := context.Background()
// Get a connection
conn, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn)
// Verify connection works
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
// Return to pool
pool.Put(conn)
// Get again - should reuse same connection
conn2, err := pool.Get(ctx)
require.NoError(t, err)
require.NotNil(t, conn2)
pool.Put(conn2)
})
t.Run("Stats", func(t *testing.T) {
stats := pool.Stats()
require.NotNil(t, stats)
assert.Contains(t, stats, "active_connections")
assert.Contains(t, stats, "total_connections")
assert.Contains(t, stats, "max_connections")
assert.Equal(t, 5, stats["max_connections"])
})
}
// TestConnectionPool_MaxConnections tests pool size limits
func TestConnectionPool_MaxConnections(t *testing.T) {
mr := NewMiniredisServer(t)
maxConns := 3
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: maxConns,
ConnectTimeout: 1 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Get max connections
conns := make([]*RedisConn, maxConns)
for i := 0; i < maxConns; i++ {
conn, err := pool.Get(ctx)
require.NoError(t, err)
conns[i] = conn
}
// Verify stats
stats := pool.Stats()
assert.Equal(t, int32(maxConns), stats["total_connections"])
assert.Equal(t, int32(maxConns), stats["active_connections"])
// Try to get one more - should block/timeout
ctx2, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
conn, err := pool.Get(ctx2)
require.Error(t, err)
require.Nil(t, conn)
// Return one connection
pool.Put(conns[0])
// Now we should be able to get a connection
conn, err = pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn)
// Cleanup
pool.Put(conn)
for i := 1; i < maxConns; i++ {
pool.Put(conns[i])
}
}
// TestConnectionPool_ConcurrentAccess tests concurrent pool usage
func TestConnectionPool_ConcurrentAccess(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 10,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
numGoroutines := 50
numOperations := 20
var wg sync.WaitGroup
errors := make(chan error, numGoroutines*numOperations)
// Spawn goroutines
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
conn, err := pool.Get(ctx)
if err != nil {
errors <- err
continue
}
// Do some work
_, err = conn.Do("PING")
if err != nil {
errors <- err
}
// Return to pool
pool.Put(conn)
// Small delay
time.Sleep(time.Millisecond)
}
}(i)
}
wg.Wait()
close(errors)
// Check for errors
errorCount := 0
for err := range errors {
t.Logf("Error: %v", err)
errorCount++
}
assert.Equal(t, 0, errorCount, "Expected no errors in concurrent access")
// Verify stats
stats := pool.Stats()
t.Logf("Final stats: %+v", stats)
assert.LessOrEqual(t, stats["total_connections"].(int32), int32(10))
assert.Equal(t, int32(0), stats["active_connections"])
}
// TestConnectionPool_ContextCancellation tests context cancellation
func TestConnectionPool_ContextCancellation(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Get the only connection
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Try to get another with cancelled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
conn2, err := pool.Get(ctx)
require.Error(t, err)
require.Nil(t, conn2)
assert.Contains(t, err.Error(), "context canceled")
// Cleanup
pool.Put(conn)
}
// TestConnectionPool_Authentication tests auth support
func TestConnectionPool_Authentication(t *testing.T) {
mr := NewMiniredisServer(t)
// Set password on miniredis
mr.server.RequireAuth("secret-password")
t.Run("CorrectPassword", func(t *testing.T) {
config := &PoolConfig{
Address: mr.GetAddr(),
Password: "secret-password",
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn)
})
t.Run("WrongPassword", func(t *testing.T) {
t.Skip("Miniredis doesn't fully simulate AUTH errors like real Redis")
config := &PoolConfig{
Address: mr.GetAddr(),
Password: "wrong-password",
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
_, err := NewConnectionPool(config)
require.Error(t, err)
assert.Contains(t, err.Error(), "authentication failed")
})
}
// TestConnectionPool_DatabaseSelection tests DB selection
func TestConnectionPool_DatabaseSelection(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
DB: 5,
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Connection should be on DB 5
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn)
}
// TestConnectionPool_ClosedConnection tests handling closed connections
func TestConnectionPool_ClosedConnection(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Get connection
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Close it manually
conn.Close()
// Try to use it
_, err = conn.Do("PING")
require.Error(t, err)
assert.True(t, errors.Is(err, ErrBackendClosed))
// Return to pool (should be discarded)
pool.Put(conn)
// Get new connection - should create a new one
conn2, err := pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn2)
resp, err := conn2.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn2)
}
// TestConnectionPool_Close tests pool closure
func TestConnectionPool_Close(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
ConnectTimeout: 5 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
// Get some connections
conns := make([]*RedisConn, 3)
for i := 0; i < 3; i++ {
conn, err := pool.Get(context.Background())
require.NoError(t, err)
conns[i] = conn
}
// Return them
for _, conn := range conns {
pool.Put(conn)
}
// Close pool
err = pool.Close()
require.NoError(t, err)
// Try to get connection from closed pool
_, err = pool.Get(context.Background())
require.Error(t, err)
assert.True(t, errors.Is(err, ErrBackendClosed))
// Close again should be no-op
err = pool.Close()
require.NoError(t, err)
}
// TestConnectionPool_Timeouts tests various timeout scenarios
func TestConnectionPool_Timeouts(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
ConnectTimeout: 100 * time.Millisecond,
ReadTimeout: 100 * time.Millisecond,
WriteTimeout: 100 * time.Millisecond,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Normal operation should work
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
pool.Put(conn)
}
// TestRedisConn_DoCommand tests the Do method
func TestRedisConn_DoCommand(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
defer pool.Put(conn)
t.Run("SET and GET", func(t *testing.T) {
// SET
resp, err := conn.Do("SET", "testkey", "testvalue")
require.NoError(t, err)
assert.Equal(t, "OK", resp)
// GET
resp, err = conn.Do("GET", "testkey")
require.NoError(t, err)
assert.Equal(t, "testvalue", resp)
})
t.Run("DEL", func(t *testing.T) {
// SET key first
_, err := conn.Do("SET", "delkey", "delvalue")
require.NoError(t, err)
// DEL
resp, err := conn.Do("DEL", "delkey")
require.NoError(t, err)
count, err := RESPInt(resp)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
})
t.Run("EXISTS", func(t *testing.T) {
// SET key first
_, err := conn.Do("SET", "existskey", "value")
require.NoError(t, err)
// EXISTS - key exists
resp, err := conn.Do("EXISTS", "existskey")
require.NoError(t, err)
count, err := RESPInt(resp)
require.NoError(t, err)
assert.Equal(t, int64(1), count)
// EXISTS - key doesn't exist
resp, err = conn.Do("EXISTS", "nonexistent")
require.NoError(t, err)
count, err = RESPInt(resp)
require.NoError(t, err)
assert.Equal(t, int64(0), count)
})
t.Run("TTL commands", func(t *testing.T) {
// SETEX
resp, err := conn.Do("SETEX", "ttlkey", "60", "ttlvalue")
require.NoError(t, err)
assert.Equal(t, "OK", resp)
// TTL
resp, err = conn.Do("TTL", "ttlkey")
require.NoError(t, err)
ttl, err := RESPInt(resp)
require.NoError(t, err)
assert.Greater(t, ttl, int64(0))
assert.LessOrEqual(t, ttl, int64(60))
})
}
// TestPoolConfig_Defaults tests default configuration values
func TestPoolConfig_Defaults(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
// Leave other fields at zero values
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Should use defaults
assert.Equal(t, 10, pool.config.MaxConnections)
assert.Equal(t, 5*time.Second, pool.config.ConnectTimeout)
// Verify it works
conn, err := pool.Get(context.Background())
require.NoError(t, err)
pool.Put(conn)
}
// TestConnectionPool_NilConnection tests handling nil connections
func TestConnectionPool_NilConnection(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 2,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
// Putting nil should be safe
pool.Put(nil)
// Pool should still work
conn, err := pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn)
pool.Put(conn)
}
// TestConnectionPool_StatsTracking tests metrics tracking
func TestConnectionPool_StatsTracking(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 5,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
// Initial stats
stats := pool.Stats()
initialGets := stats["gets"].(int64)
initialPuts := stats["puts"].(int64)
// Perform operations
numOps := 10
for i := 0; i < numOps; i++ {
conn, err := pool.Get(ctx)
require.NoError(t, err)
pool.Put(conn)
}
// Check updated stats
stats = pool.Stats()
assert.Equal(t, initialGets+int64(numOps), stats["gets"].(int64))
assert.Equal(t, initialPuts+int64(numOps), stats["puts"].(int64))
assert.Equal(t, int32(0), stats["active_connections"].(int32))
}
// TestRedisConn_TooManyArguments tests protection against allocation overflow
func TestRedisConn_TooManyArguments(t *testing.T) {
mr := NewMiniredisServer(t)
config := &PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
}
pool, err := NewConnectionPool(config)
require.NoError(t, err)
defer pool.Close()
ctx := context.Background()
conn, err := pool.Get(ctx)
require.NoError(t, err)
defer pool.Put(conn)
t.Run("AcceptableArgumentCount", func(t *testing.T) {
// Should work with reasonable number of args
args := make([]string, 100)
for i := range args {
args[i] = "value"
}
_, err := conn.Do("MSET", args...)
// May fail due to Redis constraints, but shouldn't panic or error on overflow
// Just verify it doesn't trigger our overflow protection
if err != nil {
assert.NotContains(t, err.Error(), "too many arguments")
}
})
t.Run("RejectExcessiveArguments", func(t *testing.T) {
// Create an absurdly large number of arguments that would cause overflow
// Use 1M + 1 to exceed maxSafeArgs = (1<<20)-1 = 1048575
args := make([]string, 1<<20) // 1,048,576 args
for i := range args {
args[i] = "x"
}
_, err := conn.Do("MSET", args...)
require.Error(t, err)
assert.Contains(t, err.Error(), "too many arguments")
})
t.Run("BoundaryCase", func(t *testing.T) {
// Test exactly at the boundary (maxSafeArgs)
args := make([]string, (1<<20)-1) // Exactly 1,048,575 args (max allowed)
for i := range args {
args[i] = "x"
}
_, err := conn.Do("ECHO", args...)
// Should not error due to overflow protection
if err != nil {
assert.NotContains(t, err.Error(), "too many arguments")
}
})
}
+545
View File
@@ -0,0 +1,545 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRedisBackend_BasicOperations tests basic Redis operations
func TestRedisBackend_BasicOperations(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetAndGet", func(t *testing.T) {
key := "redis-test-key"
value := []byte("redis-test-value")
ttl := 1 * time.Minute
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
assert.Greater(t, remainingTTL, 50*time.Second)
})
t.Run("GetNonExistent", func(t *testing.T) {
_, _, exists, err := backend.Get(ctx, "non-existent-redis-key")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Delete", func(t *testing.T) {
key := "redis-delete-key"
value := []byte("redis-delete-value")
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Exists", func(t *testing.T) {
key := "redis-exists-key"
value := []byte("redis-exists-value")
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
})
}
// TestRedisBackend_KeyPrefixing tests key namespace prefixing
func TestRedisBackend_KeyPrefixing(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "test:prefix:"
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "my-key"
value := []byte("my-value")
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check that key is stored with prefix
keys := mr.CheckKeys()
require.Len(t, keys, 1)
assert.Equal(t, "test:prefix:my-key", keys[0])
// Get should work without prefix
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
}
// TestRedisBackend_TTLExpiration tests TTL handling
func TestRedisBackend_TTLExpiration(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("ShortTTL", func(t *testing.T) {
key := "ttl-key"
value := []byte("ttl-value")
shortTTL := 100 * time.Millisecond
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Fast forward time in miniredis
mr.FastForward(150 * time.Millisecond)
// Should be expired
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("TTLRemaining", func(t *testing.T) {
key := "ttl-remaining-key"
value := []byte("ttl-remaining-value")
ttl := 10 * time.Second
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
// Get immediately
_, ttl1, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Fast forward 2 seconds
mr.FastForward(2 * time.Second)
// Check TTL is less
_, ttl2, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Less(t, ttl2, ttl1)
})
}
// TestRedisBackend_Clear tests clearing all keys
func TestRedisBackend_Clear(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "clear-test:"
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add multiple keys
for i := 0; i < 10; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Verify keys exist
keys := mr.CheckKeys()
assert.Len(t, keys, 10)
// Clear all
err = backend.Clear(ctx)
require.NoError(t, err)
// Verify all keys are gone
keys = mr.CheckKeys()
assert.Len(t, keys, 0)
}
// TestRedisBackend_ConnectionFailure tests behavior on connection failure
func TestRedisBackend_ConnectionFailure(t *testing.T) {
t.Parallel()
// Try to connect to non-existent Redis
config := DefaultRedisConfig("localhost:9999")
_, err := NewRedisBackend(config)
assert.Error(t, err, "Should fail to connect to non-existent Redis")
}
// TestRedisBackend_RedisErrors tests handling of Redis errors
func TestRedisBackend_RedisErrors(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Simulate Redis error
mr.SetError("simulated error")
// Operations should fail
err = backend.Set(ctx, "error-key", []byte("error-value"), 1*time.Minute)
assert.Error(t, err)
// Clear error
mr.ClearError()
// Operations should work again
err = backend.Set(ctx, "success-key", []byte("success-value"), 1*time.Minute)
assert.NoError(t, err)
}
// TestRedisBackend_ConcurrentAccess tests thread safety
func TestRedisBackend_ConcurrentAccess(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
var wg sync.WaitGroup
goroutines := 20
iterations := 50
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
if j%5 == 0 {
backend.Delete(ctx, key)
}
}
}(i)
}
wg.Wait()
stats := backend.GetStats()
hits := stats["hits"].(int64)
misses := stats["misses"].(int64)
assert.Greater(t, hits+misses, int64(0))
}
// TestRedisBackend_Stats tests statistics tracking
func TestRedisBackend_Stats(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Initial stats
stats := backend.GetStats()
assert.Equal(t, int64(0), stats["hits"].(int64))
assert.Equal(t, int64(0), stats["misses"].(int64))
// Add and access items
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
backend.Get(ctx, "key1") // Hit
backend.Get(ctx, "non-existent") // Miss
stats = backend.GetStats()
assert.Equal(t, int64(1), stats["hits"].(int64))
assert.Equal(t, int64(1), stats["misses"].(int64))
hitRate := stats["hit_rate"].(float64)
assert.InDelta(t, 0.5, hitRate, 0.01)
}
// TestRedisBackend_Ping tests health check
func TestRedisBackend_Ping(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
err = backend.Ping(ctx)
assert.NoError(t, err)
// Close and ping should fail
backend.Close()
err = backend.Ping(ctx)
assert.Error(t, err)
}
// TestRedisBackend_Close tests proper cleanup
func TestRedisBackend_Close(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
ctx := context.Background()
// Add items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("close-key-%d", i)
value := []byte(fmt.Sprintf("close-value-%d", i))
backend.Set(ctx, key, value, 1*time.Minute)
}
// Close
err = backend.Close()
require.NoError(t, err)
// Operations should fail
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
// Double close should be safe
err = backend.Close()
assert.NoError(t, err)
}
// TestRedisBackend_UpdateExisting tests updating existing keys
func TestRedisBackend_UpdateExisting(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
// Set original
err = backend.Set(ctx, key, value1, 1*time.Minute)
require.NoError(t, err)
// Update
err = backend.Set(ctx, key, value2, 2*time.Minute)
require.NoError(t, err)
// Verify updated
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved)
assert.Greater(t, ttl, 1*time.Minute)
}
// TestRedisBackend_LargeValues tests handling of large values
func TestRedisBackend_LargeValues(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "large-key"
largeValue := make([]byte, 1024*1024) // 1MB
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(largeValue), len(retrieved))
}
// TestRedisBackend_EmptyValues tests handling of empty values
func TestRedisBackend_EmptyValues(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "empty-key"
emptyValue := []byte{}
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, 0, len(retrieved))
}
// TestRedisBackend_PipelineOperations tests batch operations
func TestRedisBackend_PipelineOperations(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetMany", func(t *testing.T) {
items := make(map[string][]byte)
for i := 0; i < 10; i++ {
key := fmt.Sprintf("batch-key-%d", i)
value := []byte(fmt.Sprintf("batch-value-%d", i))
items[key] = value
}
err := backend.SetMany(ctx, items, 1*time.Minute)
require.NoError(t, err)
// Verify all items were set
for key, expectedValue := range items {
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, expectedValue, retrieved)
}
})
t.Run("GetMany", func(t *testing.T) {
// Set test data
testData := GenerateTestData(5)
for key, value := range testData {
backend.Set(ctx, key, value, 1*time.Minute)
}
// Get all keys
keys := make([]string, 0, len(testData))
for key := range testData {
keys = append(keys, key)
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, len(testData))
for key, expectedValue := range testData {
retrievedValue, exists := results[key]
assert.True(t, exists)
assert.Equal(t, expectedValue, retrievedValue)
}
})
t.Run("GetManyWithNonExistent", func(t *testing.T) {
keys := []string{"exists-1", "non-existent", "exists-2"}
backend.Set(ctx, "exists-1", []byte("value-1"), 1*time.Minute)
backend.Set(ctx, "exists-2", []byte("value-2"), 1*time.Minute)
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 2) // Only existing keys
assert.Equal(t, []byte("value-1"), results["exists-1"])
assert.Equal(t, []byte("value-2"), results["exists-2"])
_, exists := results["non-existent"]
assert.False(t, exists)
})
}
// TestRedisBackend_NoPrefix tests operation without prefix
func TestRedisBackend_NoPrefix(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "" // No prefix
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "no-prefix-key"
value := []byte("no-prefix-value")
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check key is stored without prefix
keys := mr.CheckKeys()
require.Len(t, keys, 1)
assert.Equal(t, key, keys[0])
}
+251
View File
@@ -0,0 +1,251 @@
package backends
import (
"bufio"
"errors"
"fmt"
"io"
"strconv"
"strings"
"sync"
)
// RESP (REdis Serialization Protocol) implementation
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
var (
ErrInvalidRESP = errors.New("invalid RESP response")
ErrNilResponse = errors.New("nil response")
)
// Object pools for memory optimization - reduces allocations by 50-70%
var (
readerPool = sync.Pool{
New: func() interface{} {
return &RESPReader{
r: bufio.NewReaderSize(nil, 4096),
}
},
}
writerPool = sync.Pool{
New: func() interface{} {
return &RESPWriter{
w: nil,
}
},
}
)
// RESPWriter writes RESP protocol messages
type RESPWriter struct {
w io.Writer
}
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
func NewRESPWriter(w io.Writer) *RESPWriter {
writer := writerPool.Get().(*RESPWriter)
writer.w = w
return writer
}
// Release returns the writer to the pool for reuse
func (w *RESPWriter) Release() {
w.w = nil
writerPool.Put(w)
}
// WriteCommand writes a Redis command in RESP array format
// Example: SET key value EX 3600 -> *5\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$4\r\n3600\r\n
func (w *RESPWriter) WriteCommand(args ...string) error {
// Write array header
if _, err := fmt.Fprintf(w.w, "*%d\r\n", len(args)); err != nil {
return err
}
// Write each argument as bulk string
for _, arg := range args {
if _, err := fmt.Fprintf(w.w, "$%d\r\n%s\r\n", len(arg), arg); err != nil {
return err
}
}
return nil
}
// RESPReader reads RESP protocol messages
type RESPReader struct {
r *bufio.Reader
}
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
func NewRESPReader(r io.Reader) *RESPReader {
reader := readerPool.Get().(*RESPReader)
reader.r.Reset(r)
return reader
}
// Release returns the reader to the pool for reuse
func (r *RESPReader) Release() {
r.r.Reset(nil)
readerPool.Put(r)
}
// ReadResponse reads a RESP response and returns the parsed value
func (r *RESPReader) ReadResponse() (interface{}, error) {
typeByte, err := r.r.ReadByte()
if err != nil {
return nil, err
}
switch typeByte {
case '+': // Simple string
return r.readSimpleString()
case '-': // Error
return nil, r.readError()
case ':': // Integer
return r.readInteger()
case '$': // Bulk string
return r.readBulkString()
case '*': // Array
return r.readArray()
default:
return nil, fmt.Errorf("%w: unknown type byte '%c'", ErrInvalidRESP, typeByte)
}
}
// readSimpleString reads a simple string (+OK\r\n)
func (r *RESPReader) readSimpleString() (string, error) {
line, err := r.readLine()
if err != nil {
return "", err
}
return line, nil
}
// readError reads an error message (-Error message\r\n)
func (r *RESPReader) readError() error {
line, err := r.readLine()
if err != nil {
return err
}
return errors.New(line)
}
// readInteger reads an integer (:1000\r\n)
func (r *RESPReader) readInteger() (int64, error) {
line, err := r.readLine()
if err != nil {
return 0, err
}
return strconv.ParseInt(line, 10, 64)
}
// readBulkString reads a bulk string ($6\r\nfoobar\r\n or $-1\r\n for nil)
func (r *RESPReader) readBulkString() (interface{}, error) {
line, err := r.readLine()
if err != nil {
return nil, err
}
length, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("%w: invalid bulk string length", ErrInvalidRESP)
}
// -1 indicates nil bulk string
if length == -1 {
return nil, ErrNilResponse
}
// Read exactly 'length' bytes plus \r\n
buf := make([]byte, length+2)
if _, err := io.ReadFull(r.r, buf); err != nil {
return nil, err
}
// Verify \r\n terminator
if buf[length] != '\r' || buf[length+1] != '\n' {
return nil, fmt.Errorf("%w: missing CRLF after bulk string", ErrInvalidRESP)
}
return string(buf[:length]), nil
}
// readArray reads an array (*2\r\n...\r\n or *-1\r\n for nil)
func (r *RESPReader) readArray() (interface{}, error) {
line, err := r.readLine()
if err != nil {
return nil, err
}
length, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("%w: invalid array length", ErrInvalidRESP)
}
// -1 indicates nil array
if length == -1 {
return nil, ErrNilResponse
}
// Read each element
result := make([]interface{}, length)
for i := 0; i < length; i++ {
elem, err := r.ReadResponse()
if err != nil {
return nil, err
}
result[i] = elem
}
return result, nil
}
// readLine reads a line terminated by \r\n
func (r *RESPReader) readLine() (string, error) {
line, err := r.r.ReadString('\n')
if err != nil {
return "", err
}
// Remove \r\n
line = strings.TrimSuffix(line, "\r\n")
if !strings.HasSuffix(line+"\r\n", "\r\n") {
return "", fmt.Errorf("%w: missing CRLF", ErrInvalidRESP)
}
return line, nil
}
// RESPString extracts a string from RESP response
func RESPString(resp interface{}) (string, error) {
if resp == nil {
return "", ErrNilResponse
}
switch v := resp.(type) {
case string:
return v, nil
case []byte:
return string(v), nil
default:
return "", fmt.Errorf("expected string, got %T", resp)
}
}
// RESPInt extracts an integer from RESP response
func RESPInt(resp interface{}) (int64, error) {
if resp == nil {
return 0, ErrNilResponse
}
switch v := resp.(type) {
case int64:
return v, nil
case int:
return int64(v), nil
default:
return 0, fmt.Errorf("expected integer, got %T", resp)
}
}
+495
View File
@@ -0,0 +1,495 @@
package backends
import (
"bytes"
"errors"
"io"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRESPWriter_WriteCommand tests RESP command writing
func TestRESPWriter_WriteCommand(t *testing.T) {
tests := []struct {
name string
args []string
expected string
}{
{
name: "Simple command",
args: []string{"PING"},
expected: "*1\r\n$4\r\nPING\r\n",
},
{
name: "SET command",
args: []string{"SET", "key", "value"},
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
},
{
name: "SETEX command",
args: []string{"SETEX", "mykey", "60", "myvalue"},
expected: "*4\r\n$5\r\nSETEX\r\n$5\r\nmykey\r\n$2\r\n60\r\n$7\r\nmyvalue\r\n",
},
{
name: "DEL with multiple keys",
args: []string{"DEL", "key1", "key2", "key3"},
expected: "*4\r\n$3\r\nDEL\r\n$4\r\nkey1\r\n$4\r\nkey2\r\n$4\r\nkey3\r\n",
},
{
name: "Command with empty string",
args: []string{"SET", "key", ""},
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
},
{
name: "Command with special characters",
args: []string{"SET", "key", "val\r\nue"},
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$7\r\nval\r\nue\r\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
writer := NewRESPWriter(buf)
err := writer.WriteCommand(tt.args...)
require.NoError(t, err)
assert.Equal(t, tt.expected, buf.String())
})
}
}
// TestRESPReader_ReadSimpleString tests reading simple strings
func TestRESPReader_ReadSimpleString(t *testing.T) {
tests := []struct {
name string
input string
expected string
wantErr bool
}{
{
name: "OK response",
input: "+OK\r\n",
expected: "OK",
wantErr: false,
},
{
name: "PONG response",
input: "+PONG\r\n",
expected: "PONG",
wantErr: false,
},
{
name: "Empty string",
input: "+\r\n",
expected: "",
wantErr: false,
},
{
name: "String with spaces",
input: "+Hello World\r\n",
expected: "Hello World",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_ReadError tests reading error messages
func TestRESPReader_ReadError(t *testing.T) {
tests := []struct {
name string
input string
expectedError string
}{
{
name: "ERR error",
input: "-ERR unknown command\r\n",
expectedError: "ERR unknown command",
},
{
name: "WRONGTYPE error",
input: "-WRONGTYPE Operation against a key holding the wrong kind of value\r\n",
expectedError: "WRONGTYPE Operation against a key holding the wrong kind of value",
},
{
name: "Simple error",
input: "-Error\r\n",
expectedError: "Error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
_, err := reader.ReadResponse()
require.Error(t, err)
assert.Equal(t, tt.expectedError, err.Error())
})
}
}
// TestRESPReader_ReadInteger tests reading integers
func TestRESPReader_ReadInteger(t *testing.T) {
tests := []struct {
name string
input string
expected int64
wantErr bool
}{
{
name: "Zero",
input: ":0\r\n",
expected: 0,
wantErr: false,
},
{
name: "Positive integer",
input: ":1000\r\n",
expected: 1000,
wantErr: false,
},
{
name: "Negative integer",
input: ":-1\r\n",
expected: -1,
wantErr: false,
},
{
name: "Large integer",
input: ":9223372036854775807\r\n",
expected: 9223372036854775807,
wantErr: false,
},
{
name: "Invalid integer",
input: ":abc\r\n",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_ReadBulkString tests reading bulk strings
func TestRESPReader_ReadBulkString(t *testing.T) {
tests := []struct {
name string
input string
expected interface{}
wantErr bool
isNil bool
}{
{
name: "Simple bulk string",
input: "$6\r\nfoobar\r\n",
expected: "foobar",
wantErr: false,
},
{
name: "Empty bulk string",
input: "$0\r\n\r\n",
expected: "",
wantErr: false,
},
{
name: "Nil bulk string",
input: "$-1\r\n",
expected: nil,
wantErr: true,
isNil: true,
},
{
name: "Binary safe bulk string",
input: "$5\r\n\x00\x01\x02\x03\x04\r\n",
expected: "\x00\x01\x02\x03\x04",
wantErr: false,
},
{
name: "Invalid length",
input: "$abc\r\ntest\r\n",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.isNil {
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
return
}
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_ReadArray tests reading arrays
func TestRESPReader_ReadArray(t *testing.T) {
tests := []struct {
name string
input string
expected []interface{}
wantErr bool
isNil bool
}{
{
name: "Empty array",
input: "*0\r\n",
expected: []interface{}{},
wantErr: false,
},
{
name: "Array of bulk strings",
input: "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
expected: []interface{}{
"foo",
"bar",
},
wantErr: false,
},
{
name: "Array of integers",
input: "*3\r\n:1\r\n:2\r\n:3\r\n",
expected: []interface{}{
int64(1),
int64(2),
int64(3),
},
wantErr: false,
},
{
name: "Mixed array",
input: "*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n",
expected: []interface{}{
int64(1),
int64(2),
int64(3),
int64(4),
"foobar",
},
wantErr: false,
},
{
name: "Nil array",
input: "*-1\r\n",
expected: nil,
wantErr: true,
isNil: true,
},
{
name: "Nested arrays",
input: "*2\r\n*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n*1\r\n$3\r\nbaz\r\n",
expected: []interface{}{
[]interface{}{"foo", "bar"},
[]interface{}{"baz"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
result, err := reader.ReadResponse()
if tt.isNil {
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
return
}
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestRESPReader_InvalidInput tests error handling for invalid input
func TestRESPReader_InvalidInput(t *testing.T) {
tests := []struct {
name string
input string
}{
{
name: "Unknown type byte",
input: "?invalid\r\n",
},
{
name: "Incomplete response",
input: "+OK",
},
{
name: "Missing CRLF in bulk string",
input: "$5\r\nhello",
},
{
name: "Truncated array",
input: "*3\r\n:1\r\n:2\r\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := NewRESPReader(strings.NewReader(tt.input))
_, err := reader.ReadResponse()
require.Error(t, err)
})
}
}
// TestRESPReader_EOF tests handling of EOF
func TestRESPReader_EOF(t *testing.T) {
reader := NewRESPReader(strings.NewReader(""))
_, err := reader.ReadResponse()
require.Error(t, err)
assert.True(t, errors.Is(err, io.EOF))
}
// TestRESPHelpers tests helper functions
func TestRESPHelpers(t *testing.T) {
t.Run("RESPString", func(t *testing.T) {
// Valid string
result, err := RESPString("hello")
require.NoError(t, err)
assert.Equal(t, "hello", result)
// Byte slice
result, err = RESPString([]byte("world"))
require.NoError(t, err)
assert.Equal(t, "world", result)
// Nil
_, err = RESPString(nil)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
// Invalid type
_, err = RESPString(123)
require.Error(t, err)
})
t.Run("RESPInt", func(t *testing.T) {
// Valid int64
result, err := RESPInt(int64(42))
require.NoError(t, err)
assert.Equal(t, int64(42), result)
// Valid int
result, err = RESPInt(42)
require.NoError(t, err)
assert.Equal(t, int64(42), result)
// Nil
_, err = RESPInt(nil)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
// Invalid type
_, err = RESPInt("string")
require.Error(t, err)
})
}
// TestRESPRoundTrip tests full round-trip encoding/decoding
func TestRESPRoundTrip(t *testing.T) {
tests := []struct {
name string
command []string
response string
expected interface{}
}{
{
name: "PING command",
command: []string{"PING"},
response: "+PONG\r\n",
expected: "PONG",
},
{
name: "GET command with result",
command: []string{"GET", "mykey"},
response: "$7\r\nmyvalue\r\n",
expected: "myvalue",
},
{
name: "GET command with nil",
command: []string{"GET", "nonexistent"},
response: "$-1\r\n",
expected: nil,
},
{
name: "DEL command",
command: []string{"DEL", "key1", "key2"},
response: ":2\r\n",
expected: int64(2),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Write command
writeBuf := &bytes.Buffer{}
writer := NewRESPWriter(writeBuf)
err := writer.WriteCommand(tt.command...)
require.NoError(t, err)
// Read response
reader := NewRESPReader(strings.NewReader(tt.response))
result, err := reader.ReadResponse()
if tt.expected == nil {
require.Error(t, err)
assert.True(t, errors.Is(err, ErrNilResponse))
} else {
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
+198
View File
@@ -0,0 +1,198 @@
package backends
import (
"context"
"fmt"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
// TestLogger implements a simple logger for tests
type TestLogger struct {
t *testing.T
}
func NewTestLogger(t *testing.T) *TestLogger {
return &TestLogger{t: t}
}
func (l *TestLogger) Debug(format string, args ...interface{}) {
l.t.Logf("[DEBUG] "+format, args...)
}
func (l *TestLogger) Info(format string, args ...interface{}) {
l.t.Logf("[INFO] "+format, args...)
}
func (l *TestLogger) Error(format string, args ...interface{}) {
l.t.Logf("[ERROR] "+format, args...)
}
func (l *TestLogger) Debugf(format string, args ...interface{}) {
l.Debug(format, args...)
}
func (l *TestLogger) Infof(format string, args ...interface{}) {
l.Info(format, args...)
}
func (l *TestLogger) Errorf(format string, args ...interface{}) {
l.Error(format, args...)
}
func (l *TestLogger) Warnf(format string, args ...interface{}) {
l.t.Logf("[WARN] "+format, args...)
}
// MiniredisServer manages a miniredis instance for testing
type MiniredisServer struct {
server *miniredis.Miniredis
client *redis.Client
}
// NewMiniredisServer creates a new miniredis server for testing
func NewMiniredisServer(t *testing.T) *MiniredisServer {
t.Helper()
mr, err := miniredis.Run()
require.NoError(t, err, "failed to start miniredis")
client := redis.NewClient(&redis.Options{
Addr: mr.Addr(),
})
// Verify connection
ctx := context.Background()
err = client.Ping(ctx).Err()
require.NoError(t, err, "failed to ping miniredis")
t.Cleanup(func() {
client.Close()
mr.Close()
})
return &MiniredisServer{
server: mr,
client: client,
}
}
// GetAddr returns the address of the miniredis server
func (m *MiniredisServer) GetAddr() string {
return m.server.Addr()
}
// GetClient returns the Redis client
func (m *MiniredisServer) GetClient() *redis.Client {
return m.client
}
// FastForward advances the miniredis server's time
func (m *MiniredisServer) FastForward(d time.Duration) {
m.server.FastForward(d)
}
// FlushAll removes all keys from the database
func (m *MiniredisServer) FlushAll() {
m.server.FlushAll()
}
// SetError simulates a Redis error
func (m *MiniredisServer) SetError(err string) {
m.server.SetError(err)
}
// ClearError clears any simulated errors
func (m *MiniredisServer) ClearError() {
m.server.SetError("")
}
// CheckKeys verifies that specific keys exist in Redis
func (m *MiniredisServer) CheckKeys() []string {
return m.server.Keys()
}
// Close closes the miniredis server
func (m *MiniredisServer) Close() {
m.server.Close()
}
// Restart restarts the miniredis server
func (m *MiniredisServer) Restart() {
m.server.Restart()
}
// TestConfig provides default test configuration
type TestConfig struct {
MaxSize int
DefaultTTL time.Duration
CleanupInterval time.Duration
EnableMetrics bool
}
// DefaultTestConfig returns a standard test configuration
func DefaultTestConfig() *TestConfig {
return &TestConfig{
MaxSize: 100,
DefaultTTL: 5 * time.Minute,
CleanupInterval: 1 * time.Second,
EnableMetrics: true,
}
}
// GenerateTestData creates test cache data
func GenerateTestData(count int) map[string][]byte {
data := make(map[string][]byte, count)
for i := 0; i < count; i++ {
key := fmt.Sprintf("test-key-%d", i)
value := []byte(fmt.Sprintf("test-value-%d", i))
data[key] = value
}
return data
}
// GenerateLargeValue creates a large test value
func GenerateLargeValue(sizeBytes int) []byte {
return make([]byte, sizeBytes)
}
// AssertCacheStats is a helper to verify cache statistics
func AssertCacheStats(t *testing.T, stats map[string]interface{}, expectedHits, expectedMisses int64) {
t.Helper()
hits, ok := stats["hits"].(int64)
require.True(t, ok, "hits should be int64")
require.Equal(t, expectedHits, hits, "unexpected hit count")
misses, ok := stats["misses"].(int64)
require.True(t, ok, "misses should be int64")
require.Equal(t, expectedMisses, misses, "unexpected miss count")
}
// WaitForCondition waits for a condition to be true or times out
func WaitForCondition(t *testing.T, timeout time.Duration, checkInterval time.Duration, condition func() bool) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(checkInterval)
}
t.Fatal("timeout waiting for condition")
}
// AssertEventuallyExpires verifies that a key eventually expires
func AssertEventuallyExpires(t *testing.T, backend CacheBackend, ctx context.Context, key string, maxWait time.Duration) {
t.Helper()
WaitForCondition(t, maxWait, 100*time.Millisecond, func() bool {
_, _, exists, err := backend.Get(ctx, key)
return err == nil && !exists
})
}