mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Add redis support for distributed caching (#83)
* Add redis support for distributed caching * Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * fixup! fixup! fixup! fixup! fixup! Move towards the self-provided Redis connection pool and RESP protocol implementation. Official redis client library won't work with yaegi. * ... and another all nighter. * fixup! ... and another all nighter. * fixup! fixup! ... and another all nighter. * fixup! fixup! fixup! ... and another all nighter. * Resolve issue #85 by adding ability to set custom claims in JWT tokens * Remove redundant validation in auth middleware ( issue #89 ) * Add ability to set cookie prefix for session cookies ( #87 ) * fixup! Add ability to set cookie prefix for session cookies ( #87 ) * Add ability to set cookie max age - issue #91 * Potential fix for code scanning alert no. 10: Size computation for allocation may overflow Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> * fixup! Merge main into 0.8.0-redis: resolve conflicts --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
Vendored
+90
@@ -0,0 +1,90 @@
|
||||
package backends
|
||||
|
||||
import "time"
|
||||
|
||||
// BackendType represents the type of cache backend
|
||||
type BackendType string
|
||||
|
||||
const (
|
||||
BackendTypeMemory BackendType = "memory"
|
||||
BackendTypeRedis BackendType = "redis"
|
||||
BackendTypeHybrid BackendType = "hybrid"
|
||||
|
||||
// Aliases for backward compatibility
|
||||
TypeMemory BackendType = "memory"
|
||||
TypeRedis BackendType = "redis"
|
||||
TypeHybrid BackendType = "hybrid"
|
||||
)
|
||||
|
||||
// Config provides common configuration for cache backends
|
||||
type Config struct {
|
||||
// Type specifies the backend type
|
||||
Type BackendType
|
||||
|
||||
// Memory backend settings
|
||||
MaxSize int
|
||||
MaxMemoryBytes int64
|
||||
CleanupInterval time.Duration
|
||||
|
||||
// Redis backend settings
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
RedisPrefix string
|
||||
PoolSize int
|
||||
|
||||
// Hybrid backend settings
|
||||
L1Config *Config // Memory cache (L1)
|
||||
L2Config *Config // Redis cache (L2)
|
||||
AsyncWrites bool // Write to L2 asynchronously
|
||||
|
||||
// Resilience settings
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
HealthCheckInterval time.Duration
|
||||
|
||||
// Metrics
|
||||
EnableMetrics bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeMemory,
|
||||
MaxSize: 1000,
|
||||
MaxMemoryBytes: 50 * 1024 * 1024, // 50MB
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRedisConfig returns a default configuration for Redis caching
|
||||
func DefaultRedisConfig(addr string) *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeRedis,
|
||||
RedisAddr: addr,
|
||||
RedisDB: 0,
|
||||
RedisPrefix: "traefikoidc:",
|
||||
PoolSize: 10,
|
||||
EnableCircuitBreaker: true,
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHybridConfig returns a default configuration for hybrid caching
|
||||
func DefaultHybridConfig(redisAddr string) *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeHybrid,
|
||||
L1Config: &Config{
|
||||
Type: BackendTypeMemory,
|
||||
MaxSize: 500,
|
||||
MaxMemoryBytes: 10 * 1024 * 1024, // 10MB for L1
|
||||
CleanupInterval: 1 * time.Minute,
|
||||
},
|
||||
L2Config: DefaultRedisConfig(redisAddr),
|
||||
AsyncWrites: true,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
+59
@@ -0,0 +1,59 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package backends
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDefaultHybridConfig verifies the default hybrid configuration
|
||||
func TestDefaultHybridConfig(t *testing.T) {
|
||||
redisAddr := "localhost:6379"
|
||||
|
||||
config := DefaultHybridConfig(redisAddr)
|
||||
|
||||
require.NotNil(t, config)
|
||||
|
||||
// Verify top-level config
|
||||
assert.Equal(t, BackendTypeHybrid, config.Type)
|
||||
assert.True(t, config.AsyncWrites)
|
||||
assert.True(t, config.EnableMetrics)
|
||||
|
||||
// Verify L1 (memory) config
|
||||
require.NotNil(t, config.L1Config)
|
||||
assert.Equal(t, BackendTypeMemory, config.L1Config.Type)
|
||||
assert.Equal(t, 500, config.L1Config.MaxSize)
|
||||
assert.Equal(t, int64(10*1024*1024), config.L1Config.MaxMemoryBytes) // 10MB
|
||||
assert.Equal(t, 1*time.Minute, config.L1Config.CleanupInterval)
|
||||
|
||||
// Verify L2 (Redis) config exists
|
||||
require.NotNil(t, config.L2Config)
|
||||
assert.Equal(t, BackendTypeRedis, config.L2Config.Type)
|
||||
}
|
||||
|
||||
func TestDefaultHybridConfig_DifferentRedisAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
redisAddr string
|
||||
}{
|
||||
{"localhost", "localhost:6379"},
|
||||
{"remote host", "redis.example.com:6379"},
|
||||
{"IP address", "192.168.1.100:6379"},
|
||||
{"custom port", "localhost:6380"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := DefaultHybridConfig(tt.redisAddr)
|
||||
|
||||
require.NotNil(t, config)
|
||||
assert.Equal(t, BackendTypeHybrid, config.Type)
|
||||
assert.NotNil(t, config.L1Config)
|
||||
assert.NotNil(t, config.L2Config)
|
||||
})
|
||||
}
|
||||
}
|
||||
Vendored
+38
@@ -0,0 +1,38 @@
|
||||
package backends
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrBackendClosed is returned when operating on a closed backend
|
||||
ErrBackendClosed = errors.New("cache backend is closed")
|
||||
|
||||
// ErrKeyNotFound is returned when a key doesn't exist
|
||||
ErrKeyNotFound = errors.New("key not found")
|
||||
|
||||
// ErrCacheMiss indicates the requested key was not found in the cache
|
||||
ErrCacheMiss = errors.New("cache miss")
|
||||
|
||||
// ErrBackendUnavailable indicates the cache backend is not available
|
||||
ErrBackendUnavailable = errors.New("cache backend unavailable")
|
||||
|
||||
// ErrInvalidValue indicates the cached value is invalid or corrupted
|
||||
ErrInvalidValue = errors.New("invalid cached value")
|
||||
|
||||
// ErrInvalidTTL is returned when TTL is invalid
|
||||
ErrInvalidTTL = errors.New("invalid TTL")
|
||||
|
||||
// ErrConnectionFailed is returned when connection fails
|
||||
ErrConnectionFailed = errors.New("connection failed")
|
||||
|
||||
// ErrCircuitOpen is returned when circuit breaker is open
|
||||
ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||
|
||||
// ErrTimeout is returned when operation times out
|
||||
ErrTimeout = errors.New("operation timeout")
|
||||
|
||||
// ErrSerializationFailed is returned when serialization fails
|
||||
ErrSerializationFailed = errors.New("serialization failed")
|
||||
|
||||
// ErrDeserializationFailed is returned when deserialization fails
|
||||
ErrDeserializationFailed = errors.New("deserialization failed")
|
||||
)
|
||||
Vendored
+695
@@ -0,0 +1,695 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
|
||||
// It provides automatic failover, async writes for non-critical data, and optimized read paths
|
||||
type HybridBackend struct {
|
||||
primary CacheBackend // L1: Memory cache for fast access
|
||||
secondary CacheBackend // L2: Redis cache for distributed access
|
||||
|
||||
// Configuration
|
||||
syncWriteCacheTypes map[string]bool // Which cache types require synchronous writes
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
|
||||
// Metrics
|
||||
l1Hits atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
l1Writes atomic.Int64
|
||||
l2Writes atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Fallback tracking
|
||||
fallbackMode atomic.Bool // True when operating in degraded mode (L1 only)
|
||||
lastL2Error atomic.Value // Stores last L2 error timestamp
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Logging
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// asyncWriteItem represents an async write operation
|
||||
type asyncWriteItem struct {
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Warnf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// defaultLogger provides a basic logger implementation
|
||||
type defaultLogger struct {
|
||||
*log.Logger
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
|
||||
l.Printf("[DEBUG] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Infof(format string, args ...interface{}) {
|
||||
l.Printf("[INFO] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Warnf(format string, args ...interface{}) {
|
||||
l.Printf("[WARN] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
|
||||
l.Printf("[ERROR] "+format, args...)
|
||||
}
|
||||
|
||||
// HybridConfig provides configuration for the hybrid backend
|
||||
type HybridConfig struct {
|
||||
Primary CacheBackend
|
||||
Secondary CacheBackend
|
||||
SyncWriteCacheTypes map[string]bool // Cache types requiring synchronous L2 writes
|
||||
AsyncBufferSize int
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
|
||||
func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
if config.Primary == nil {
|
||||
return nil, fmt.Errorf("primary (L1) backend is required")
|
||||
}
|
||||
|
||||
if config.Secondary == nil {
|
||||
return nil, fmt.Errorf("secondary (L2) backend is required")
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
config.Logger = &defaultLogger{Logger: log.New(log.Writer(), "[HybridCache] ", log.LstdFlags)}
|
||||
}
|
||||
|
||||
if config.AsyncBufferSize <= 0 {
|
||||
config.AsyncBufferSize = 1000
|
||||
}
|
||||
|
||||
// Default critical cache types that require synchronous writes
|
||||
if config.SyncWriteCacheTypes == nil {
|
||||
config.SyncWriteCacheTypes = map[string]bool{
|
||||
"blacklist": true, // Token blacklist must be immediately consistent
|
||||
"token": true, // Token validation is critical
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
h := &HybridBackend{
|
||||
primary: config.Primary,
|
||||
secondary: config.Secondary,
|
||||
syncWriteCacheTypes: config.SyncWriteCacheTypes,
|
||||
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: config.Logger,
|
||||
}
|
||||
|
||||
// Start async write worker
|
||||
h.wg.Add(1)
|
||||
go h.asyncWriteWorker()
|
||||
|
||||
// Start health monitoring
|
||||
h.wg.Add(1)
|
||||
go h.healthMonitor()
|
||||
|
||||
h.logger.Infof("HybridBackend initialized with L1 (memory) and L2 (Redis) tiers")
|
||||
h.logger.Infof("Sync write cache types: %v", config.SyncWriteCacheTypes)
|
||||
h.logger.Infof("Async write buffer size: %d", config.AsyncBufferSize)
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Set stores a value in both L1 and L2 caches
|
||||
func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
// Always write to L1 first (synchronous)
|
||||
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Warnf("Failed to write to L1 cache: %v", err)
|
||||
// Continue to try L2 even if L1 fails
|
||||
} else {
|
||||
h.l1Writes.Add(1)
|
||||
}
|
||||
|
||||
// Check if we're in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", key)
|
||||
return nil // Don't fail the operation if L2 is down
|
||||
}
|
||||
|
||||
// Determine if this should be a sync or async write based on cache type
|
||||
cacheType := h.extractCacheType(key)
|
||||
requiresSync := h.syncWriteCacheTypes[cacheType]
|
||||
|
||||
if requiresSync {
|
||||
// Synchronous write for critical cache types
|
||||
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", key, err)
|
||||
h.recordL2Error()
|
||||
// Don't fail the operation - L1 write succeeded
|
||||
return nil
|
||||
}
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", key)
|
||||
} else {
|
||||
// Asynchronous write for non-critical cache types
|
||||
select {
|
||||
case h.asyncWriteBuffer <- &asyncWriteItem{
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
ctx: ctx,
|
||||
}:
|
||||
h.logger.Debugf("Queued async write to L2 for key: %s", key)
|
||||
default:
|
||||
// Buffer is full, log and continue
|
||||
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", key)
|
||||
h.errors.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from cache, checking L1 first, then L2
|
||||
func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
// Try L1 first
|
||||
value, ttl, exists, err := h.primary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L1 get error for key %s: %v", key, err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
h.l1Hits.Add(1)
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
|
||||
// Check if we're in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
// Try L2
|
||||
value, ttl, exists, err = h.secondary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L2 get error for key %s: %v", key, err)
|
||||
h.recordL2Error()
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil // Don't propagate L2 errors
|
||||
}
|
||||
|
||||
if !exists {
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Populate L1 cache with value from L2 (write-through on read)
|
||||
// Use goroutine to avoid blocking the read path
|
||||
go func() {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
|
||||
}
|
||||
}()
|
||||
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
|
||||
// Delete removes a key from both L1 and L2 caches
|
||||
func (h *HybridBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
var deleted bool
|
||||
|
||||
// Delete from L1
|
||||
if d, err := h.primary.Delete(ctx, key); err != nil {
|
||||
h.logger.Debugf("Failed to delete from L1 cache: %v", err)
|
||||
} else if d {
|
||||
deleted = true
|
||||
}
|
||||
|
||||
// Delete from L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if d, err := h.secondary.Delete(ctx, key); err != nil {
|
||||
h.logger.Debugf("Failed to delete from L2 cache: %v", err)
|
||||
h.recordL2Error()
|
||||
} else if d {
|
||||
deleted = true
|
||||
}
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in either cache
|
||||
func (h *HybridBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
// Check L1 first
|
||||
if exists, err := h.primary.Exists(ctx, key); err == nil && exists {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if exists, err := h.secondary.Exists(ctx, key); err == nil && exists {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Clear removes all keys from both caches
|
||||
func (h *HybridBackend) Clear(ctx context.Context) error {
|
||||
var lastErr error
|
||||
|
||||
// Clear L1
|
||||
if err := h.primary.Clear(ctx); err != nil {
|
||||
h.logger.Errorf("Failed to clear L1 cache: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
// Clear L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if err := h.secondary.Clear(ctx); err != nil {
|
||||
h.logger.Errorf("Failed to clear L2 cache: %v", err)
|
||||
h.recordL2Error()
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetStats returns statistics for the hybrid cache
|
||||
func (h *HybridBackend) GetStats() map[string]interface{} {
|
||||
l1Hits := h.l1Hits.Load()
|
||||
l2Hits := h.l2Hits.Load()
|
||||
misses := h.misses.Load()
|
||||
total := l1Hits + l2Hits + misses
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"type": TypeHybrid,
|
||||
"l1_hits": l1Hits,
|
||||
"l2_hits": l2Hits,
|
||||
"misses": misses,
|
||||
"total": total,
|
||||
"l1_writes": h.l1Writes.Load(),
|
||||
"l2_writes": h.l2Writes.Load(),
|
||||
"errors": h.errors.Load(),
|
||||
"fallback_mode": h.fallbackMode.Load(),
|
||||
}
|
||||
|
||||
if total > 0 {
|
||||
stats["l1_hit_rate"] = float64(l1Hits) / float64(total)
|
||||
stats["l2_hit_rate"] = float64(l2Hits) / float64(total)
|
||||
stats["overall_hit_rate"] = float64(l1Hits+l2Hits) / float64(total)
|
||||
}
|
||||
|
||||
// Add sub-backend stats
|
||||
stats["l1_stats"] = h.primary.GetStats()
|
||||
stats["l2_stats"] = h.secondary.GetStats()
|
||||
|
||||
// Add last L2 error time if available
|
||||
if lastErr := h.lastL2Error.Load(); lastErr != nil {
|
||||
if t, ok := lastErr.(time.Time); ok {
|
||||
stats["last_l2_error"] = t.Format(time.RFC3339)
|
||||
stats["seconds_since_l2_error"] = time.Since(t).Seconds()
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks if both backends are healthy
|
||||
func (h *HybridBackend) Ping(ctx context.Context) error {
|
||||
// Check L1
|
||||
if err := h.primary.Ping(ctx); err != nil {
|
||||
return fmt.Errorf("L1 ping failed: %w", err)
|
||||
}
|
||||
|
||||
// Check L2 (but don't fail if it's down)
|
||||
if err := h.secondary.Ping(ctx); err != nil {
|
||||
h.logger.Warnf("L2 ping failed: %v", err)
|
||||
h.recordL2Error()
|
||||
// Don't return error - we can operate with L1 only
|
||||
} else {
|
||||
// L2 is healthy, clear fallback mode if it was set
|
||||
if h.fallbackMode.CompareAndSwap(true, false) {
|
||||
h.logger.Infof("L2 backend recovered, exiting fallback mode")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close shuts down the hybrid backend
|
||||
func (h *HybridBackend) Close() error {
|
||||
// Cancel context to stop workers
|
||||
h.cancel()
|
||||
|
||||
// Close async write channel
|
||||
close(h.asyncWriteBuffer)
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
h.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Workers finished
|
||||
case <-time.After(5 * time.Second):
|
||||
h.logger.Warnf("Timeout waiting for workers to finish")
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
|
||||
// Close backends
|
||||
if err := h.primary.Close(); err != nil {
|
||||
h.logger.Errorf("Failed to close L1 backend: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if err := h.secondary.Close(); err != nil {
|
||||
h.logger.Errorf("Failed to close L2 backend: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
h.logger.Infof("HybridBackend closed")
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values efficiently
|
||||
func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
results := make(map[string][]byte, len(keys))
|
||||
missingKeys := make([]string, 0)
|
||||
|
||||
// Try L1 first for all keys
|
||||
for _, key := range keys {
|
||||
if value, _, exists, _ := h.primary.Get(ctx, key); exists {
|
||||
results[key] = value
|
||||
h.l1Hits.Add(1)
|
||||
} else {
|
||||
missingKeys = append(missingKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
// If all found in L1 or in fallback mode, return
|
||||
if len(missingKeys) == 0 || h.fallbackMode.Load() {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Try L2 for missing keys using batch operation if available
|
||||
if batcher, ok := h.secondary.(interface {
|
||||
GetMany(context.Context, []string) (map[string][]byte, error)
|
||||
}); ok {
|
||||
l2Results, err := batcher.GetMany(ctx, missingKeys)
|
||||
if err != nil {
|
||||
h.logger.Debugf("L2 batch get error: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
for key, value := range l2Results {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
|
||||
}(key, value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to individual gets
|
||||
for _, key := range missingKeys {
|
||||
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte, t time.Duration) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, t)
|
||||
}(key, value, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count misses for keys not found anywhere
|
||||
for _, key := range keys {
|
||||
if _, found := results[key]; !found {
|
||||
h.misses.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// SetMany stores multiple key-value pairs efficiently
|
||||
func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write to L1 first
|
||||
for key, value := range items {
|
||||
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to write to L1 in batch: %v", err)
|
||||
} else {
|
||||
h.l1Writes.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Skip L2 if in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if L2 supports batch operations
|
||||
if batcher, ok := h.secondary.(interface {
|
||||
SetMany(context.Context, map[string][]byte, time.Duration) error
|
||||
}); ok {
|
||||
if err := batcher.SetMany(ctx, items, ttl); err != nil {
|
||||
h.logger.Warnf("Failed to batch write to L2: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(int64(len(items)))
|
||||
}
|
||||
} else {
|
||||
// Fallback to individual sets
|
||||
for key, value := range items {
|
||||
cacheType := h.extractCacheType(key)
|
||||
if h.syncWriteCacheTypes[cacheType] {
|
||||
// Sync write for critical types
|
||||
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to write to L2: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(1)
|
||||
}
|
||||
} else {
|
||||
// Async write for non-critical types
|
||||
select {
|
||||
case h.asyncWriteBuffer <- &asyncWriteItem{
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
ctx: ctx,
|
||||
}:
|
||||
// Queued
|
||||
default:
|
||||
h.logger.Warnf("Async buffer full for batch write")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// asyncWriteWorker processes asynchronous writes to L2
|
||||
func (h *HybridBackend) asyncWriteWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
// Drain remaining items with best effort
|
||||
for len(h.asyncWriteBuffer) > 0 {
|
||||
select {
|
||||
case item := <-h.asyncWriteBuffer:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
_ = h.secondary.Set(ctx, item.key, item.value, item.ttl)
|
||||
cancel()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
case item, ok := <-h.asyncWriteBuffer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Perform the write with a timeout
|
||||
writeCtx, cancel := context.WithTimeout(item.ctx, 500*time.Millisecond)
|
||||
if err := h.secondary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("Async write to L2 failed for key %s: %v", item.key, err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Async write to L2 completed for key: %s", item.key)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthMonitor periodically checks L2 health and manages fallback mode
|
||||
func (h *HybridBackend) healthMonitor() {
|
||||
defer h.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
if err := h.secondary.Ping(ctx); err != nil {
|
||||
if !h.fallbackMode.Load() {
|
||||
h.fallbackMode.Store(true)
|
||||
h.logger.Warnf("L2 backend unhealthy, entering fallback mode: %v", err)
|
||||
}
|
||||
} else {
|
||||
if h.fallbackMode.CompareAndSwap(true, false) {
|
||||
h.logger.Infof("L2 backend healthy, exiting fallback mode")
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordL2Error records the timestamp of an L2 error
|
||||
func (h *HybridBackend) recordL2Error() {
|
||||
h.lastL2Error.Store(time.Now())
|
||||
|
||||
// Check if we should enter fallback mode based on recent errors
|
||||
if !h.fallbackMode.Load() {
|
||||
// Simple heuristic: if we've had an error in the last second, consider L2 unhealthy
|
||||
if lastErr := h.lastL2Error.Load(); lastErr != nil {
|
||||
if t, ok := lastErr.(time.Time); ok && time.Since(t) < time.Second {
|
||||
h.fallbackMode.Store(true)
|
||||
h.logger.Warnf("Multiple L2 errors detected, entering fallback mode")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractCacheType attempts to determine the cache type from the key
|
||||
func (h *HybridBackend) extractCacheType(key string) string {
|
||||
// Simple heuristic based on key prefixes
|
||||
// This should match the actual cache type strategy in the main application
|
||||
|
||||
if len(key) > 10 {
|
||||
prefix := key[:10]
|
||||
switch {
|
||||
case contains(prefix, "blacklist"):
|
||||
return "blacklist"
|
||||
case contains(prefix, "token"):
|
||||
return "token"
|
||||
case contains(prefix, "metadata"):
|
||||
return "metadata"
|
||||
case contains(prefix, "jwk"):
|
||||
return "jwk"
|
||||
case contains(prefix, "session"):
|
||||
return "session"
|
||||
case contains(prefix, "introspect"):
|
||||
return "introspection"
|
||||
}
|
||||
}
|
||||
|
||||
return "general"
|
||||
}
|
||||
|
||||
// contains checks if a string contains a substring (case-insensitive)
|
||||
func contains(s, substr string) bool {
|
||||
if len(substr) > len(s) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
match := true
|
||||
for j := 0; j < len(substr); j++ {
|
||||
if toLower(s[i+j]) != toLower(substr[j]) {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// toLower converts a byte to lowercase
|
||||
func toLower(b byte) byte {
|
||||
if b >= 'A' && b <= 'Z' {
|
||||
return b + 32
|
||||
}
|
||||
return b
|
||||
}
|
||||
+1490
File diff suppressed because it is too large
Load Diff
Vendored
+133
@@ -0,0 +1,133 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheBackend defines the interface for all cache backend implementations
|
||||
// Implementations include: MemoryBackend, RedisBackend, and HybridBackend
|
||||
type CacheBackend interface {
|
||||
// Set stores a value in the cache with the specified TTL
|
||||
// Returns an error if the operation fails
|
||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
// Returns: value, remaining TTL, exists flag, and error
|
||||
// If the key doesn't exist, exists will be false
|
||||
Get(ctx context.Context, key string) (value []byte, ttl time.Duration, exists bool, err error)
|
||||
|
||||
// Delete removes a key from the cache
|
||||
// Returns true if the key was deleted, false if it didn't exist
|
||||
Delete(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Clear removes all keys from the cache
|
||||
Clear(ctx context.Context) error
|
||||
|
||||
// GetStats returns cache statistics
|
||||
// Stats include: hits, misses, size, memory usage, etc.
|
||||
GetStats() map[string]interface{}
|
||||
|
||||
// Close shuts down the cache backend and releases resources
|
||||
Close() error
|
||||
|
||||
// Ping checks if the backend is healthy and responsive
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
||||
// BackendStats represents statistics for a cache backend
|
||||
type BackendStats struct {
|
||||
// Type is the backend type
|
||||
Type BackendType
|
||||
|
||||
// Hits is the number of cache hits
|
||||
Hits int64
|
||||
|
||||
// Misses is the number of cache misses
|
||||
Misses int64
|
||||
|
||||
// Sets is the number of set operations
|
||||
Sets int64
|
||||
|
||||
// Deletes is the number of delete operations
|
||||
Deletes int64
|
||||
|
||||
// Errors is the number of errors
|
||||
Errors int64
|
||||
|
||||
// Evictions is the number of evicted items
|
||||
Evictions int64
|
||||
|
||||
// CurrentSize is the current number of items in cache
|
||||
CurrentSize int64
|
||||
|
||||
// MaxSize is the maximum number of items (0 means unlimited)
|
||||
MaxSize int64
|
||||
|
||||
// MemoryUsage is the approximate memory usage in bytes
|
||||
MemoryUsage int64
|
||||
|
||||
// AverageGetLatency is the average latency for get operations
|
||||
AverageGetLatency time.Duration
|
||||
|
||||
// AverageSetLatency is the average latency for set operations
|
||||
AverageSetLatency time.Duration
|
||||
|
||||
// LastError is the last error encountered
|
||||
LastError string
|
||||
|
||||
// LastErrorTime is when the last error occurred
|
||||
LastErrorTime time.Time
|
||||
|
||||
// Uptime is how long the backend has been running
|
||||
Uptime time.Duration
|
||||
|
||||
// StartTime is when the backend was started
|
||||
StartTime time.Time
|
||||
}
|
||||
|
||||
// BackendCapabilities describes the capabilities of a cache backend
|
||||
type BackendCapabilities struct {
|
||||
// Distributed indicates if the backend is distributed across multiple instances
|
||||
Distributed bool
|
||||
|
||||
// Persistent indicates if the backend persists data across restarts
|
||||
Persistent bool
|
||||
|
||||
// Eviction indicates if the backend supports automatic eviction
|
||||
Eviction bool
|
||||
|
||||
// TTL indicates if the backend supports TTL (time-to-live)
|
||||
TTL bool
|
||||
|
||||
// MaxKeySize is the maximum size of a key in bytes (0 = unlimited)
|
||||
MaxKeySize int64
|
||||
|
||||
// MaxValueSize is the maximum size of a value in bytes (0 = unlimited)
|
||||
MaxValueSize int64
|
||||
|
||||
// MaxKeys is the maximum number of keys (0 = unlimited)
|
||||
MaxKeys int64
|
||||
|
||||
// SupportsExpire indicates if the backend supports expiration
|
||||
SupportsExpire bool
|
||||
|
||||
// SupportsMultiGet indicates if the backend supports batch get operations
|
||||
SupportsMultiGet bool
|
||||
|
||||
// SupportsTransaction indicates if the backend supports transactions
|
||||
SupportsTransaction bool
|
||||
|
||||
// SupportsCompression indicates if the backend supports compression
|
||||
SupportsCompression bool
|
||||
|
||||
// RequiresSerialize indicates if values must be serialized
|
||||
RequiresSerialize bool
|
||||
|
||||
// AtomicOperations indicates if the backend supports atomic operations
|
||||
AtomicOperations bool
|
||||
}
|
||||
+421
@@ -0,0 +1,421 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCacheBackendContract defines a set of tests that all CacheBackend implementations must pass
|
||||
// This ensures that Memory, Redis, and Hybrid backends all behave consistently
|
||||
func TestCacheBackendContract(t *testing.T) {
|
||||
// Test suite will be run against each backend type
|
||||
t.Run("MemoryBackend", func(t *testing.T) {
|
||||
backend := setupMemoryBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
|
||||
t.Run("RedisBackend", func(t *testing.T) {
|
||||
backend := setupRedisBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
|
||||
t.Run("HybridBackend", func(t *testing.T) {
|
||||
backend := setupHybridBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
}
|
||||
|
||||
// runContractTests executes all contract tests against a backend
|
||||
func runContractTests(t *testing.T, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("BasicSetGet", func(t *testing.T) {
|
||||
testBasicSetGet(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
testGetNonExistent(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("UpdateExisting", func(t *testing.T) {
|
||||
testUpdateExisting(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
testDelete(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("DeleteNonExistent", func(t *testing.T) {
|
||||
testDeleteNonExistent(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
testExists(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("TTLExpiration", func(t *testing.T) {
|
||||
testTTLExpiration(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
testClear(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
testPing(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
testStats(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentAccess", func(t *testing.T) {
|
||||
testConcurrentAccess(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("LargeValues", func(t *testing.T) {
|
||||
testLargeValues(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("EmptyValues", func(t *testing.T) {
|
||||
testEmptyValues(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("SpecialCharactersInKeys", func(t *testing.T) {
|
||||
testSpecialCharactersInKeys(t, ctx, backend)
|
||||
})
|
||||
}
|
||||
|
||||
// testBasicSetGet verifies basic set and get operations
|
||||
func testBasicSetGet(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "test-key-1"
|
||||
value := []byte("test-value-1")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
// Set value
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err, "Set should not return error")
|
||||
|
||||
// Get value
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err, "Get should not return error")
|
||||
assert.True(t, exists, "Key should exist")
|
||||
assert.Equal(t, value, retrieved, "Retrieved value should match")
|
||||
assert.Greater(t, remainingTTL, 50*time.Second, "TTL should be close to original")
|
||||
assert.LessOrEqual(t, remainingTTL, ttl, "TTL should not exceed original")
|
||||
}
|
||||
|
||||
// testGetNonExistent verifies behavior when getting non-existent keys
|
||||
func testGetNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "non-existent-key"
|
||||
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err, "Get should not return error for non-existent key")
|
||||
assert.False(t, exists, "Key should not exist")
|
||||
assert.Nil(t, retrieved, "Value should be nil")
|
||||
assert.Equal(t, time.Duration(0), ttl, "TTL should be zero")
|
||||
}
|
||||
|
||||
// testUpdateExisting verifies updating an existing key
|
||||
func testUpdateExisting(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
// Set initial value
|
||||
err := backend.Set(ctx, key, value1, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update value
|
||||
err = backend.Set(ctx, key, value2, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated value
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved, "Value should be updated")
|
||||
}
|
||||
|
||||
// testDelete verifies delete operation
|
||||
func testDelete(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "delete-key"
|
||||
value := []byte("delete-value")
|
||||
|
||||
// Set value
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Delete
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted, "Delete should return true for existing key")
|
||||
|
||||
// Verify deleted
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after delete")
|
||||
}
|
||||
|
||||
// testDeleteNonExistent verifies deleting non-existent keys
|
||||
func testDeleteNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "non-existent-delete-key"
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, deleted, "Delete should return false for non-existent key")
|
||||
}
|
||||
|
||||
// testExists verifies the Exists operation
|
||||
func testExists(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "exists-key"
|
||||
value := []byte("exists-value")
|
||||
|
||||
// Check non-existent key
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist initially")
|
||||
|
||||
// Set value
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check existing key
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key should exist after Set")
|
||||
}
|
||||
|
||||
// testTTLExpiration verifies TTL expiration behavior
|
||||
func testTTLExpiration(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "ttl-key"
|
||||
value := []byte("ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
// Set with short TTL
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key should exist immediately after Set")
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify expired
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after TTL expiration")
|
||||
}
|
||||
|
||||
// testClear verifies Clear operation
|
||||
func testClear(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
// Set multiple keys
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Give async writes time to complete before clearing
|
||||
// This prevents race conditions with async write workers
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Clear all
|
||||
err := backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all keys are gone
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after Clear")
|
||||
}
|
||||
}
|
||||
|
||||
// testPing verifies Ping operation
|
||||
func testPing(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
err := backend.Ping(ctx)
|
||||
assert.NoError(t, err, "Ping should succeed on healthy backend")
|
||||
}
|
||||
|
||||
// testStats verifies GetStats operation
|
||||
func testStats(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
stats := backend.GetStats()
|
||||
assert.NotNil(t, stats, "Stats should not be nil")
|
||||
|
||||
// Stats should contain basic metrics
|
||||
_, hasHits := stats["hits"]
|
||||
_, hasMisses := stats["misses"]
|
||||
assert.True(t, hasHits || hasMisses, "Stats should contain hits or misses")
|
||||
}
|
||||
|
||||
// testConcurrentAccess verifies thread safety
|
||||
func testConcurrentAccess(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 10
|
||||
iterations := 20
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// testLargeValues verifies handling of large values
|
||||
func testLargeValues(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "large-value-key"
|
||||
value := GenerateLargeValue(1024 * 1024) // 1MB
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle large values")
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(value), len(retrieved), "Large value should be retrieved intact")
|
||||
}
|
||||
|
||||
// testEmptyValues verifies handling of empty values
|
||||
func testEmptyValues(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "empty-value-key"
|
||||
value := []byte{}
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle empty values")
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Empty value should exist")
|
||||
assert.Equal(t, 0, len(retrieved), "Retrieved value should be empty")
|
||||
}
|
||||
|
||||
// testSpecialCharactersInKeys verifies handling of special characters in keys
|
||||
func testSpecialCharactersInKeys(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
specialKeys := []string{
|
||||
"key:with:colons",
|
||||
"key/with/slashes",
|
||||
"key-with-dashes",
|
||||
"key_with_underscores",
|
||||
"key.with.dots",
|
||||
"key|with|pipes",
|
||||
}
|
||||
|
||||
for _, key := range specialKeys {
|
||||
value := []byte(fmt.Sprintf("value-for-%s", key))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle special character in key: %s", key)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key with special characters should exist: %s", key)
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions to setup different backend types
|
||||
// These will be implemented in respective test files
|
||||
|
||||
func setupMemoryBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
// This will be implemented in memory_test.go
|
||||
// For now, return nil to allow compilation
|
||||
t.Skip("MemoryBackend implementation pending")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupRedisBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
// This will be implemented in redis_test.go
|
||||
// For now, return nil to allow compilation
|
||||
t.Skip("RedisBackend implementation pending")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupHybridBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
config := &HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 100,
|
||||
Logger: NewTestLogger(t),
|
||||
}
|
||||
|
||||
hybrid, err := NewHybridBackend(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
hybrid.Close()
|
||||
})
|
||||
|
||||
return hybrid
|
||||
}
|
||||
Vendored
+516
@@ -0,0 +1,516 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// memoryCacheItem represents an item in the memory cache
|
||||
type memoryCacheItem struct {
|
||||
key string
|
||||
value interface{}
|
||||
expiresAt time.Time
|
||||
createdAt time.Time
|
||||
accessedAt time.Time
|
||||
accessCount int64
|
||||
size int64
|
||||
element *list.Element // for LRU tracking
|
||||
}
|
||||
|
||||
// isExpired checks if the item is expired
|
||||
func (item *memoryCacheItem) isExpired() bool {
|
||||
if item.expiresAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(item.expiresAt)
|
||||
}
|
||||
|
||||
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
|
||||
type MemoryCacheBackend struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
currentSize int64
|
||||
currentMemory int64
|
||||
|
||||
// Statistics
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
sets atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Latency tracking
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
|
||||
// Status
|
||||
startTime time.Time
|
||||
lastError string
|
||||
lastErrorTime time.Time
|
||||
cleanupTicker *time.Ticker
|
||||
cleanupDone chan bool
|
||||
closed atomic.Bool
|
||||
|
||||
// Configuration
|
||||
cleanupInterval time.Duration
|
||||
evictionPolicy string // "lru", "lfu", "fifo"
|
||||
}
|
||||
|
||||
// NewMemoryCacheBackend creates a new memory cache backend
|
||||
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
|
||||
if maxSize <= 0 {
|
||||
maxSize = 10000 // Default to 10k items
|
||||
}
|
||||
if maxMemory <= 0 {
|
||||
maxMemory = 100 * 1024 * 1024 // Default to 100MB
|
||||
}
|
||||
if cleanupInterval <= 0 {
|
||||
cleanupInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
m := &MemoryCacheBackend{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
startTime: time.Now(),
|
||||
cleanupInterval: cleanupInterval,
|
||||
evictionPolicy: "lru",
|
||||
cleanupDone: make(chan bool),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
m.cleanupTicker = time.NewTicker(cleanupInterval)
|
||||
go m.cleanupLoop()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired items
|
||||
func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-m.cleanupTicker.C:
|
||||
m.cleanupExpired()
|
||||
case <-m.cleanupDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpired removes all expired items from the cache
|
||||
func (m *MemoryCacheBackend) cleanupExpired() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var keysToDelete []string
|
||||
for key, item := range m.items {
|
||||
if item.isExpired() {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range keysToDelete {
|
||||
m.deleteItemLocked(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{}, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(start).Nanoseconds()
|
||||
m.totalGetTime.Add(duration)
|
||||
m.getCount.Add(1)
|
||||
}()
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
m.mu.Lock()
|
||||
m.deleteItemLocked(key)
|
||||
m.mu.Unlock()
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
// Update access time and count
|
||||
m.mu.Lock()
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
// Move to front of LRU list
|
||||
if m.evictionPolicy == "lru" && item.element != nil {
|
||||
m.lruList.MoveToFront(item.element)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.hits.Add(1)
|
||||
return item.value, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with optional TTL
|
||||
func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(start).Nanoseconds()
|
||||
m.totalSetTime.Add(duration)
|
||||
m.setCount.Add(1)
|
||||
}()
|
||||
|
||||
// Calculate item size (simplified estimation)
|
||||
itemSize := int64(len(key)) + estimateValueSize(value)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Check if we need to evict items
|
||||
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
|
||||
m.evictLocked()
|
||||
}
|
||||
|
||||
// Check if key exists
|
||||
if oldItem, exists := m.items[key]; exists {
|
||||
m.currentMemory -= oldItem.size
|
||||
if oldItem.element != nil {
|
||||
m.lruList.Remove(oldItem.element)
|
||||
}
|
||||
} else {
|
||||
m.currentSize++
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
var expiresAt time.Time
|
||||
if ttl > 0 {
|
||||
expiresAt = now.Add(ttl)
|
||||
}
|
||||
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: itemSize,
|
||||
}
|
||||
|
||||
// Add to LRU list
|
||||
if m.evictionPolicy == "lru" {
|
||||
item.element = m.lruList.PushFront(item)
|
||||
}
|
||||
|
||||
m.items[key] = item
|
||||
m.currentMemory += itemSize
|
||||
m.sets.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.items[key]; !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.deleteItemLocked(key)
|
||||
m.deletes.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
|
||||
if item, exists := m.items[key]; exists {
|
||||
m.currentMemory -= item.size
|
||||
m.currentSize--
|
||||
if item.element != nil {
|
||||
m.lruList.Remove(item.element)
|
||||
}
|
||||
delete(m.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// evictLocked evicts items based on the eviction policy (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) evictLocked() {
|
||||
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
|
||||
// Evict least recently used item
|
||||
element := m.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
m.deleteItemLocked(item.key)
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if m.closed.Load() {
|
||||
return false, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return !item.isExpired(), nil
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache
|
||||
func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.items = make(map[string]*memoryCacheItem)
|
||||
m.lruList = list.New()
|
||||
m.currentSize = 0
|
||||
m.currentMemory = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keys returns all keys matching the pattern (use "*" for all keys)
|
||||
func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range m.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// Size returns the number of items in the cache
|
||||
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.currentSize, nil
|
||||
}
|
||||
|
||||
// TTL returns the remaining time-to-live for a key
|
||||
func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return 0, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, nil // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return remaining, nil
|
||||
}
|
||||
|
||||
// Expire updates the TTL for an existing key
|
||||
func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Duration) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
item, exists := m.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
return ErrCacheMiss
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the cache backend
|
||||
func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
lastError := m.lastError
|
||||
lastErrorTime := m.lastErrorTime
|
||||
m.mu.RUnlock()
|
||||
|
||||
avgGetLatency := time.Duration(0)
|
||||
if getCount := m.getCount.Load(); getCount > 0 {
|
||||
avgGetLatency = time.Duration(m.totalGetTime.Load() / getCount)
|
||||
}
|
||||
|
||||
avgSetLatency := time.Duration(0)
|
||||
if setCount := m.setCount.Load(); setCount > 0 {
|
||||
avgSetLatency = time.Duration(m.totalSetTime.Load() / setCount)
|
||||
}
|
||||
|
||||
return &BackendStats{
|
||||
Type: TypeMemory,
|
||||
Hits: m.hits.Load(),
|
||||
Misses: m.misses.Load(),
|
||||
Sets: m.sets.Load(),
|
||||
Deletes: m.deletes.Load(),
|
||||
Errors: m.errors.Load(),
|
||||
Evictions: m.evictions.Load(),
|
||||
CurrentSize: m.currentSize,
|
||||
MaxSize: m.maxSize,
|
||||
MemoryUsage: m.currentMemory,
|
||||
AverageGetLatency: avgGetLatency,
|
||||
AverageSetLatency: avgSetLatency,
|
||||
LastError: lastError,
|
||||
LastErrorTime: lastErrorTime,
|
||||
Uptime: time.Since(m.startTime),
|
||||
StartTime: m.startTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy
|
||||
func (m *MemoryCacheBackend) Ping(ctx context.Context) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the backend and releases resources
|
||||
func (m *MemoryCacheBackend) Close() error {
|
||||
if m.closed.Swap(true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
m.cleanupTicker.Stop()
|
||||
close(m.cleanupDone)
|
||||
|
||||
m.mu.Lock()
|
||||
m.items = nil
|
||||
m.lruList = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy
|
||||
func (m *MemoryCacheBackend) IsHealthy() bool {
|
||||
return !m.closed.Load()
|
||||
}
|
||||
|
||||
// Type returns the backend type
|
||||
func (m *MemoryCacheBackend) Type() BackendType {
|
||||
return TypeMemory
|
||||
}
|
||||
|
||||
// Capabilities returns the backend capabilities
|
||||
func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
|
||||
return &BackendCapabilities{
|
||||
Distributed: false,
|
||||
Persistent: false,
|
||||
Eviction: true,
|
||||
TTL: true,
|
||||
MaxKeySize: 1024, // 1KB
|
||||
MaxValueSize: 10485760, // 10MB
|
||||
MaxKeys: m.maxSize,
|
||||
SupportsExpire: true,
|
||||
SupportsMultiGet: true,
|
||||
SupportsTransaction: false,
|
||||
SupportsCompression: false,
|
||||
RequiresSerialize: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// estimateValueSize estimates the size of a value in bytes
|
||||
func estimateValueSize(value interface{}) int64 {
|
||||
// This is a simplified estimation
|
||||
// In production, you might want to use a more accurate method
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return int64(len(v))
|
||||
case []byte:
|
||||
return int64(len(v))
|
||||
case int, int32, int64, uint, uint32, uint64:
|
||||
return 8
|
||||
case float32, float64:
|
||||
return 8
|
||||
case bool:
|
||||
return 1
|
||||
default:
|
||||
// For complex types, use a default estimate
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
// matchPattern checks if a key matches a pattern (simplified glob matching)
|
||||
func matchPattern(pattern, key string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
// Simplified pattern matching - in production, use a proper glob library
|
||||
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
|
||||
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
|
||||
}
|
||||
+182
@@ -0,0 +1,182 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
)
|
||||
|
||||
// setupBenchmarkRedis creates a miniredis instance for benchmarking
|
||||
func setupBenchmarkRedis(b *testing.B) string {
|
||||
b.Helper()
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
mr.Close()
|
||||
})
|
||||
return mr.Addr()
|
||||
}
|
||||
|
||||
// BenchmarkRedisOperations_WithPooling benchmarks memory allocations with object pooling
|
||||
func BenchmarkRedisOperations_WithPooling(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Perform various operations
|
||||
_, _ = conn.Do("SET", "bench-key", "bench-value")
|
||||
_, _ = conn.Do("GET", "bench-key")
|
||||
_, _ = conn.Do("EXISTS", "bench-key")
|
||||
_, _ = conn.Do("DEL", "bench-key")
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRedisBackend_SetGet benchmarks the full backend with pooling
|
||||
func BenchmarkRedisBackend_SetGet(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: addr,
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
testData := []byte("benchmark test data with some content")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Set operation
|
||||
err := backend.Set(ctx, "bench-key", testData, 0)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Get operation
|
||||
_, _, _, err = backend.Get(ctx, "bench-key")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRedisBackend_ConcurrentAccess benchmarks concurrent operations with pooling
|
||||
func BenchmarkRedisBackend_ConcurrentAccess(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: addr,
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
testData := []byte("concurrent benchmark data")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = backend.Set(ctx, "concurrent-key", testData, 0)
|
||||
_, _, _, _ = backend.Get(ctx, "concurrent-key")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkRESPProtocol_WriteRead benchmarks RESP protocol encoding/decoding
|
||||
func BenchmarkRESPProtocol_WriteRead(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// This tests the pooling of RESPReader/RESPWriter
|
||||
_, _ = conn.Do("PING")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConnectionPool_GetPut benchmarks connection pool operations
|
||||
func BenchmarkConnectionPool_GetPut(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
pool.Put(conn)
|
||||
}
|
||||
}
|
||||
+783
@@ -0,0 +1,783 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMemoryBackend_BasicOperations tests basic CRUD operations
|
||||
func TestMemoryBackend_BasicOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetAndGet", func(t *testing.T) {
|
||||
key := "test-key"
|
||||
value := []byte("test-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
assert.Greater(t, remainingTTL, 50*time.Second)
|
||||
assert.LessOrEqual(t, remainingTTL, ttl)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
_, _, exists, err := backend.Get(ctx, "non-existent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
key := "delete-key"
|
||||
value := []byte("delete-value")
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("DeleteNonExistent", func(t *testing.T) {
|
||||
deleted, err := backend.Delete(ctx, "non-existent-delete")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, deleted)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
key := "exists-key"
|
||||
value := []byte("exists-value")
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
// Add multiple items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
stats := backend.GetStats()
|
||||
size := stats["size"].(int64)
|
||||
assert.Equal(t, int64(0), size)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_TTLExpiration tests TTL and expiration
|
||||
func TestMemoryBackend_TTLExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.CleanupInterval = 50 * time.Millisecond
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ShortTTL", func(t *testing.T) {
|
||||
key := "short-ttl-key"
|
||||
value := []byte("short-ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be expired
|
||||
_, _, exists, err = backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("TTLDecrement", func(t *testing.T) {
|
||||
key := "ttl-decrement-key"
|
||||
value := []byte("ttl-decrement-value")
|
||||
ttl := 2 * time.Second
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check TTL immediately
|
||||
_, ttl1, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Check TTL again - should be less
|
||||
_, ttl2, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Less(t, ttl2, ttl1, "TTL should decrease over time")
|
||||
})
|
||||
|
||||
t.Run("CleanupExpiredItems", func(t *testing.T) {
|
||||
// Set multiple items with short TTL
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("cleanup-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 50*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Wait for cleanup to run
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// All items should be cleaned up
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expired items should be cleaned up")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_LRUEviction tests LRU eviction
|
||||
func TestMemoryBackend_LRUEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 5
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill cache to max size
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("lru-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("lru-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Access first item to make it most recently used
|
||||
_, _, exists, err := backend.Get(ctx, "lru-key-0")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Add a new item - should evict lru-key-1 (least recently used)
|
||||
err = backend.Set(ctx, "lru-key-new", []byte("new-value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// lru-key-0 should still exist (was accessed recently)
|
||||
exists, err = backend.Exists(ctx, "lru-key-0")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Recently accessed item should not be evicted")
|
||||
|
||||
// lru-key-1 should be evicted
|
||||
exists, err = backend.Exists(ctx, "lru-key-1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Least recently used item should be evicted")
|
||||
|
||||
// Check eviction count
|
||||
stats := backend.GetStats()
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have evictions")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_MemoryLimit tests memory-based eviction
|
||||
func TestMemoryBackend_MemoryLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100
|
||||
config.MaxMemoryBytes = 1024 // 1KB limit
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items until memory limit is reached
|
||||
largeValue := make([]byte, 512) // 512 bytes each
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("mem-key-%d", i)
|
||||
err := backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
stats := backend.GetStats()
|
||||
memory := stats["memory"].(int64)
|
||||
assert.LessOrEqual(t, memory, config.MaxMemoryBytes, "Memory should not exceed limit")
|
||||
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have memory-based evictions")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_ConcurrentAccess tests thread safety
|
||||
func TestMemoryBackend_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
// Random deletes
|
||||
if j%5 == 0 {
|
||||
backend.Delete(ctx, key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify stats are consistent
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
assert.Greater(t, hits+misses, int64(0), "Should have cache operations")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_UpdateExisting tests updating existing keys
|
||||
func TestMemoryBackend_UpdateExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
|
||||
// Set original
|
||||
err = backend.Set(ctx, key, value1, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update
|
||||
err = backend.Set(ctx, key, value2, 2*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved)
|
||||
assert.Greater(t, ttl, 1*time.Minute, "TTL should be updated")
|
||||
|
||||
// Size should not increase (same key)
|
||||
stats := backend.GetStats()
|
||||
size := stats["size"].(int64)
|
||||
assert.Equal(t, int64(1), size, "Size should be 1 for one key")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Stats tests statistics tracking
|
||||
func TestMemoryBackend_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := backend.GetStats()
|
||||
assert.Equal(t, int64(0), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(0), stats["misses"].(int64))
|
||||
|
||||
// Add items and track hits/misses
|
||||
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
backend.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
|
||||
// Hit
|
||||
backend.Get(ctx, "key1")
|
||||
// Miss
|
||||
backend.Get(ctx, "non-existent")
|
||||
|
||||
stats = backend.GetStats()
|
||||
assert.Equal(t, int64(1), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(1), stats["misses"].(int64))
|
||||
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
assert.InDelta(t, 0.5, hitRate, 0.01)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_EmptyValues tests handling of empty values
|
||||
func TestMemoryBackend_EmptyValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "empty-key"
|
||||
emptyValue := []byte{}
|
||||
|
||||
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 0, len(retrieved))
|
||||
}
|
||||
|
||||
// TestMemoryBackend_LargeValues tests handling of large values
|
||||
func TestMemoryBackend_LargeValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxMemoryBytes = 10 * 1024 * 1024 // 10MB
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "large-key"
|
||||
largeValue := make([]byte, 1024*1024) // 1MB
|
||||
|
||||
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(largeValue), len(retrieved))
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Close tests proper cleanup on close
|
||||
func TestMemoryBackend_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add some items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("close-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("close-value-%d", i))
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Close
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Operations after close should fail
|
||||
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
_, _, _, err = backend.Get(ctx, "close-key-0")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
// Closing again should be safe
|
||||
err = backend.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Ping tests ping operation
|
||||
func TestMemoryBackend_Ping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = backend.Ping(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Close and ping should fail
|
||||
backend.Close()
|
||||
err = backend.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_ValueIsolation tests that returned values are isolated
|
||||
func TestMemoryBackend_ValueIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "isolation-key"
|
||||
originalValue := []byte("original-value")
|
||||
|
||||
err = backend.Set(ctx, key, originalValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get value and modify it
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Modify retrieved value
|
||||
if len(retrieved) > 0 {
|
||||
retrieved[0] = 'X'
|
||||
}
|
||||
|
||||
// Get again - should be unchanged
|
||||
retrieved2, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, originalValue, retrieved2, "Original value should not be modified")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Keys tests the Keys method with pattern matching
|
||||
func TestMemoryBackend_Keys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add test data
|
||||
testKeys := []string{"user:1", "user:2", "session:abc", "session:def", "token:xyz"}
|
||||
for _, key := range testKeys {
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("AllKeys", func(t *testing.T) {
|
||||
keys, err := backend.Keys(ctx, "*")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 5)
|
||||
})
|
||||
|
||||
t.Run("SpecificPattern", func(t *testing.T) {
|
||||
// Simple exact match
|
||||
keys, err := backend.Keys(ctx, "user:1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 1)
|
||||
assert.Contains(t, keys, "user:1")
|
||||
})
|
||||
|
||||
t.Run("ExcludesExpired", func(t *testing.T) {
|
||||
// Add an expired key
|
||||
expiredKey := "expired:key"
|
||||
err := backend.Set(ctx, expiredKey, []byte("value"), 1*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
keys, err := backend.Keys(ctx, "*")
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, keys, expiredKey, "Expired keys should not be returned")
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
_, err := closedBackend.Keys(ctx, "*")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Size tests the Size method
|
||||
func TestMemoryBackend_Size(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially empty
|
||||
size, err := backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), size)
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
size, err = backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(5), size)
|
||||
|
||||
// Delete one
|
||||
backend.Delete(ctx, "key-0")
|
||||
|
||||
size, err = backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(4), size)
|
||||
|
||||
// After close
|
||||
backend.Close()
|
||||
_, err = backend.Size(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_TTL tests the TTL method
|
||||
func TestMemoryBackend_TTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ExistingKey", func(t *testing.T) {
|
||||
key := "ttl-key"
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, []byte("value"), ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, remaining, 50*time.Second)
|
||||
assert.LessOrEqual(t, remaining, ttl)
|
||||
})
|
||||
|
||||
t.Run("NonExistentKey", func(t *testing.T) {
|
||||
_, err := backend.TTL(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCacheMiss, err)
|
||||
})
|
||||
|
||||
t.Run("NoExpiration", func(t *testing.T) {
|
||||
key := "no-expiry"
|
||||
// TTL of 0 typically means no expiration
|
||||
err := backend.Set(ctx, key, []byte("value"), 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
// No expiration returns 0
|
||||
assert.Equal(t, time.Duration(0), remaining)
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
_, err := closedBackend.TTL(ctx, "key")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Expire tests the Expire method
|
||||
func TestMemoryBackend_Expire(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateTTL", func(t *testing.T) {
|
||||
key := "expire-key"
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update to shorter TTL
|
||||
err = backend.Expire(ctx, key, 5*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check new TTL
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, remaining, 5*time.Second)
|
||||
})
|
||||
|
||||
t.Run("NonExistentKey", func(t *testing.T) {
|
||||
err := backend.Expire(ctx, "non-existent", 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCacheMiss, err)
|
||||
})
|
||||
|
||||
t.Run("RemoveExpiration", func(t *testing.T) {
|
||||
key := "no-expire-key"
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set TTL to 0 to remove expiration
|
||||
err = backend.Expire(ctx, key, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TTL should now be 0
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, time.Duration(0), remaining)
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
err := closedBackend.Expire(ctx, "key", 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_IsHealthy tests the IsHealthy method
|
||||
func TestMemoryBackend_IsHealthy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be healthy when open
|
||||
assert.True(t, backend.IsHealthy())
|
||||
|
||||
// Should be unhealthy after close
|
||||
backend.Close()
|
||||
assert.False(t, backend.IsHealthy())
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Type tests the Type method
|
||||
func TestMemoryBackend_Type(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
backendType := backend.Type()
|
||||
assert.Equal(t, TypeMemory, backendType)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Capabilities tests the Capabilities method
|
||||
func TestMemoryBackend_Capabilities(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
caps := backend.Capabilities()
|
||||
require.NotNil(t, caps)
|
||||
|
||||
// Memory backend should not be distributed or persistent
|
||||
assert.False(t, caps.Distributed)
|
||||
assert.False(t, caps.Persistent)
|
||||
|
||||
// Should support eviction and TTL
|
||||
assert.True(t, caps.Eviction)
|
||||
assert.True(t, caps.TTL)
|
||||
assert.True(t, caps.SupportsExpire)
|
||||
assert.True(t, caps.SupportsMultiGet)
|
||||
|
||||
// Check limits
|
||||
assert.Greater(t, caps.MaxKeySize, int64(0))
|
||||
assert.Greater(t, caps.MaxValueSize, int64(0))
|
||||
}
|
||||
|
||||
// TestMatchPattern tests the matchPattern helper function
|
||||
func TestMatchPattern(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
pattern string
|
||||
key string
|
||||
matches bool
|
||||
}{
|
||||
{"*", "any-key", true},
|
||||
{"*", "another", true},
|
||||
{"user:1", "user:1", true},
|
||||
{"user:1", "user:2", false},
|
||||
{"*:suffix", "prefix:suffix", true},
|
||||
{"*suffix", "prefix-suffix", true},
|
||||
{"*abc", "xyzabc", true},
|
||||
{"*abc", "xyz", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("%s-%s", tt.pattern, tt.key), func(t *testing.T) {
|
||||
result := matchPattern(tt.pattern, tt.key)
|
||||
assert.Equal(t, tt.matches, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
+153
@@ -0,0 +1,153 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryBackend wraps MemoryCacheBackend to implement the CacheBackend interface
|
||||
type MemoryBackend struct {
|
||||
*MemoryCacheBackend
|
||||
}
|
||||
|
||||
// NewMemoryBackend creates a new memory backend from a config
|
||||
func NewMemoryBackend(config *Config) (*MemoryBackend, error) {
|
||||
maxSize := int64(config.MaxSize)
|
||||
if maxSize <= 0 {
|
||||
maxSize = 1000
|
||||
}
|
||||
|
||||
cacheBackend := NewMemoryCacheBackend(maxSize, config.MaxMemoryBytes, config.CleanupInterval)
|
||||
return &MemoryBackend{
|
||||
MemoryCacheBackend: cacheBackend,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with the specified TTL
|
||||
func (m *MemoryBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
err := m.MemoryCacheBackend.Set(ctx, key, value, ttl)
|
||||
if err == ErrBackendUnavailable {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
val, err := m.MemoryCacheBackend.Get(ctx, key)
|
||||
if err != nil {
|
||||
if err == ErrCacheMiss {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
if err == ErrBackendUnavailable {
|
||||
return nil, 0, false, ErrBackendClosed
|
||||
}
|
||||
return nil, 0, false, err
|
||||
}
|
||||
|
||||
// Get the item directly to check TTL
|
||||
m.MemoryCacheBackend.mu.RLock()
|
||||
item, exists := m.MemoryCacheBackend.items[key]
|
||||
m.MemoryCacheBackend.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
var ttl time.Duration
|
||||
if !item.expiresAt.IsZero() {
|
||||
ttl = time.Until(item.expiresAt)
|
||||
if ttl < 0 {
|
||||
ttl = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Convert interface{} to []byte
|
||||
var valueBytes []byte
|
||||
if val != nil {
|
||||
if bytes, ok := val.([]byte); ok {
|
||||
valueBytes = bytes
|
||||
} else {
|
||||
// If it's not already []byte, we might need to handle other types
|
||||
// For now, we'll just return an error
|
||||
return nil, 0, false, ErrInvalidValue
|
||||
}
|
||||
}
|
||||
|
||||
return valueBytes, ttl, true, nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
// Check if key exists first
|
||||
exists, err := m.MemoryCacheBackend.Exists(ctx, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err = m.MemoryCacheBackend.Delete(ctx, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return m.MemoryCacheBackend.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Clear removes all keys from the cache
|
||||
func (m *MemoryBackend) Clear(ctx context.Context) error {
|
||||
return m.MemoryCacheBackend.Clear(ctx)
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics
|
||||
func (m *MemoryBackend) GetStats() map[string]interface{} {
|
||||
stats, err := m.MemoryCacheBackend.GetStats(context.Background())
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BackendStats to map
|
||||
hitRate := float64(0)
|
||||
total := stats.Hits + stats.Misses
|
||||
if total > 0 {
|
||||
hitRate = float64(stats.Hits) / float64(total)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the cache backend and releases resources
|
||||
func (m *MemoryBackend) Close() error {
|
||||
return m.MemoryCacheBackend.Close()
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy and responsive
|
||||
func (m *MemoryBackend) Ping(ctx context.Context) error {
|
||||
return m.MemoryCacheBackend.Ping(ctx)
|
||||
}
|
||||
|
||||
// Ensure MemoryBackend implements CacheBackend
|
||||
var _ CacheBackend = (*MemoryBackend)(nil)
|
||||
Vendored
+455
@@ -0,0 +1,455 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Pure-Go Redis client implementation
|
||||
// Compatible with Yaegi interpreter (no unsafe package)
|
||||
// Implements RESP protocol for basic Redis operations
|
||||
|
||||
var (
|
||||
ErrPoolExhausted = errors.New("connection pool exhausted")
|
||||
)
|
||||
|
||||
// RedisBackend implements a Redis-based cache backend using pure Go
|
||||
type RedisBackend struct {
|
||||
config *Config
|
||||
pool *ConnectionPool
|
||||
healthMonitor *HealthMonitor
|
||||
|
||||
// Metrics
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
closed atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewRedisBackend creates a new Redis cache backend with pure-Go implementation
|
||||
func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
if config.RedisAddr == "" {
|
||||
return nil, fmt.Errorf("redis address is required")
|
||||
}
|
||||
|
||||
// Create connection pool with health checks enabled
|
||||
poolConfig := &PoolConfig{
|
||||
Address: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
DB: config.RedisDB,
|
||||
MaxConnections: config.PoolSize,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
// Create health monitor
|
||||
healthConfig := DefaultHealthMonitorConfig()
|
||||
healthMonitor := NewHealthMonitor(pool, healthConfig)
|
||||
|
||||
backend := &RedisBackend{
|
||||
config: config,
|
||||
pool: pool,
|
||||
healthMonitor: healthMonitor,
|
||||
}
|
||||
|
||||
// Test connectivity
|
||||
if err := backend.Ping(context.Background()); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping Redis: %w", err)
|
||||
}
|
||||
|
||||
// Start health monitoring
|
||||
healthMonitor.Start()
|
||||
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// Set stores a value in Redis with TTL
|
||||
func (r *RedisBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
|
||||
// Execute with retry logic
|
||||
return r.executeWithRetry(ctx, func(conn *RedisConn) error {
|
||||
var err error
|
||||
|
||||
// Use PSETEX for millisecond precision, SETEX for second precision
|
||||
if ttl > 0 {
|
||||
ttlMillis := ttl.Milliseconds()
|
||||
if ttlMillis < 1000 {
|
||||
// Use PSETEX for sub-second TTLs (millisecond precision)
|
||||
_, err = conn.Do("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
|
||||
} else {
|
||||
// Use SETEX for larger TTLs (second precision)
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
_, err = conn.Do("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
|
||||
}
|
||||
} else {
|
||||
_, err = conn.Do("SET", prefixedKey, string(value))
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// Get retrieves a value from Redis
|
||||
func (r *RedisBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, 0, false, ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
var resultValue []byte
|
||||
var resultTTL time.Duration
|
||||
var resultExists bool
|
||||
|
||||
// Execute with retry logic
|
||||
err := r.executeWithRetry(ctx, func(conn *RedisConn) error {
|
||||
// Get value
|
||||
resp, err := conn.Do("GET", prefixedKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNilResponse) {
|
||||
r.misses.Add(1)
|
||||
resultExists = false
|
||||
return nil // Not an error, key just doesn't exist
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
value, err := RESPString(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get TTL
|
||||
ttlResp, err := conn.Do("TTL", prefixedKey)
|
||||
if err != nil {
|
||||
// If TTL fails, still return the value
|
||||
r.hits.Add(1)
|
||||
resultValue = []byte(value)
|
||||
resultTTL = 0
|
||||
resultExists = true
|
||||
return nil
|
||||
}
|
||||
|
||||
ttlSeconds, _ := RESPInt(ttlResp)
|
||||
var ttl time.Duration
|
||||
if ttlSeconds > 0 {
|
||||
ttl = time.Duration(ttlSeconds) * time.Second
|
||||
}
|
||||
|
||||
r.hits.Add(1)
|
||||
resultValue = []byte(value)
|
||||
resultTTL = ttl
|
||||
resultExists = true
|
||||
return nil
|
||||
})
|
||||
|
||||
return resultValue, resultTTL, resultExists, err
|
||||
}
|
||||
|
||||
// Delete removes a key from Redis
|
||||
func (r *RedisBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
if r.closed.Load() {
|
||||
return false, ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
resp, err := conn.Do("DEL", prefixedKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in Redis
|
||||
func (r *RedisBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if r.closed.Load() {
|
||||
return false, ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
resp, err := conn.Do("EXISTS", prefixedKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// Clear removes all keys with the configured prefix
|
||||
func (r *RedisBackend) Clear(ctx context.Context) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
// Use FLUSHDB if no prefix (clear entire DB)
|
||||
if r.config.RedisPrefix == "" {
|
||||
_, err := conn.Do("FLUSHDB")
|
||||
return err
|
||||
}
|
||||
|
||||
// With prefix, we need to scan and delete keys
|
||||
// For simplicity in this implementation, we'll use KEYS pattern (not recommended for production at scale)
|
||||
pattern := r.config.RedisPrefix + "*"
|
||||
resp, err := conn.Do("KEYS", pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract keys from array response
|
||||
keys, ok := resp.([]interface{})
|
||||
if !ok || len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete each key
|
||||
for _, keyInterface := range keys {
|
||||
key, err := RESPString(keyInterface)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.Do("DEL", key) // Best effort, ignore errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats returns backend statistics
|
||||
func (r *RedisBackend) GetStats() map[string]interface{} {
|
||||
hits := r.hits.Load()
|
||||
misses := r.misses.Load()
|
||||
total := hits + misses
|
||||
|
||||
hitRate := float64(0)
|
||||
if total > 0 {
|
||||
hitRate = float64(hits) / float64(total)
|
||||
}
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"backend": "redis-pure-go",
|
||||
"address": r.config.RedisAddr,
|
||||
"hits": hits,
|
||||
"misses": misses,
|
||||
"hit_rate": hitRate,
|
||||
"pool": r.pool.Stats(),
|
||||
}
|
||||
|
||||
// Add health monitor stats if available
|
||||
if r.healthMonitor != nil {
|
||||
stats["health"] = r.healthMonitor.GetStats()
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks Redis connectivity
|
||||
func (r *RedisBackend) Ping(ctx context.Context) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
_, err = conn.Do("PING")
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the Redis backend and all connections
|
||||
func (r *RedisBackend) Close() error {
|
||||
if r.closed.Swap(true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Stop health monitor
|
||||
if r.healthMonitor != nil {
|
||||
r.healthMonitor.Stop()
|
||||
}
|
||||
|
||||
// Close connection pool
|
||||
if r.pool != nil {
|
||||
return r.pool.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prefixKey adds the configured prefix to a key
|
||||
func (r *RedisBackend) prefixKey(key string) string {
|
||||
if r.config.RedisPrefix == "" {
|
||||
return key
|
||||
}
|
||||
return r.config.RedisPrefix + key
|
||||
}
|
||||
|
||||
// executeWithRetry executes a Redis operation with exponential backoff retry logic
|
||||
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
|
||||
maxRetries := 3
|
||||
baseDelay := 100 * time.Millisecond
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
// If we can't get a connection and this is the last attempt, fail
|
||||
if attempt == maxRetries-1 {
|
||||
return fmt.Errorf("failed to get connection after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
// Wait with exponential backoff before retrying
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the operation
|
||||
err = operation(conn)
|
||||
r.pool.Put(conn)
|
||||
|
||||
// If successful, return
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If error is not retryable or last attempt, fail
|
||||
if attempt == maxRetries-1 || !isRetryableError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait with exponential backoff before retrying
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("operation failed after %d attempts", maxRetries)
|
||||
}
|
||||
|
||||
// isRetryableError determines if an error is worth retrying
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Retry on connection errors, timeouts, etc.
|
||||
// Don't retry on application-level errors like wrong type
|
||||
errMsg := err.Error()
|
||||
retryablePatterns := []string{
|
||||
"connection",
|
||||
"timeout",
|
||||
"EOF",
|
||||
"broken pipe",
|
||||
"reset by peer",
|
||||
}
|
||||
|
||||
for _, pattern := range retryablePatterns {
|
||||
if contains(errMsg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetMany stores multiple values in Redis (batch operation)
|
||||
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
// For simplicity, execute sequentially (can be optimized with pipelining later)
|
||||
for key, value := range items {
|
||||
if err := r.Set(ctx, key, value, ttl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values from Redis
|
||||
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
result := make(map[string][]byte)
|
||||
|
||||
// For simplicity, execute sequentially
|
||||
for _, key := range keys {
|
||||
value, _, exists, err := r.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
+176
@@ -0,0 +1,176 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthMonitor continuously monitors Redis connection health and triggers reconnections
|
||||
type HealthMonitor struct {
|
||||
pool *ConnectionPool
|
||||
config *HealthMonitorConfig
|
||||
|
||||
// State
|
||||
healthy atomic.Bool
|
||||
running atomic.Bool
|
||||
lastCheckTime atomic.Int64 // Unix timestamp
|
||||
|
||||
// Metrics
|
||||
consecutiveFailures atomic.Int64
|
||||
totalChecks atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// HealthMonitorConfig configures the health monitor
|
||||
type HealthMonitorConfig struct {
|
||||
CheckInterval time.Duration // How often to check health
|
||||
Timeout time.Duration // Timeout for health check
|
||||
UnhealthyThreshold int // Consecutive failures before marking unhealthy
|
||||
OnHealthChange func(healthy bool)
|
||||
}
|
||||
|
||||
// DefaultHealthMonitorConfig returns default health monitor configuration
|
||||
func DefaultHealthMonitorConfig() *HealthMonitorConfig {
|
||||
return &HealthMonitorConfig{
|
||||
CheckInterval: 5 * time.Second,
|
||||
Timeout: 3 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHealthMonitor creates a new health monitor
|
||||
func NewHealthMonitor(pool *ConnectionPool, config *HealthMonitorConfig) *HealthMonitor {
|
||||
if config == nil {
|
||||
config = DefaultHealthMonitorConfig()
|
||||
}
|
||||
|
||||
hm := &HealthMonitor{
|
||||
pool: pool,
|
||||
config: config,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
hm.healthy.Store(true) // Assume healthy initially
|
||||
return hm
|
||||
}
|
||||
|
||||
// Start begins health monitoring
|
||||
func (hm *HealthMonitor) Start() {
|
||||
if hm.running.Swap(true) {
|
||||
return // Already running
|
||||
}
|
||||
|
||||
hm.wg.Add(1)
|
||||
go hm.monitorLoop()
|
||||
}
|
||||
|
||||
// Stop stops health monitoring
|
||||
func (hm *HealthMonitor) Stop() {
|
||||
if !hm.running.Swap(false) {
|
||||
return // Not running
|
||||
}
|
||||
|
||||
close(hm.stopChan)
|
||||
hm.wg.Wait()
|
||||
}
|
||||
|
||||
// IsHealthy returns the current health status
|
||||
func (hm *HealthMonitor) IsHealthy() bool {
|
||||
return hm.healthy.Load()
|
||||
}
|
||||
|
||||
// GetStats returns health monitor statistics
|
||||
func (hm *HealthMonitor) GetStats() map[string]interface{} {
|
||||
lastCheck := time.Unix(hm.lastCheckTime.Load(), 0)
|
||||
|
||||
return map[string]interface{}{
|
||||
"healthy": hm.healthy.Load(),
|
||||
"consecutive_failures": hm.consecutiveFailures.Load(),
|
||||
"total_checks": hm.totalChecks.Load(),
|
||||
"total_failures": hm.totalFailures.Load(),
|
||||
"last_check": lastCheck,
|
||||
}
|
||||
}
|
||||
|
||||
// monitorLoop runs the health check loop
|
||||
func (hm *HealthMonitor) monitorLoop() {
|
||||
defer hm.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(hm.config.CheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Perform initial check immediately
|
||||
hm.performHealthCheck()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hm.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
hm.performHealthCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthCheck executes a health check
|
||||
func (hm *HealthMonitor) performHealthCheck() {
|
||||
hm.totalChecks.Add(1)
|
||||
hm.lastCheckTime.Store(time.Now().Unix())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hm.config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Try to get a connection and ping Redis
|
||||
conn, err := hm.pool.Get(ctx)
|
||||
if err != nil {
|
||||
hm.recordFailure()
|
||||
return
|
||||
}
|
||||
defer hm.pool.Put(conn)
|
||||
|
||||
// Ping Redis
|
||||
_, err = conn.Do("PING")
|
||||
if err != nil {
|
||||
hm.recordFailure()
|
||||
return
|
||||
}
|
||||
|
||||
// Success!
|
||||
hm.recordSuccess()
|
||||
}
|
||||
|
||||
// recordSuccess records a successful health check
|
||||
func (hm *HealthMonitor) recordSuccess() {
|
||||
wasHealthy := hm.healthy.Load()
|
||||
hm.consecutiveFailures.Store(0)
|
||||
hm.healthy.Store(true)
|
||||
|
||||
// Trigger callback if health changed
|
||||
if !wasHealthy && hm.config.OnHealthChange != nil {
|
||||
hm.config.OnHealthChange(true)
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failed health check
|
||||
func (hm *HealthMonitor) recordFailure() {
|
||||
hm.totalFailures.Add(1)
|
||||
failures := hm.consecutiveFailures.Add(1)
|
||||
|
||||
wasHealthy := hm.healthy.Load()
|
||||
|
||||
// Mark unhealthy if threshold exceeded
|
||||
if failures >= int64(hm.config.UnhealthyThreshold) {
|
||||
hm.healthy.Store(false)
|
||||
|
||||
// Trigger callback if health changed
|
||||
if wasHealthy && hm.config.OnHealthChange != nil {
|
||||
hm.config.OnHealthChange(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
+421
@@ -0,0 +1,421 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestHealthMonitor_BasicOperation tests basic health monitoring
|
||||
func TestHealthMonitor_BasicOperation(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Create health monitor with fast check interval for testing
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
require.NotNil(t, hm)
|
||||
|
||||
// Initially should be healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Start monitoring
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Wait for a few checks
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Should still be healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Check stats
|
||||
stats := hm.GetStats()
|
||||
require.NotNil(t, stats)
|
||||
assert.True(t, stats["healthy"].(bool))
|
||||
assert.Greater(t, stats["total_checks"].(int64), int64(0))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_HealthyToUnhealthy tests transition to unhealthy state
|
||||
func TestHealthMonitor_HealthyToUnhealthy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
var healthChangedCalled atomic.Bool
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
UnhealthyThreshold: 2,
|
||||
OnHealthChange: func(healthy bool) {
|
||||
if !healthy {
|
||||
healthChangedCalled.Store(true)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Simulate Redis errors
|
||||
mr.SetError("ERR server is down")
|
||||
|
||||
// Wait for health checks to detect failure (2 failures * 50ms + buffer)
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should now be unhealthy
|
||||
assert.False(t, hm.IsHealthy(), "Health monitor should detect server failure")
|
||||
assert.True(t, healthChangedCalled.Load(), "OnHealthChange callback should be called")
|
||||
|
||||
// Check stats
|
||||
stats := hm.GetStats()
|
||||
assert.False(t, stats["healthy"].(bool))
|
||||
assert.GreaterOrEqual(t, stats["consecutive_failures"].(int64), int64(2))
|
||||
assert.Greater(t, stats["total_failures"].(int64), int64(0))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_UnhealthyToHealthy tests recovery to healthy state
|
||||
func TestHealthMonitor_UnhealthyToHealthy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
var recoveryDetected atomic.Bool
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
UnhealthyThreshold: 2,
|
||||
OnHealthChange: func(healthy bool) {
|
||||
if healthy {
|
||||
recoveryDetected.Store(true)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Simulate Redis errors
|
||||
mr.SetError("ERR server is down")
|
||||
|
||||
// Wait for health checks to detect failure
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should now be unhealthy
|
||||
assert.False(t, hm.IsHealthy(), "Should detect server failure")
|
||||
|
||||
// Clear error to simulate recovery
|
||||
mr.ClearError()
|
||||
|
||||
// Wait for recovery
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should be healthy again
|
||||
assert.True(t, hm.IsHealthy(), "Should recover after server restart")
|
||||
assert.True(t, recoveryDetected.Load(), "Recovery callback should be called")
|
||||
|
||||
// Consecutive failures should be reset
|
||||
stats := hm.GetStats()
|
||||
assert.True(t, stats["healthy"].(bool))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_StartStop tests start/stop behavior
|
||||
func TestHealthMonitor_StartStop(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, DefaultHealthMonitorConfig())
|
||||
|
||||
// Start monitoring
|
||||
hm.Start()
|
||||
assert.True(t, hm.running.Load())
|
||||
|
||||
// Starting again should be no-op
|
||||
hm.Start()
|
||||
assert.True(t, hm.running.Load())
|
||||
|
||||
// Stop monitoring
|
||||
hm.Stop()
|
||||
assert.False(t, hm.running.Load())
|
||||
|
||||
// Stopping again should be no-op
|
||||
hm.Stop()
|
||||
assert.False(t, hm.running.Load())
|
||||
}
|
||||
|
||||
// TestHealthMonitor_MultipleMonitors tests multiple health monitors
|
||||
func TestHealthMonitor_MultipleMonitors(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 10,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Create multiple monitors
|
||||
hm1 := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm2 := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 150 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
})
|
||||
|
||||
// Start both
|
||||
hm1.Start()
|
||||
hm2.Start()
|
||||
|
||||
// Both should be healthy
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
assert.True(t, hm1.IsHealthy())
|
||||
assert.True(t, hm2.IsHealthy())
|
||||
|
||||
// Stop both
|
||||
hm1.Stop()
|
||||
hm2.Stop()
|
||||
|
||||
// Verify they stopped
|
||||
assert.False(t, hm1.running.Load())
|
||||
assert.False(t, hm2.running.Load())
|
||||
}
|
||||
|
||||
// TestHealthMonitor_StatsAccuracy tests stats tracking
|
||||
func TestHealthMonitor_StatsAccuracy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Wait for some checks
|
||||
time.Sleep(550 * time.Millisecond)
|
||||
|
||||
stats := hm.GetStats()
|
||||
|
||||
// Should have performed multiple checks
|
||||
totalChecks := stats["total_checks"].(int64)
|
||||
assert.GreaterOrEqual(t, totalChecks, int64(4))
|
||||
|
||||
// All checks should succeed
|
||||
assert.Equal(t, int64(0), stats["total_failures"].(int64))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
|
||||
// Last check time should be recent (within check interval + buffer)
|
||||
// Use 2s tolerance to account for CI runner load and timing variance
|
||||
lastCheck := stats["last_check"].(time.Time)
|
||||
assert.WithinDuration(t, time.Now(), lastCheck, 2*time.Second)
|
||||
}
|
||||
|
||||
// TestHealthMonitor_DefaultConfig tests default configuration
|
||||
func TestHealthMonitor_DefaultConfig(t *testing.T) {
|
||||
config := DefaultHealthMonitorConfig()
|
||||
|
||||
assert.Equal(t, 5*time.Second, config.CheckInterval)
|
||||
assert.Equal(t, 3*time.Second, config.Timeout)
|
||||
assert.Equal(t, 3, config.UnhealthyThreshold)
|
||||
assert.Nil(t, config.OnHealthChange)
|
||||
}
|
||||
|
||||
// TestHealthMonitor_PoolExhaustion tests behavior when pool is exhausted
|
||||
func TestHealthMonitor_PoolExhaustion(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1, // Very small pool
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 50 * time.Millisecond, // Short timeout
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Get the only connection, blocking health checks
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for health check attempts
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Health monitor might mark as unhealthy due to timeouts
|
||||
stats := hm.GetStats()
|
||||
t.Logf("Stats with blocked pool: %+v", stats)
|
||||
|
||||
// Return connection
|
||||
pool.Put(conn)
|
||||
|
||||
// Wait for recovery
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Should recover
|
||||
assert.True(t, hm.IsHealthy())
|
||||
}
|
||||
|
||||
// TestConnectionPool_WithHealthChecks tests pool with health checks enabled
|
||||
func TestConnectionPool_WithHealthChecks(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Connection should be healthy
|
||||
assert.True(t, pool.isConnectionHealthy(conn))
|
||||
|
||||
// Use connection
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Get again - should reuse and validate
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
pool.Put(conn2)
|
||||
}
|
||||
|
||||
// TestConnectionPool_StaleConnectionRemoval tests stale connection handling
|
||||
func TestConnectionPool_StaleConnectionRemoval(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 3,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get and return a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
|
||||
initialTotal := pool.totalConns.Load()
|
||||
|
||||
// Close the connection manually to make it stale
|
||||
conn.Close()
|
||||
|
||||
// Get another connection - should detect stale and create new
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
// Connection should be healthy
|
||||
assert.True(t, pool.isConnectionHealthy(conn2))
|
||||
|
||||
pool.Put(conn2)
|
||||
|
||||
// Total connections might be same or less (stale removed)
|
||||
finalTotal := pool.totalConns.Load()
|
||||
assert.LessOrEqual(t, finalTotal, initialTotal+1)
|
||||
}
|
||||
+337
@@ -0,0 +1,337 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConnectionPool manages a pool of Redis connections
|
||||
// Pure-Go implementation compatible with Yaegi
|
||||
type ConnectionPool struct {
|
||||
config *PoolConfig
|
||||
|
||||
connections chan *RedisConn
|
||||
mu sync.Mutex
|
||||
closed atomic.Bool
|
||||
|
||||
// Metrics
|
||||
activeConns atomic.Int32
|
||||
totalConns atomic.Int32
|
||||
gets atomic.Int64
|
||||
puts atomic.Int64
|
||||
timeouts atomic.Int64
|
||||
}
|
||||
|
||||
// PoolConfig holds connection pool configuration
|
||||
type PoolConfig struct {
|
||||
Address string
|
||||
Password string
|
||||
DB int
|
||||
MaxConnections int
|
||||
ConnectTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
EnableHealthCheck bool // Enable connection health validation
|
||||
MaxRetries int // Max retries for failed operations
|
||||
RetryDelay time.Duration // Initial delay between retries
|
||||
}
|
||||
|
||||
// NewConnectionPool creates a new connection pool
|
||||
func NewConnectionPool(config *PoolConfig) (*ConnectionPool, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
|
||||
if config.MaxConnections <= 0 {
|
||||
config.MaxConnections = 10
|
||||
}
|
||||
|
||||
if config.ConnectTimeout == 0 {
|
||||
config.ConnectTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
pool := &ConnectionPool{
|
||||
config: config,
|
||||
connections: make(chan *RedisConn, config.MaxConnections),
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// Get retrieves a connection from the pool or creates a new one
|
||||
func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
if p.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
p.gets.Add(1)
|
||||
|
||||
// Try to get a connection with validation
|
||||
maxAttempts := 3
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
var conn *RedisConn
|
||||
var err error
|
||||
|
||||
select {
|
||||
case conn = <-p.connections:
|
||||
// Reuse existing connection - validate if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
// Connection is stale, close it and try again
|
||||
conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
return conn, nil
|
||||
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
|
||||
default:
|
||||
// No available connection, create new one if under limit
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
if err != nil {
|
||||
// If this is the last attempt, return error
|
||||
if attempt == maxAttempts-1 {
|
||||
return nil, err
|
||||
}
|
||||
// Wait before retry with exponential backoff
|
||||
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
p.totalConns.Add(1)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Pool exhausted, wait for a connection with timeout
|
||||
select {
|
||||
case conn = <-p.connections:
|
||||
// Validate connection if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
p.timeouts.Add(1)
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(p.config.ConnectTimeout):
|
||||
p.timeouts.Add(1)
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("failed to get healthy connection after retries")
|
||||
}
|
||||
|
||||
// Put returns a connection to the pool
|
||||
func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.puts.Add(1)
|
||||
p.activeConns.Add(-1)
|
||||
|
||||
if p.closed.Load() || conn.closed.Load() {
|
||||
conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
return
|
||||
}
|
||||
|
||||
// Return to pool (non-blocking)
|
||||
select {
|
||||
case p.connections <- conn:
|
||||
// Successfully returned to pool
|
||||
default:
|
||||
// Pool full, close connection
|
||||
conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes all connections in the pool
|
||||
func (p *ConnectionPool) Close() error {
|
||||
if p.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
close(p.connections)
|
||||
|
||||
// Close all pooled connections
|
||||
for conn := range p.connections {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns pool statistics
|
||||
func (p *ConnectionPool) Stats() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"active_connections": p.activeConns.Load(),
|
||||
"total_connections": p.totalConns.Load(),
|
||||
"max_connections": p.config.MaxConnections,
|
||||
"gets": p.gets.Load(),
|
||||
"puts": p.puts.Load(),
|
||||
"timeouts": p.timeouts.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// createConnection creates a new Redis connection
|
||||
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Connect with timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: p.config.ConnectTimeout,
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("tcp", p.config.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
redisConn := &RedisConn{
|
||||
conn: conn,
|
||||
readTimeout: p.config.ReadTimeout,
|
||||
writeTimeout: p.config.WriteTimeout,
|
||||
}
|
||||
|
||||
// Authenticate if password is provided
|
||||
if p.config.Password != "" {
|
||||
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
|
||||
redisConn.Close()
|
||||
return nil, fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Select database
|
||||
if p.config.DB != 0 {
|
||||
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
|
||||
redisConn.Close()
|
||||
return nil, fmt.Errorf("failed to select database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return redisConn, nil
|
||||
}
|
||||
|
||||
// RedisConn represents a single Redis connection
|
||||
type RedisConn struct {
|
||||
conn net.Conn
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
closed atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Do executes a Redis command and returns the response
|
||||
func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
if c.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Build command arguments
|
||||
// Check for overflow: ensure len(args)+1 doesn't cause allocation overflow
|
||||
// Limit to a safe value that prevents integer overflow in allocation size calculation
|
||||
// (capacity * sizeof(string) must fit in int/size_t)
|
||||
argsLen := len(args)
|
||||
const maxSafeArgs = (1 << 20) - 1 // 1M args is already absurdly large for Redis commands
|
||||
if argsLen < 0 || argsLen > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments")
|
||||
}
|
||||
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
|
||||
totalBytes := len(command)
|
||||
for _, s := range args {
|
||||
// Protect against possible overflow
|
||||
if len(s) > maxTotalArgBytes-totalBytes {
|
||||
return nil, errors.New("arguments too large (would overflow maximum allowed total size)")
|
||||
}
|
||||
totalBytes += len(s)
|
||||
if totalBytes > maxTotalArgBytes {
|
||||
return nil, errors.New("total argument size exceeds maximum allowed")
|
||||
}
|
||||
}
|
||||
cmdArgs := make([]string, 0, argsLen+1)
|
||||
cmdArgs = append(cmdArgs, command)
|
||||
cmdArgs = append(cmdArgs, args...)
|
||||
|
||||
// Set write timeout
|
||||
if c.writeTimeout > 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
|
||||
// Write command (using pooled writer for memory efficiency)
|
||||
writer := NewRESPWriter(c.conn)
|
||||
err := writer.WriteCommand(cmdArgs...)
|
||||
writer.Release() // Return to pool immediately after use
|
||||
if err != nil {
|
||||
c.closed.Store(true)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set read timeout
|
||||
if c.readTimeout > 0 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
|
||||
// Read response (using pooled reader for memory efficiency)
|
||||
reader := NewRESPReader(c.conn)
|
||||
resp, err := reader.ReadResponse()
|
||||
reader.Release() // Return to pool immediately after use
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrNilResponse) {
|
||||
c.closed.Store(true)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *RedisConn) Close() error {
|
||||
if c.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isConnectionHealthy validates a connection is still working
|
||||
func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
if conn == nil || conn.closed.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Set a read deadline for the ping
|
||||
if conn.conn != nil {
|
||||
conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer conn.conn.SetReadDeadline(time.Time{}) // Clear deadline
|
||||
}
|
||||
|
||||
_, err := conn.Do("PING")
|
||||
return err == nil
|
||||
}
|
||||
+620
@@ -0,0 +1,620 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestConnectionPool_BasicOperations tests basic pool operations
|
||||
func TestConnectionPool_BasicOperations(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
t.Run("GetAndPutConnection", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Verify connection works
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Get again - should reuse same connection
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
pool.Put(conn2)
|
||||
})
|
||||
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
stats := pool.Stats()
|
||||
require.NotNil(t, stats)
|
||||
|
||||
assert.Contains(t, stats, "active_connections")
|
||||
assert.Contains(t, stats, "total_connections")
|
||||
assert.Contains(t, stats, "max_connections")
|
||||
assert.Equal(t, 5, stats["max_connections"])
|
||||
})
|
||||
}
|
||||
|
||||
// TestConnectionPool_MaxConnections tests pool size limits
|
||||
func TestConnectionPool_MaxConnections(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
maxConns := 3
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: maxConns,
|
||||
ConnectTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get max connections
|
||||
conns := make([]*RedisConn, maxConns)
|
||||
for i := 0; i < maxConns; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
conns[i] = conn
|
||||
}
|
||||
|
||||
// Verify stats
|
||||
stats := pool.Stats()
|
||||
assert.Equal(t, int32(maxConns), stats["total_connections"])
|
||||
assert.Equal(t, int32(maxConns), stats["active_connections"])
|
||||
|
||||
// Try to get one more - should block/timeout
|
||||
ctx2, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Get(ctx2)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, conn)
|
||||
|
||||
// Return one connection
|
||||
pool.Put(conns[0])
|
||||
|
||||
// Now we should be able to get a connection
|
||||
conn, err = pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Cleanup
|
||||
pool.Put(conn)
|
||||
for i := 1; i < maxConns; i++ {
|
||||
pool.Put(conns[i])
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectionPool_ConcurrentAccess tests concurrent pool usage
|
||||
func TestConnectionPool_ConcurrentAccess(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
numGoroutines := 50
|
||||
numOperations := 20
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, numGoroutines*numOperations)
|
||||
|
||||
// Spawn goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < numOperations; j++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
continue
|
||||
}
|
||||
|
||||
// Do some work
|
||||
_, err = conn.Do("PING")
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Small delay
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
errorCount := 0
|
||||
for err := range errors {
|
||||
t.Logf("Error: %v", err)
|
||||
errorCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, errorCount, "Expected no errors in concurrent access")
|
||||
|
||||
// Verify stats
|
||||
stats := pool.Stats()
|
||||
t.Logf("Final stats: %+v", stats)
|
||||
assert.LessOrEqual(t, stats["total_connections"].(int32), int32(10))
|
||||
assert.Equal(t, int32(0), stats["active_connections"])
|
||||
}
|
||||
|
||||
// TestConnectionPool_ContextCancellation tests context cancellation
|
||||
func TestConnectionPool_ContextCancellation(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Get the only connection
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get another with cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, conn2)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
|
||||
// Cleanup
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Authentication tests auth support
|
||||
func TestConnectionPool_Authentication(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
// Set password on miniredis
|
||||
mr.server.RequireAuth("secret-password")
|
||||
|
||||
t.Run("CorrectPassword", func(t *testing.T) {
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
Password: "secret-password",
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
})
|
||||
|
||||
t.Run("WrongPassword", func(t *testing.T) {
|
||||
t.Skip("Miniredis doesn't fully simulate AUTH errors like real Redis")
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
Password: "wrong-password",
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
_, err := NewConnectionPool(config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "authentication failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestConnectionPool_DatabaseSelection tests DB selection
|
||||
func TestConnectionPool_DatabaseSelection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
DB: 5,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connection should be on DB 5
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_ClosedConnection tests handling closed connections
|
||||
func TestConnectionPool_ClosedConnection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Get connection
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close it manually
|
||||
conn.Close()
|
||||
|
||||
// Try to use it
|
||||
_, err = conn.Do("PING")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrBackendClosed))
|
||||
|
||||
// Return to pool (should be discarded)
|
||||
pool.Put(conn)
|
||||
|
||||
// Get new connection - should create a new one
|
||||
conn2, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
resp, err := conn2.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn2)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Close tests pool closure
|
||||
func TestConnectionPool_Close(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get some connections
|
||||
conns := make([]*RedisConn, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
conns[i] = conn
|
||||
}
|
||||
|
||||
// Return them
|
||||
for _, conn := range conns {
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// Close pool
|
||||
err = pool.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get connection from closed pool
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrBackendClosed))
|
||||
|
||||
// Close again should be no-op
|
||||
err = pool.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Timeouts tests various timeout scenarios
|
||||
func TestConnectionPool_Timeouts(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
WriteTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Normal operation should work
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestRedisConn_DoCommand tests the Do method
|
||||
func TestRedisConn_DoCommand(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("SET and GET", func(t *testing.T) {
|
||||
// SET
|
||||
resp, err := conn.Do("SET", "testkey", "testvalue")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OK", resp)
|
||||
|
||||
// GET
|
||||
resp, err = conn.Do("GET", "testkey")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testvalue", resp)
|
||||
})
|
||||
|
||||
t.Run("DEL", func(t *testing.T) {
|
||||
// SET key first
|
||||
_, err := conn.Do("SET", "delkey", "delvalue")
|
||||
require.NoError(t, err)
|
||||
|
||||
// DEL
|
||||
resp, err := conn.Do("DEL", "delkey")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
})
|
||||
|
||||
t.Run("EXISTS", func(t *testing.T) {
|
||||
// SET key first
|
||||
_, err := conn.Do("SET", "existskey", "value")
|
||||
require.NoError(t, err)
|
||||
|
||||
// EXISTS - key exists
|
||||
resp, err := conn.Do("EXISTS", "existskey")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
// EXISTS - key doesn't exist
|
||||
resp, err = conn.Do("EXISTS", "nonexistent")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err = RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
|
||||
t.Run("TTL commands", func(t *testing.T) {
|
||||
// SETEX
|
||||
resp, err := conn.Do("SETEX", "ttlkey", "60", "ttlvalue")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OK", resp)
|
||||
|
||||
// TTL
|
||||
resp, err = conn.Do("TTL", "ttlkey")
|
||||
require.NoError(t, err)
|
||||
|
||||
ttl, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, ttl, int64(0))
|
||||
assert.LessOrEqual(t, ttl, int64(60))
|
||||
})
|
||||
}
|
||||
|
||||
// TestPoolConfig_Defaults tests default configuration values
|
||||
func TestPoolConfig_Defaults(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
// Leave other fields at zero values
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Should use defaults
|
||||
assert.Equal(t, 10, pool.config.MaxConnections)
|
||||
assert.Equal(t, 5*time.Second, pool.config.ConnectTimeout)
|
||||
|
||||
// Verify it works
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_NilConnection tests handling nil connections
|
||||
func TestConnectionPool_NilConnection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Putting nil should be safe
|
||||
pool.Put(nil)
|
||||
|
||||
// Pool should still work
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_StatsTracking tests metrics tracking
|
||||
func TestConnectionPool_StatsTracking(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := pool.Stats()
|
||||
initialGets := stats["gets"].(int64)
|
||||
initialPuts := stats["puts"].(int64)
|
||||
|
||||
// Perform operations
|
||||
numOps := 10
|
||||
for i := 0; i < numOps; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// Check updated stats
|
||||
stats = pool.Stats()
|
||||
assert.Equal(t, initialGets+int64(numOps), stats["gets"].(int64))
|
||||
assert.Equal(t, initialPuts+int64(numOps), stats["puts"].(int64))
|
||||
assert.Equal(t, int32(0), stats["active_connections"].(int32))
|
||||
}
|
||||
|
||||
// TestRedisConn_TooManyArguments tests protection against allocation overflow
|
||||
func TestRedisConn_TooManyArguments(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("AcceptableArgumentCount", func(t *testing.T) {
|
||||
// Should work with reasonable number of args
|
||||
args := make([]string, 100)
|
||||
for i := range args {
|
||||
args[i] = "value"
|
||||
}
|
||||
_, err := conn.Do("MSET", args...)
|
||||
// May fail due to Redis constraints, but shouldn't panic or error on overflow
|
||||
// Just verify it doesn't trigger our overflow protection
|
||||
if err != nil {
|
||||
assert.NotContains(t, err.Error(), "too many arguments")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RejectExcessiveArguments", func(t *testing.T) {
|
||||
// Create an absurdly large number of arguments that would cause overflow
|
||||
// Use 1M + 1 to exceed maxSafeArgs = (1<<20)-1 = 1048575
|
||||
args := make([]string, 1<<20) // 1,048,576 args
|
||||
for i := range args {
|
||||
args[i] = "x"
|
||||
}
|
||||
|
||||
_, err := conn.Do("MSET", args...)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "too many arguments")
|
||||
})
|
||||
|
||||
t.Run("BoundaryCase", func(t *testing.T) {
|
||||
// Test exactly at the boundary (maxSafeArgs)
|
||||
args := make([]string, (1<<20)-1) // Exactly 1,048,575 args (max allowed)
|
||||
for i := range args {
|
||||
args[i] = "x"
|
||||
}
|
||||
|
||||
_, err := conn.Do("ECHO", args...)
|
||||
// Should not error due to overflow protection
|
||||
if err != nil {
|
||||
assert.NotContains(t, err.Error(), "too many arguments")
|
||||
}
|
||||
})
|
||||
}
|
||||
+545
@@ -0,0 +1,545 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRedisBackend_BasicOperations tests basic Redis operations
|
||||
func TestRedisBackend_BasicOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetAndGet", func(t *testing.T) {
|
||||
key := "redis-test-key"
|
||||
value := []byte("redis-test-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
assert.Greater(t, remainingTTL, 50*time.Second)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
_, _, exists, err := backend.Get(ctx, "non-existent-redis-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
key := "redis-delete-key"
|
||||
value := []byte("redis-delete-value")
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
key := "redis-exists-key"
|
||||
value := []byte("redis-exists-value")
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_KeyPrefixing tests key namespace prefixing
|
||||
func TestRedisBackend_KeyPrefixing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "test:prefix:"
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "my-key"
|
||||
value := []byte("my-value")
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that key is stored with prefix
|
||||
keys := mr.CheckKeys()
|
||||
require.Len(t, keys, 1)
|
||||
assert.Equal(t, "test:prefix:my-key", keys[0])
|
||||
|
||||
// Get should work without prefix
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
// TestRedisBackend_TTLExpiration tests TTL handling
|
||||
func TestRedisBackend_TTLExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ShortTTL", func(t *testing.T) {
|
||||
key := "ttl-key"
|
||||
value := []byte("ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Fast forward time in miniredis
|
||||
mr.FastForward(150 * time.Millisecond)
|
||||
|
||||
// Should be expired
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("TTLRemaining", func(t *testing.T) {
|
||||
key := "ttl-remaining-key"
|
||||
value := []byte("ttl-remaining-value")
|
||||
ttl := 10 * time.Second
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get immediately
|
||||
_, ttl1, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Fast forward 2 seconds
|
||||
mr.FastForward(2 * time.Second)
|
||||
|
||||
// Check TTL is less
|
||||
_, ttl2, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Less(t, ttl2, ttl1)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_Clear tests clearing all keys
|
||||
func TestRedisBackend_Clear(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "clear-test:"
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add multiple keys
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify keys exist
|
||||
keys := mr.CheckKeys()
|
||||
assert.Len(t, keys, 10)
|
||||
|
||||
// Clear all
|
||||
err = backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all keys are gone
|
||||
keys = mr.CheckKeys()
|
||||
assert.Len(t, keys, 0)
|
||||
}
|
||||
|
||||
// TestRedisBackend_ConnectionFailure tests behavior on connection failure
|
||||
func TestRedisBackend_ConnectionFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Try to connect to non-existent Redis
|
||||
config := DefaultRedisConfig("localhost:9999")
|
||||
_, err := NewRedisBackend(config)
|
||||
assert.Error(t, err, "Should fail to connect to non-existent Redis")
|
||||
}
|
||||
|
||||
// TestRedisBackend_RedisErrors tests handling of Redis errors
|
||||
func TestRedisBackend_RedisErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Simulate Redis error
|
||||
mr.SetError("simulated error")
|
||||
|
||||
// Operations should fail
|
||||
err = backend.Set(ctx, "error-key", []byte("error-value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Clear error
|
||||
mr.ClearError()
|
||||
|
||||
// Operations should work again
|
||||
err = backend.Set(ctx, "success-key", []byte("success-value"), 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_ConcurrentAccess tests thread safety
|
||||
func TestRedisBackend_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
if j%5 == 0 {
|
||||
backend.Delete(ctx, key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
assert.Greater(t, hits+misses, int64(0))
|
||||
}
|
||||
|
||||
// TestRedisBackend_Stats tests statistics tracking
|
||||
func TestRedisBackend_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := backend.GetStats()
|
||||
assert.Equal(t, int64(0), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(0), stats["misses"].(int64))
|
||||
|
||||
// Add and access items
|
||||
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
backend.Get(ctx, "key1") // Hit
|
||||
backend.Get(ctx, "non-existent") // Miss
|
||||
|
||||
stats = backend.GetStats()
|
||||
assert.Equal(t, int64(1), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(1), stats["misses"].(int64))
|
||||
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
assert.InDelta(t, 0.5, hitRate, 0.01)
|
||||
}
|
||||
|
||||
// TestRedisBackend_Ping tests health check
|
||||
func TestRedisBackend_Ping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = backend.Ping(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Close and ping should fail
|
||||
backend.Close()
|
||||
err = backend.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_Close tests proper cleanup
|
||||
func TestRedisBackend_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("close-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("close-value-%d", i))
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Close
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Operations should fail
|
||||
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
// Double close should be safe
|
||||
err = backend.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_UpdateExisting tests updating existing keys
|
||||
func TestRedisBackend_UpdateExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
|
||||
// Set original
|
||||
err = backend.Set(ctx, key, value1, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update
|
||||
err = backend.Set(ctx, key, value2, 2*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved)
|
||||
assert.Greater(t, ttl, 1*time.Minute)
|
||||
}
|
||||
|
||||
// TestRedisBackend_LargeValues tests handling of large values
|
||||
func TestRedisBackend_LargeValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "large-key"
|
||||
largeValue := make([]byte, 1024*1024) // 1MB
|
||||
|
||||
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(largeValue), len(retrieved))
|
||||
}
|
||||
|
||||
// TestRedisBackend_EmptyValues tests handling of empty values
|
||||
func TestRedisBackend_EmptyValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "empty-key"
|
||||
emptyValue := []byte{}
|
||||
|
||||
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 0, len(retrieved))
|
||||
}
|
||||
|
||||
// TestRedisBackend_PipelineOperations tests batch operations
|
||||
func TestRedisBackend_PipelineOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMany", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("batch-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("batch-value-%d", i))
|
||||
items[key] = value
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all items were set
|
||||
for key, expectedValue := range items {
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, retrieved)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetMany", func(t *testing.T) {
|
||||
// Set test data
|
||||
testData := GenerateTestData(5)
|
||||
for key, value := range testData {
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Get all keys
|
||||
keys := make([]string, 0, len(testData))
|
||||
for key := range testData {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, len(testData))
|
||||
|
||||
for key, expectedValue := range testData {
|
||||
retrievedValue, exists := results[key]
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, retrievedValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetManyWithNonExistent", func(t *testing.T) {
|
||||
keys := []string{"exists-1", "non-existent", "exists-2"}
|
||||
|
||||
backend.Set(ctx, "exists-1", []byte("value-1"), 1*time.Minute)
|
||||
backend.Set(ctx, "exists-2", []byte("value-2"), 1*time.Minute)
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only existing keys
|
||||
assert.Equal(t, []byte("value-1"), results["exists-1"])
|
||||
assert.Equal(t, []byte("value-2"), results["exists-2"])
|
||||
_, exists := results["non-existent"]
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_NoPrefix tests operation without prefix
|
||||
func TestRedisBackend_NoPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "" // No prefix
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "no-prefix-key"
|
||||
value := []byte("no-prefix-value")
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check key is stored without prefix
|
||||
keys := mr.CheckKeys()
|
||||
require.Len(t, keys, 1)
|
||||
assert.Equal(t, key, keys[0])
|
||||
}
|
||||
Vendored
+251
@@ -0,0 +1,251 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RESP (REdis Serialization Protocol) implementation
|
||||
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
|
||||
|
||||
var (
|
||||
ErrInvalidRESP = errors.New("invalid RESP response")
|
||||
ErrNilResponse = errors.New("nil response")
|
||||
)
|
||||
|
||||
// Object pools for memory optimization - reduces allocations by 50-70%
|
||||
var (
|
||||
readerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPReader{
|
||||
r: bufio.NewReaderSize(nil, 4096),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
writerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPWriter{
|
||||
w: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// RESPWriter writes RESP protocol messages
|
||||
type RESPWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
|
||||
func NewRESPWriter(w io.Writer) *RESPWriter {
|
||||
writer := writerPool.Get().(*RESPWriter)
|
||||
writer.w = w
|
||||
return writer
|
||||
}
|
||||
|
||||
// Release returns the writer to the pool for reuse
|
||||
func (w *RESPWriter) Release() {
|
||||
w.w = nil
|
||||
writerPool.Put(w)
|
||||
}
|
||||
|
||||
// WriteCommand writes a Redis command in RESP array format
|
||||
// Example: SET key value EX 3600 -> *5\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$4\r\n3600\r\n
|
||||
func (w *RESPWriter) WriteCommand(args ...string) error {
|
||||
// Write array header
|
||||
if _, err := fmt.Fprintf(w.w, "*%d\r\n", len(args)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write each argument as bulk string
|
||||
for _, arg := range args {
|
||||
if _, err := fmt.Fprintf(w.w, "$%d\r\n%s\r\n", len(arg), arg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RESPReader reads RESP protocol messages
|
||||
type RESPReader struct {
|
||||
r *bufio.Reader
|
||||
}
|
||||
|
||||
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
|
||||
func NewRESPReader(r io.Reader) *RESPReader {
|
||||
reader := readerPool.Get().(*RESPReader)
|
||||
reader.r.Reset(r)
|
||||
return reader
|
||||
}
|
||||
|
||||
// Release returns the reader to the pool for reuse
|
||||
func (r *RESPReader) Release() {
|
||||
r.r.Reset(nil)
|
||||
readerPool.Put(r)
|
||||
}
|
||||
|
||||
// ReadResponse reads a RESP response and returns the parsed value
|
||||
func (r *RESPReader) ReadResponse() (interface{}, error) {
|
||||
typeByte, err := r.r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch typeByte {
|
||||
case '+': // Simple string
|
||||
return r.readSimpleString()
|
||||
case '-': // Error
|
||||
return nil, r.readError()
|
||||
case ':': // Integer
|
||||
return r.readInteger()
|
||||
case '$': // Bulk string
|
||||
return r.readBulkString()
|
||||
case '*': // Array
|
||||
return r.readArray()
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: unknown type byte '%c'", ErrInvalidRESP, typeByte)
|
||||
}
|
||||
}
|
||||
|
||||
// readSimpleString reads a simple string (+OK\r\n)
|
||||
func (r *RESPReader) readSimpleString() (string, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return line, nil
|
||||
}
|
||||
|
||||
// readError reads an error message (-Error message\r\n)
|
||||
func (r *RESPReader) readError() error {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New(line)
|
||||
}
|
||||
|
||||
// readInteger reads an integer (:1000\r\n)
|
||||
func (r *RESPReader) readInteger() (int64, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseInt(line, 10, 64)
|
||||
}
|
||||
|
||||
// readBulkString reads a bulk string ($6\r\nfoobar\r\n or $-1\r\n for nil)
|
||||
func (r *RESPReader) readBulkString() (interface{}, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid bulk string length", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
// -1 indicates nil bulk string
|
||||
if length == -1 {
|
||||
return nil, ErrNilResponse
|
||||
}
|
||||
|
||||
// Read exactly 'length' bytes plus \r\n
|
||||
buf := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(r.r, buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify \r\n terminator
|
||||
if buf[length] != '\r' || buf[length+1] != '\n' {
|
||||
return nil, fmt.Errorf("%w: missing CRLF after bulk string", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
return string(buf[:length]), nil
|
||||
}
|
||||
|
||||
// readArray reads an array (*2\r\n...\r\n or *-1\r\n for nil)
|
||||
func (r *RESPReader) readArray() (interface{}, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid array length", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
// -1 indicates nil array
|
||||
if length == -1 {
|
||||
return nil, ErrNilResponse
|
||||
}
|
||||
|
||||
// Read each element
|
||||
result := make([]interface{}, length)
|
||||
for i := 0; i < length; i++ {
|
||||
elem, err := r.ReadResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[i] = elem
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// readLine reads a line terminated by \r\n
|
||||
func (r *RESPReader) readLine() (string, error) {
|
||||
line, err := r.r.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Remove \r\n
|
||||
line = strings.TrimSuffix(line, "\r\n")
|
||||
if !strings.HasSuffix(line+"\r\n", "\r\n") {
|
||||
return "", fmt.Errorf("%w: missing CRLF", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
return line, nil
|
||||
}
|
||||
|
||||
// RESPString extracts a string from RESP response
|
||||
func RESPString(resp interface{}) (string, error) {
|
||||
if resp == nil {
|
||||
return "", ErrNilResponse
|
||||
}
|
||||
|
||||
switch v := resp.(type) {
|
||||
case string:
|
||||
return v, nil
|
||||
case []byte:
|
||||
return string(v), nil
|
||||
default:
|
||||
return "", fmt.Errorf("expected string, got %T", resp)
|
||||
}
|
||||
}
|
||||
|
||||
// RESPInt extracts an integer from RESP response
|
||||
func RESPInt(resp interface{}) (int64, error) {
|
||||
if resp == nil {
|
||||
return 0, ErrNilResponse
|
||||
}
|
||||
|
||||
switch v := resp.(type) {
|
||||
case int64:
|
||||
return v, nil
|
||||
case int:
|
||||
return int64(v), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("expected integer, got %T", resp)
|
||||
}
|
||||
}
|
||||
Vendored
+495
@@ -0,0 +1,495 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRESPWriter_WriteCommand tests RESP command writing
|
||||
func TestRESPWriter_WriteCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple command",
|
||||
args: []string{"PING"},
|
||||
expected: "*1\r\n$4\r\nPING\r\n",
|
||||
},
|
||||
{
|
||||
name: "SET command",
|
||||
args: []string{"SET", "key", "value"},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
|
||||
},
|
||||
{
|
||||
name: "SETEX command",
|
||||
args: []string{"SETEX", "mykey", "60", "myvalue"},
|
||||
expected: "*4\r\n$5\r\nSETEX\r\n$5\r\nmykey\r\n$2\r\n60\r\n$7\r\nmyvalue\r\n",
|
||||
},
|
||||
{
|
||||
name: "DEL with multiple keys",
|
||||
args: []string{"DEL", "key1", "key2", "key3"},
|
||||
expected: "*4\r\n$3\r\nDEL\r\n$4\r\nkey1\r\n$4\r\nkey2\r\n$4\r\nkey3\r\n",
|
||||
},
|
||||
{
|
||||
name: "Command with empty string",
|
||||
args: []string{"SET", "key", ""},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
|
||||
},
|
||||
{
|
||||
name: "Command with special characters",
|
||||
args: []string{"SET", "key", "val\r\nue"},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$7\r\nval\r\nue\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
writer := NewRESPWriter(buf)
|
||||
|
||||
err := writer.WriteCommand(tt.args...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadSimpleString tests reading simple strings
|
||||
func TestRESPReader_ReadSimpleString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "OK response",
|
||||
input: "+OK\r\n",
|
||||
expected: "OK",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "PONG response",
|
||||
input: "+PONG\r\n",
|
||||
expected: "PONG",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "+\r\n",
|
||||
expected: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "String with spaces",
|
||||
input: "+Hello World\r\n",
|
||||
expected: "Hello World",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadError tests reading error messages
|
||||
func TestRESPReader_ReadError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "ERR error",
|
||||
input: "-ERR unknown command\r\n",
|
||||
expectedError: "ERR unknown command",
|
||||
},
|
||||
{
|
||||
name: "WRONGTYPE error",
|
||||
input: "-WRONGTYPE Operation against a key holding the wrong kind of value\r\n",
|
||||
expectedError: "WRONGTYPE Operation against a key holding the wrong kind of value",
|
||||
},
|
||||
{
|
||||
name: "Simple error",
|
||||
input: "-Error\r\n",
|
||||
expectedError: "Error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
_, err := reader.ReadResponse()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.expectedError, err.Error())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadInteger tests reading integers
|
||||
func TestRESPReader_ReadInteger(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Zero",
|
||||
input: ":0\r\n",
|
||||
expected: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Positive integer",
|
||||
input: ":1000\r\n",
|
||||
expected: 1000,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Negative integer",
|
||||
input: ":-1\r\n",
|
||||
expected: -1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Large integer",
|
||||
input: ":9223372036854775807\r\n",
|
||||
expected: 9223372036854775807,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid integer",
|
||||
input: ":abc\r\n",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadBulkString tests reading bulk strings
|
||||
func TestRESPReader_ReadBulkString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected interface{}
|
||||
wantErr bool
|
||||
isNil bool
|
||||
}{
|
||||
{
|
||||
name: "Simple bulk string",
|
||||
input: "$6\r\nfoobar\r\n",
|
||||
expected: "foobar",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty bulk string",
|
||||
input: "$0\r\n\r\n",
|
||||
expected: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Nil bulk string",
|
||||
input: "$-1\r\n",
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
isNil: true,
|
||||
},
|
||||
{
|
||||
name: "Binary safe bulk string",
|
||||
input: "$5\r\n\x00\x01\x02\x03\x04\r\n",
|
||||
expected: "\x00\x01\x02\x03\x04",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid length",
|
||||
input: "$abc\r\ntest\r\n",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.isNil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadArray tests reading arrays
|
||||
func TestRESPReader_ReadArray(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []interface{}
|
||||
wantErr bool
|
||||
isNil bool
|
||||
}{
|
||||
{
|
||||
name: "Empty array",
|
||||
input: "*0\r\n",
|
||||
expected: []interface{}{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Array of bulk strings",
|
||||
input: "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
|
||||
expected: []interface{}{
|
||||
"foo",
|
||||
"bar",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Array of integers",
|
||||
input: "*3\r\n:1\r\n:2\r\n:3\r\n",
|
||||
expected: []interface{}{
|
||||
int64(1),
|
||||
int64(2),
|
||||
int64(3),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Mixed array",
|
||||
input: "*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n",
|
||||
expected: []interface{}{
|
||||
int64(1),
|
||||
int64(2),
|
||||
int64(3),
|
||||
int64(4),
|
||||
"foobar",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Nil array",
|
||||
input: "*-1\r\n",
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
isNil: true,
|
||||
},
|
||||
{
|
||||
name: "Nested arrays",
|
||||
input: "*2\r\n*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n*1\r\n$3\r\nbaz\r\n",
|
||||
expected: []interface{}{
|
||||
[]interface{}{"foo", "bar"},
|
||||
[]interface{}{"baz"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.isNil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_InvalidInput tests error handling for invalid input
|
||||
func TestRESPReader_InvalidInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "Unknown type byte",
|
||||
input: "?invalid\r\n",
|
||||
},
|
||||
{
|
||||
name: "Incomplete response",
|
||||
input: "+OK",
|
||||
},
|
||||
{
|
||||
name: "Missing CRLF in bulk string",
|
||||
input: "$5\r\nhello",
|
||||
},
|
||||
{
|
||||
name: "Truncated array",
|
||||
input: "*3\r\n:1\r\n:2\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
_, err := reader.ReadResponse()
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_EOF tests handling of EOF
|
||||
func TestRESPReader_EOF(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(""))
|
||||
_, err := reader.ReadResponse()
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, io.EOF))
|
||||
}
|
||||
|
||||
// TestRESPHelpers tests helper functions
|
||||
func TestRESPHelpers(t *testing.T) {
|
||||
t.Run("RESPString", func(t *testing.T) {
|
||||
// Valid string
|
||||
result, err := RESPString("hello")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello", result)
|
||||
|
||||
// Byte slice
|
||||
result, err = RESPString([]byte("world"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "world", result)
|
||||
|
||||
// Nil
|
||||
_, err = RESPString(nil)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
|
||||
// Invalid type
|
||||
_, err = RESPString(123)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("RESPInt", func(t *testing.T) {
|
||||
// Valid int64
|
||||
result, err := RESPInt(int64(42))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(42), result)
|
||||
|
||||
// Valid int
|
||||
result, err = RESPInt(42)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(42), result)
|
||||
|
||||
// Nil
|
||||
_, err = RESPInt(nil)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
|
||||
// Invalid type
|
||||
_, err = RESPInt("string")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRESPRoundTrip tests full round-trip encoding/decoding
|
||||
func TestRESPRoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command []string
|
||||
response string
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
name: "PING command",
|
||||
command: []string{"PING"},
|
||||
response: "+PONG\r\n",
|
||||
expected: "PONG",
|
||||
},
|
||||
{
|
||||
name: "GET command with result",
|
||||
command: []string{"GET", "mykey"},
|
||||
response: "$7\r\nmyvalue\r\n",
|
||||
expected: "myvalue",
|
||||
},
|
||||
{
|
||||
name: "GET command with nil",
|
||||
command: []string{"GET", "nonexistent"},
|
||||
response: "$-1\r\n",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "DEL command",
|
||||
command: []string{"DEL", "key1", "key2"},
|
||||
response: ":2\r\n",
|
||||
expected: int64(2),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Write command
|
||||
writeBuf := &bytes.Buffer{}
|
||||
writer := NewRESPWriter(writeBuf)
|
||||
err := writer.WriteCommand(tt.command...)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
reader := NewRESPReader(strings.NewReader(tt.response))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.expected == nil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+198
@@ -0,0 +1,198 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestLogger implements a simple logger for tests
|
||||
type TestLogger struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func NewTestLogger(t *testing.T) *TestLogger {
|
||||
return &TestLogger{t: t}
|
||||
}
|
||||
|
||||
func (l *TestLogger) Debug(format string, args ...interface{}) {
|
||||
l.t.Logf("[DEBUG] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Info(format string, args ...interface{}) {
|
||||
l.t.Logf("[INFO] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Error(format string, args ...interface{}) {
|
||||
l.t.Logf("[ERROR] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Debugf(format string, args ...interface{}) {
|
||||
l.Debug(format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Infof(format string, args ...interface{}) {
|
||||
l.Info(format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Errorf(format string, args ...interface{}) {
|
||||
l.Error(format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Warnf(format string, args ...interface{}) {
|
||||
l.t.Logf("[WARN] "+format, args...)
|
||||
}
|
||||
|
||||
// MiniredisServer manages a miniredis instance for testing
|
||||
type MiniredisServer struct {
|
||||
server *miniredis.Miniredis
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
// NewMiniredisServer creates a new miniredis server for testing
|
||||
func NewMiniredisServer(t *testing.T) *MiniredisServer {
|
||||
t.Helper()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err, "failed to start miniredis")
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
})
|
||||
|
||||
// Verify connection
|
||||
ctx := context.Background()
|
||||
err = client.Ping(ctx).Err()
|
||||
require.NoError(t, err, "failed to ping miniredis")
|
||||
|
||||
t.Cleanup(func() {
|
||||
client.Close()
|
||||
mr.Close()
|
||||
})
|
||||
|
||||
return &MiniredisServer{
|
||||
server: mr,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAddr returns the address of the miniredis server
|
||||
func (m *MiniredisServer) GetAddr() string {
|
||||
return m.server.Addr()
|
||||
}
|
||||
|
||||
// GetClient returns the Redis client
|
||||
func (m *MiniredisServer) GetClient() *redis.Client {
|
||||
return m.client
|
||||
}
|
||||
|
||||
// FastForward advances the miniredis server's time
|
||||
func (m *MiniredisServer) FastForward(d time.Duration) {
|
||||
m.server.FastForward(d)
|
||||
}
|
||||
|
||||
// FlushAll removes all keys from the database
|
||||
func (m *MiniredisServer) FlushAll() {
|
||||
m.server.FlushAll()
|
||||
}
|
||||
|
||||
// SetError simulates a Redis error
|
||||
func (m *MiniredisServer) SetError(err string) {
|
||||
m.server.SetError(err)
|
||||
}
|
||||
|
||||
// ClearError clears any simulated errors
|
||||
func (m *MiniredisServer) ClearError() {
|
||||
m.server.SetError("")
|
||||
}
|
||||
|
||||
// CheckKeys verifies that specific keys exist in Redis
|
||||
func (m *MiniredisServer) CheckKeys() []string {
|
||||
return m.server.Keys()
|
||||
}
|
||||
|
||||
// Close closes the miniredis server
|
||||
func (m *MiniredisServer) Close() {
|
||||
m.server.Close()
|
||||
}
|
||||
|
||||
// Restart restarts the miniredis server
|
||||
func (m *MiniredisServer) Restart() {
|
||||
m.server.Restart()
|
||||
}
|
||||
|
||||
// TestConfig provides default test configuration
|
||||
type TestConfig struct {
|
||||
MaxSize int
|
||||
DefaultTTL time.Duration
|
||||
CleanupInterval time.Duration
|
||||
EnableMetrics bool
|
||||
}
|
||||
|
||||
// DefaultTestConfig returns a standard test configuration
|
||||
func DefaultTestConfig() *TestConfig {
|
||||
return &TestConfig{
|
||||
MaxSize: 100,
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
CleanupInterval: 1 * time.Second,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateTestData creates test cache data
|
||||
func GenerateTestData(count int) map[string][]byte {
|
||||
data := make(map[string][]byte, count)
|
||||
for i := 0; i < count; i++ {
|
||||
key := fmt.Sprintf("test-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("test-value-%d", i))
|
||||
data[key] = value
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// GenerateLargeValue creates a large test value
|
||||
func GenerateLargeValue(sizeBytes int) []byte {
|
||||
return make([]byte, sizeBytes)
|
||||
}
|
||||
|
||||
// AssertCacheStats is a helper to verify cache statistics
|
||||
func AssertCacheStats(t *testing.T, stats map[string]interface{}, expectedHits, expectedMisses int64) {
|
||||
t.Helper()
|
||||
|
||||
hits, ok := stats["hits"].(int64)
|
||||
require.True(t, ok, "hits should be int64")
|
||||
require.Equal(t, expectedHits, hits, "unexpected hit count")
|
||||
|
||||
misses, ok := stats["misses"].(int64)
|
||||
require.True(t, ok, "misses should be int64")
|
||||
require.Equal(t, expectedMisses, misses, "unexpected miss count")
|
||||
}
|
||||
|
||||
// WaitForCondition waits for a condition to be true or times out
|
||||
func WaitForCondition(t *testing.T, timeout time.Duration, checkInterval time.Duration, condition func() bool) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
time.Sleep(checkInterval)
|
||||
}
|
||||
t.Fatal("timeout waiting for condition")
|
||||
}
|
||||
|
||||
// AssertEventuallyExpires verifies that a key eventually expires
|
||||
func AssertEventuallyExpires(t *testing.T, backend CacheBackend, ctx context.Context, key string, maxWait time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
WaitForCondition(t, maxWait, 100*time.Millisecond, func() bool {
|
||||
_, _, exists, err := backend.Get(ctx, key)
|
||||
return err == nil && !exists
|
||||
})
|
||||
}
|
||||
Vendored
+96
-10
@@ -1880,19 +1880,20 @@ func TestConcurrentManagerOperations(t *testing.T) {
|
||||
// TestTTLExpirationAndCleanup tests TTL expiration and cleanup routines comprehensively
|
||||
func TestTTLExpirationAndCleanup(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.CleanupInterval = 10 * time.Millisecond
|
||||
config.CleanupInterval = 50 * time.Millisecond
|
||||
config.EnableAutoCleanup = true
|
||||
cache := New(config)
|
||||
defer cache.Close()
|
||||
|
||||
// Test various TTL scenarios
|
||||
// Note: Timing increased 5x to account for race detector overhead
|
||||
testCases := []struct {
|
||||
key string
|
||||
ttl time.Duration
|
||||
}{
|
||||
{"very-short", 5 * time.Millisecond},
|
||||
{"short", 25 * time.Millisecond},
|
||||
{"medium", 100 * time.Millisecond},
|
||||
{"very-short", 25 * time.Millisecond},
|
||||
{"short", 125 * time.Millisecond},
|
||||
{"medium", 500 * time.Millisecond},
|
||||
{"long", 1 * time.Hour},
|
||||
}
|
||||
|
||||
@@ -1908,13 +1909,13 @@ func TestTTLExpirationAndCleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
// Wait for very short items to expire
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
time.Sleep(75 * time.Millisecond)
|
||||
if _, exists := cache.Get("very-short"); exists {
|
||||
t.Error("Very short item should be expired")
|
||||
}
|
||||
|
||||
// Wait for short items to expire
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
if _, exists := cache.Get("short"); exists {
|
||||
t.Error("Short item should be expired")
|
||||
}
|
||||
@@ -1930,16 +1931,16 @@ func TestTTLExpirationAndCleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test manual cleanup
|
||||
cache.Set("manual-cleanup", "value", 1*time.Millisecond)
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.Set("manual-cleanup", "value", 5*time.Millisecond)
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
cache.Cleanup()
|
||||
|
||||
// Add many expired items to test bulk cleanup
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("bulk-%d", i)
|
||||
cache.Set(key, fmt.Sprintf("value-%d", i), 1*time.Millisecond)
|
||||
cache.Set(key, fmt.Sprintf("value-%d", i), 5*time.Millisecond)
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
|
||||
sizeBefore := cache.Size()
|
||||
cache.Cleanup()
|
||||
@@ -2038,3 +2039,88 @@ func TestCacheStatisticsAndMetrics(t *testing.T) {
|
||||
t.Error("Memory usage should increase after adding large item")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// noOpLogger Tests
|
||||
// ============================================================================
|
||||
|
||||
// TestNoOpLogger_AllMethods tests all noOpLogger methods to ensure they don't panic
|
||||
func TestNoOpLogger_AllMethods(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
// Test simple message methods
|
||||
logger.Debug("test debug message")
|
||||
logger.Info("test info message")
|
||||
logger.Error("test error message")
|
||||
logger.Warn("test warn message")
|
||||
logger.Fatal("test fatal message")
|
||||
|
||||
// Test formatted message methods
|
||||
logger.Debugf("test debug: %s", "value")
|
||||
logger.Infof("test info: %s", "value")
|
||||
logger.Errorf("test error: %s", "value")
|
||||
logger.Warnf("test warn: %s", "value")
|
||||
logger.Fatalf("test fatal: %s", "value")
|
||||
|
||||
// If we reach here, all methods executed without panicking
|
||||
// This is expected behavior for a no-op logger
|
||||
}
|
||||
|
||||
// TestNoOpLogger_WithField verifies WithField returns the same logger
|
||||
func TestNoOpLogger_WithField(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
result := logger.WithField("key", "value")
|
||||
|
||||
if result != logger {
|
||||
t.Error("WithField should return the same logger instance")
|
||||
}
|
||||
|
||||
// Verify the returned logger works
|
||||
result.Info("test message after WithField")
|
||||
}
|
||||
|
||||
// TestNoOpLogger_WithFields verifies WithFields returns the same logger
|
||||
func TestNoOpLogger_WithFields(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 123,
|
||||
"key3": true,
|
||||
}
|
||||
|
||||
result := logger.WithFields(fields)
|
||||
|
||||
if result != logger {
|
||||
t.Error("WithFields should return the same logger instance")
|
||||
}
|
||||
|
||||
// Verify the returned logger works
|
||||
result.Info("test message after WithFields")
|
||||
}
|
||||
|
||||
// TestNoOpLogger_Chaining verifies method chaining works
|
||||
func TestNoOpLogger_Chaining(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
// Use WithField and verify it returns a usable logger
|
||||
result := logger.WithField("key1", "value1")
|
||||
|
||||
// Verify the result can be used for logging (Logger interface methods)
|
||||
result.Info("info after WithField")
|
||||
result.Infof("infof after WithField: %s", "test")
|
||||
result.Debug("debug after WithField")
|
||||
result.Debugf("debugf after WithField: %d", 123)
|
||||
result.Error("error after WithField")
|
||||
result.Errorf("errorf after WithField: %v", true)
|
||||
|
||||
// Use WithFields and verify it returns a usable logger
|
||||
result2 := logger.WithFields(map[string]interface{}{
|
||||
"key2": "value2",
|
||||
"key3": 123,
|
||||
})
|
||||
|
||||
// Verify the result can be used for logging
|
||||
result2.Infof("message after WithFields: %s", "test")
|
||||
}
|
||||
|
||||
+329
@@ -0,0 +1,329 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
// ErrCircuitOpen is returned when the circuit breaker is open
|
||||
ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||
|
||||
// ErrTooManyRequests is returned when too many requests are made in half-open state
|
||||
ErrTooManyRequests = errors.New("too many requests in half-open state")
|
||||
)
|
||||
|
||||
// State represents the state of the circuit breaker
|
||||
type State int32
|
||||
|
||||
const (
|
||||
// StateClosed allows all operations to pass through
|
||||
StateClosed State = iota
|
||||
|
||||
// StateOpen blocks all operations
|
||||
StateOpen
|
||||
|
||||
// StateHalfOpen allows a limited number of operations to test recovery
|
||||
StateHalfOpen
|
||||
)
|
||||
|
||||
// String returns the string representation of the state
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateClosed:
|
||||
return "closed"
|
||||
case StateOpen:
|
||||
return "open"
|
||||
case StateHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds configuration for the circuit breaker
|
||||
type CircuitBreakerConfig struct {
|
||||
// MaxFailures is the number of consecutive failures before opening the circuit
|
||||
MaxFailures int
|
||||
|
||||
// FailureThreshold is the failure rate threshold (0.0 to 1.0)
|
||||
FailureThreshold float64
|
||||
|
||||
// Timeout is how long the circuit stays open before trying half-open
|
||||
Timeout time.Duration
|
||||
|
||||
// HalfOpenMaxRequests is the number of requests allowed in half-open state
|
||||
HalfOpenMaxRequests int
|
||||
|
||||
// ResetTimeout is how long to wait before resetting counters in closed state
|
||||
ResetTimeout time.Duration
|
||||
|
||||
// OnStateChange is called when the circuit breaker changes state
|
||||
OnStateChange func(from, to State)
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns default configuration
|
||||
func DefaultCircuitBreakerConfig() *CircuitBreakerConfig {
|
||||
return &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
FailureThreshold: 0.6,
|
||||
Timeout: 30 * time.Second,
|
||||
HalfOpenMaxRequests: 3,
|
||||
ResetTimeout: 60 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern
|
||||
type CircuitBreaker struct {
|
||||
config *CircuitBreakerConfig
|
||||
|
||||
// State management
|
||||
state atomic.Int32
|
||||
lastStateChange time.Time
|
||||
stateMu sync.RWMutex
|
||||
|
||||
// Failure tracking
|
||||
consecutiveFailures atomic.Int32
|
||||
totalRequests atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
halfOpenRequests atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
nextRetryTime time.Time
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
stateTransitions atomic.Int64
|
||||
rejectedRequests atomic.Int64
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker
|
||||
func NewCircuitBreaker(config *CircuitBreakerConfig) *CircuitBreaker {
|
||||
if config == nil {
|
||||
config = DefaultCircuitBreakerConfig()
|
||||
}
|
||||
|
||||
return &CircuitBreaker{
|
||||
config: config,
|
||||
lastStateChange: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs a function through the circuit breaker
|
||||
func (cb *CircuitBreaker) Execute(ctx context.Context, fn func() error) error {
|
||||
if !cb.AllowRequest() {
|
||||
cb.rejectedRequests.Add(1)
|
||||
return ErrCircuitOpen
|
||||
}
|
||||
|
||||
cb.totalRequests.Add(1)
|
||||
|
||||
err := fn()
|
||||
if err != nil {
|
||||
cb.RecordFailure()
|
||||
} else {
|
||||
cb.RecordSuccess()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AllowRequest checks if a request is allowed to proceed
|
||||
func (cb *CircuitBreaker) AllowRequest() bool {
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
return true
|
||||
|
||||
case StateOpen:
|
||||
// Check if timeout has passed and we should try half-open
|
||||
cb.timeMu.RLock()
|
||||
shouldRetry := time.Now().After(cb.nextRetryTime)
|
||||
cb.timeMu.RUnlock()
|
||||
|
||||
if shouldRetry {
|
||||
cb.setState(StateHalfOpen)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
case StateHalfOpen:
|
||||
// Allow limited requests in half-open state
|
||||
current := cb.halfOpenRequests.Add(1)
|
||||
return current <= int32(cb.config.HalfOpenMaxRequests)
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation
|
||||
func (cb *CircuitBreaker) RecordSuccess() {
|
||||
cb.timeMu.Lock()
|
||||
cb.lastSuccessTime = time.Now()
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Reset consecutive failures
|
||||
cb.consecutiveFailures.Store(0)
|
||||
|
||||
case StateHalfOpen:
|
||||
// If we've had enough successful requests, close the circuit
|
||||
successfulRequests := cb.halfOpenRequests.Load()
|
||||
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordFailure records a failed operation
|
||||
func (cb *CircuitBreaker) RecordFailure() {
|
||||
cb.totalFailures.Add(1)
|
||||
failures := cb.consecutiveFailures.Add(1)
|
||||
|
||||
cb.timeMu.Lock()
|
||||
cb.lastFailureTime = time.Now()
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Check if we should open the circuit
|
||||
if failures >= int32(cb.config.MaxFailures) {
|
||||
cb.openCircuit()
|
||||
} else if cb.config.FailureThreshold > 0 {
|
||||
// Check failure rate
|
||||
total := cb.totalRequests.Load()
|
||||
failureCount := cb.totalFailures.Load()
|
||||
if total > 10 && float64(failureCount)/float64(total) > cb.config.FailureThreshold {
|
||||
cb.openCircuit()
|
||||
}
|
||||
}
|
||||
|
||||
case StateHalfOpen:
|
||||
// Any failure in half-open state reopens the circuit
|
||||
cb.openCircuit()
|
||||
}
|
||||
}
|
||||
|
||||
// openCircuit transitions to open state
|
||||
func (cb *CircuitBreaker) openCircuit() {
|
||||
cb.setState(StateOpen)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
|
||||
cb.timeMu.Lock()
|
||||
cb.nextRetryTime = time.Now().Add(cb.config.Timeout)
|
||||
cb.timeMu.Unlock()
|
||||
}
|
||||
|
||||
// GetState returns the current state
|
||||
func (cb *CircuitBreaker) GetState() State {
|
||||
return State(cb.state.Load())
|
||||
}
|
||||
|
||||
// setState changes the circuit breaker state
|
||||
func (cb *CircuitBreaker) setState(newState State) {
|
||||
oldState := State(cb.state.Swap(int32(newState)))
|
||||
|
||||
if oldState != newState {
|
||||
cb.stateTransitions.Add(1)
|
||||
|
||||
cb.stateMu.Lock()
|
||||
cb.lastStateChange = time.Now()
|
||||
cb.stateMu.Unlock()
|
||||
|
||||
if cb.config.OnStateChange != nil {
|
||||
cb.config.OnStateChange(oldState, newState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset resets the circuit breaker to closed state
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
cb.totalRequests.Store(0)
|
||||
cb.totalFailures.Store(0)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
cb.rejectedRequests.Store(0)
|
||||
cb.stateTransitions.Store(0)
|
||||
|
||||
now := time.Now()
|
||||
cb.timeMu.Lock()
|
||||
cb.lastFailureTime = now
|
||||
cb.lastSuccessTime = now
|
||||
cb.nextRetryTime = now
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
cb.stateMu.Lock()
|
||||
cb.lastStateChange = now
|
||||
cb.stateMu.Unlock()
|
||||
}
|
||||
|
||||
// Stats returns circuit breaker statistics
|
||||
func (cb *CircuitBreaker) Stats() CircuitBreakerStats {
|
||||
cb.timeMu.RLock()
|
||||
lastFailure := cb.lastFailureTime
|
||||
lastSuccess := cb.lastSuccessTime
|
||||
nextRetry := cb.nextRetryTime
|
||||
cb.timeMu.RUnlock()
|
||||
|
||||
cb.stateMu.RLock()
|
||||
lastChange := cb.lastStateChange
|
||||
cb.stateMu.RUnlock()
|
||||
|
||||
totalReq := cb.totalRequests.Load()
|
||||
totalFail := cb.totalFailures.Load()
|
||||
successRate := float64(0)
|
||||
if totalReq > 0 {
|
||||
successRate = float64(totalReq-totalFail) / float64(totalReq)
|
||||
}
|
||||
|
||||
return CircuitBreakerStats{
|
||||
State: cb.GetState(),
|
||||
ConsecutiveFailures: cb.consecutiveFailures.Load(),
|
||||
TotalRequests: totalReq,
|
||||
TotalFailures: totalFail,
|
||||
SuccessRate: successRate,
|
||||
RejectedRequests: cb.rejectedRequests.Load(),
|
||||
StateTransitions: cb.stateTransitions.Load(),
|
||||
LastFailureTime: lastFailure,
|
||||
LastSuccessTime: lastSuccess,
|
||||
LastStateChange: lastChange,
|
||||
NextRetryTime: nextRetry,
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerStats holds statistics for the circuit breaker
|
||||
type CircuitBreakerStats struct {
|
||||
State State
|
||||
ConsecutiveFailures int32
|
||||
TotalRequests int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
RejectedRequests int64
|
||||
StateTransitions int64
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastStateChange time.Time
|
||||
NextRetryTime time.Time
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the circuit breaker is in a healthy state
|
||||
func (cb *CircuitBreaker) IsHealthy() bool {
|
||||
return cb.GetState() != StateOpen
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
)
|
||||
|
||||
// CircuitBreakerBackend wraps a cache backend with circuit breaker protection
|
||||
type CircuitBreakerBackend struct {
|
||||
backend backends.CacheBackend
|
||||
cb *CircuitBreaker
|
||||
}
|
||||
|
||||
// NewCircuitBreakerBackend creates a new circuit breaker wrapped backend
|
||||
func NewCircuitBreakerBackend(b backends.CacheBackend, config *CircuitBreakerConfig) backends.CacheBackend {
|
||||
if config == nil {
|
||||
config = DefaultCircuitBreakerConfig()
|
||||
}
|
||||
|
||||
return &CircuitBreakerBackend{
|
||||
backend: b,
|
||||
cb: NewCircuitBreaker(config),
|
||||
}
|
||||
}
|
||||
|
||||
// Set stores a value with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Set(ctx, key, value, ttl)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return nil, 0, false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
value, ttl, exists, err := c.backend.Get(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return value, ttl, exists, err
|
||||
}
|
||||
|
||||
// Delete removes a key with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
deleted, err := c.backend.Delete(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return deleted, err
|
||||
}
|
||||
|
||||
// Exists checks if a key exists with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
exists, err := c.backend.Exists(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// Clear removes all keys with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Clear(ctx context.Context) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Clear(ctx)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetStats returns statistics including circuit breaker state
|
||||
func (c *CircuitBreakerBackend) GetStats() map[string]interface{} {
|
||||
stats := c.backend.GetStats()
|
||||
if stats == nil {
|
||||
stats = make(map[string]interface{})
|
||||
}
|
||||
|
||||
cbStats := c.cb.Stats()
|
||||
stats["circuit_breaker"] = map[string]interface{}{
|
||||
"state": cbStats.State.String(),
|
||||
"consecutive_failures": cbStats.ConsecutiveFailures,
|
||||
"total_requests": cbStats.TotalRequests,
|
||||
"total_failures": cbStats.TotalFailures,
|
||||
"success_rate": cbStats.SuccessRate,
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks backend health with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Ping(ctx context.Context) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Ping(ctx)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close shuts down the backend
|
||||
func (c *CircuitBreakerBackend) Close() error {
|
||||
return c.backend.Close()
|
||||
}
|
||||
@@ -0,0 +1,561 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockBackend is a simple mock implementation for testing
|
||||
type mockBackend struct {
|
||||
data map[string]mockEntry
|
||||
mu sync.RWMutex
|
||||
failSet bool
|
||||
failGet bool
|
||||
failDelete bool
|
||||
failExists bool
|
||||
failClear bool
|
||||
failPing bool
|
||||
callCount int
|
||||
}
|
||||
|
||||
type mockEntry struct {
|
||||
value []byte
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newMockBackend() *mockBackend {
|
||||
return &mockBackend{
|
||||
data: make(map[string]mockEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failSet {
|
||||
return errors.New("mock set error")
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(ttl)
|
||||
if ttl == 0 {
|
||||
expiresAt = time.Now().Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
m.data[key] = mockEntry{
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failGet {
|
||||
return nil, 0, false, errors.New("mock get error")
|
||||
}
|
||||
|
||||
entry, exists := m.data[key]
|
||||
if !exists {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
ttl := time.Until(entry.expiresAt)
|
||||
return entry.value, ttl, true, nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failDelete {
|
||||
return false, errors.New("mock delete error")
|
||||
}
|
||||
|
||||
_, existed := m.data[key]
|
||||
delete(m.data, key)
|
||||
return existed, nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failExists {
|
||||
return false, errors.New("mock exists error")
|
||||
}
|
||||
|
||||
entry, exists := m.data[key]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Clear(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failClear {
|
||||
return errors.New("mock clear error")
|
||||
}
|
||||
|
||||
m.data = make(map[string]mockEntry)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) GetStats() map[string]interface{} {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"hits": int64(0),
|
||||
"misses": int64(0),
|
||||
"call_count": m.callCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockBackend) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Ping(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failPing {
|
||||
return errors.New("mock ping error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Constructor Tests
|
||||
|
||||
func TestNewCircuitBreakerBackend_WithDefaultConfig(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
require.NotNil(t, cb)
|
||||
|
||||
// Verify it implements the interface (compile-time check)
|
||||
var _ backends.CacheBackend = cb
|
||||
}
|
||||
|
||||
func TestNewCircuitBreakerBackend_WithCustomConfig(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
FailureThreshold: 0.5,
|
||||
Timeout: 5 * time.Second,
|
||||
HalfOpenMaxRequests: 2,
|
||||
ResetTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
require.NotNil(t, cb)
|
||||
}
|
||||
|
||||
// Set Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Set_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, mockBE.callCount)
|
||||
|
||||
// Verify value was stored
|
||||
value, _, exists, _ := mockBE.Get(ctx, "key1")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("value1"), value)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Set_Failure(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failSet = true
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Set_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failSet = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures to open circuit
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Set(ctx, "key", []byte("value"), 1*time.Minute)
|
||||
}
|
||||
|
||||
// Circuit should be open now
|
||||
err := cb.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Get Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Get_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First set a value
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
// Now get it through circuit breaker
|
||||
value, _, exists, err := cb.Get(ctx, "key1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("value1"), value)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Get_Failure(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failGet = true
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
_, _, _, err := cb.Get(ctx, "key1")
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Get_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failGet = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Get(ctx, "key")
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
_, _, _, err := cb.Get(ctx, "key2")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Delete Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Delete_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set a value first
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
// Delete through circuit breaker
|
||||
deleted, err := cb.Delete(ctx, "key1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
// Verify it's deleted
|
||||
exists, _ := mockBE.Exists(ctx, "key1")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Delete_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failDelete = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Delete(ctx, "key")
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
_, err := cb.Delete(ctx, "key2")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Exists Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Exists_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set a value first
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
// Check existence through circuit breaker
|
||||
exists, err := cb.Exists(ctx, "key1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Exists_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failExists = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Exists(ctx, "key")
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
_, err := cb.Exists(ctx, "key2")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Clear Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Clear_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set some values
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
mockBE.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
|
||||
// Clear through circuit breaker
|
||||
err := cb.Clear(ctx)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify cleared
|
||||
exists1, _ := mockBE.Exists(ctx, "key1")
|
||||
exists2, _ := mockBE.Exists(ctx, "key2")
|
||||
assert.False(t, exists1)
|
||||
assert.False(t, exists2)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Clear_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failClear = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Clear(ctx)
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
err := cb.Clear(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// GetStats Tests
|
||||
|
||||
func TestCircuitBreakerBackend_GetStats(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Perform some operations
|
||||
cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
cb.Get(ctx, "key1")
|
||||
|
||||
stats := cb.GetStats()
|
||||
|
||||
require.NotNil(t, stats)
|
||||
|
||||
// Should have circuit breaker stats
|
||||
assert.Contains(t, stats, "circuit_breaker")
|
||||
|
||||
cbStats, ok := stats["circuit_breaker"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
// Verify circuit breaker stats fields
|
||||
assert.Contains(t, cbStats, "state")
|
||||
assert.Contains(t, cbStats, "consecutive_failures")
|
||||
assert.Contains(t, cbStats, "total_requests")
|
||||
assert.Contains(t, cbStats, "total_failures")
|
||||
assert.Contains(t, cbStats, "success_rate")
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_GetStats_NilBackendStats(t *testing.T) {
|
||||
// Create a mock backend that returns nil stats
|
||||
mockBE := &mockBackendNilStats{}
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
stats := cb.GetStats()
|
||||
|
||||
require.NotNil(t, stats)
|
||||
assert.Contains(t, stats, "circuit_breaker")
|
||||
}
|
||||
|
||||
// mockBackendNilStats returns nil from GetStats
|
||||
type mockBackendNilStats struct {
|
||||
mockBackend
|
||||
}
|
||||
|
||||
func (m *mockBackendNilStats) GetStats() map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Ping_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.Ping(ctx)
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Ping_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failPing = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Ping(ctx)
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
err := cb.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Close Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Close(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
err := cb.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Circuit Recovery Test
|
||||
|
||||
func TestCircuitBreakerBackend_CircuitRecovery(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failSet = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 200 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures to open circuit
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Set(ctx, "key", []byte("value"), 1*time.Minute)
|
||||
}
|
||||
|
||||
// Verify circuit is open
|
||||
err := cb.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
// Fix the backend
|
||||
mockBE.mu.Lock()
|
||||
mockBE.failSet = false
|
||||
mockBE.mu.Unlock()
|
||||
|
||||
// Circuit should be in half-open state, allow a test request
|
||||
err = cb.Set(ctx, "key3", []byte("value3"), 1*time.Minute)
|
||||
|
||||
// After success threshold is met, circuit should close
|
||||
if err == nil {
|
||||
// Circuit recovered
|
||||
err2 := cb.Set(ctx, "key4", []byte("value4"), 1*time.Minute)
|
||||
assert.NoError(t, err2, "Circuit should be closed after recovery")
|
||||
}
|
||||
}
|
||||
+553
@@ -0,0 +1,553 @@
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestCircuitBreaker_StateTransitions tests state machine transitions
|
||||
func TestCircuitBreaker_StateTransitions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Initial state is closed", func(t *testing.T) {
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("Closed to Open after max failures", func(t *testing.T) {
|
||||
cb.Reset()
|
||||
|
||||
// Simulate failures
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("Open to HalfOpen after timeout", func(t *testing.T) {
|
||||
// Open the circuit
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should allow request and transition to half-open
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, StateHalfOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("HalfOpen to Closed after successful requests", func(t *testing.T) {
|
||||
// Open circuit then wait for half-open
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// First request transitions to half-open and succeeds
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
// Should be in half-open after first request
|
||||
state := cb.GetState()
|
||||
assert.True(t, state == StateHalfOpen || state == StateClosed,
|
||||
"After first successful request, should be half-open or potentially closed")
|
||||
|
||||
if state == StateHalfOpen {
|
||||
// Need more successful requests to close
|
||||
// The exact number depends on implementation but should be within HalfOpenMaxRequests
|
||||
for i := 0; i < config.HalfOpenMaxRequests; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
// After multiple successful requests, should eventually close
|
||||
finalState := cb.GetState()
|
||||
assert.True(t, finalState == StateClosed || finalState == StateHalfOpen,
|
||||
"After successful requests, circuit should transition towards closed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HalfOpen to Open on failure", func(t *testing.T) {
|
||||
// Open circuit then wait for half-open
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// First call transitions to half-open, second failure reopens
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_OpenCircuitBlocks tests that open circuit blocks requests
|
||||
func TestCircuitBreaker_OpenCircuitBlocks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures to open circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Requests should be blocked
|
||||
err := cb.Execute(ctx, func() error {
|
||||
t.Fatal("Should not execute function when circuit is open")
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_HalfOpenMaxRequests tests max requests in half-open state
|
||||
func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open circuit then wait for half-open
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// After timeout, circuit should allow transition to half-open
|
||||
// Execute HalfOpenMaxRequests successful requests
|
||||
successCount := 0
|
||||
for i := 0; i < config.HalfOpenMaxRequests; i++ {
|
||||
err := cb.Execute(ctx, func() error {
|
||||
successCount++
|
||||
return nil
|
||||
})
|
||||
// Should allow up to HalfOpenMaxRequests
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify we executed the expected number
|
||||
assert.Equal(t, config.HalfOpenMaxRequests, successCount)
|
||||
|
||||
// After successful requests, circuit behavior depends on implementation
|
||||
// It could close (allowing more requests) or stay half-open (blocking)
|
||||
// The important thing is that we allowed exactly HalfOpenMaxRequests
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_SuccessResetsFailures tests failure counter reset
|
||||
func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Have some failures (but less than max)
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats := cb.Stats()
|
||||
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
|
||||
|
||||
// One success should reset failures
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats = cb.Stats()
|
||||
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_ConcurrentAccess tests thread safety
|
||||
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 10,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 5,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Mix of successes and failures
|
||||
cb.Execute(ctx, func() error {
|
||||
if (id+j)%3 == 0 {
|
||||
return errors.New("test error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Random state checks
|
||||
_ = cb.GetState()
|
||||
_ = cb.Stats()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should complete without panics
|
||||
stats := cb.Stats()
|
||||
assert.NotNil(t, stats)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Stats tests statistics tracking
|
||||
func TestCircuitBreaker_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Execute some requests
|
||||
cb.Execute(ctx, func() error { return nil }) // Success
|
||||
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
|
||||
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
|
||||
|
||||
stats := cb.Stats()
|
||||
|
||||
assert.Equal(t, StateClosed, stats.State)
|
||||
assert.Equal(t, int64(3), stats.TotalRequests)
|
||||
assert.Equal(t, int64(2), stats.TotalFailures)
|
||||
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Reset tests circuit reset
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open the circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Reset
|
||||
cb.Reset()
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats := cb.Stats()
|
||||
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
|
||||
assert.Equal(t, int64(0), stats.TotalRequests)
|
||||
assert.Equal(t, int64(0), stats.TotalFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_StateChangeCallback tests state change notifications
|
||||
func TestCircuitBreaker_StateChangeCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var transitions []string
|
||||
var mu sync.Mutex
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 50 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
OnStateChange: func(from, to State) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
transitions = append(transitions, from.String()+"->"+to.String())
|
||||
},
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger state transitions
|
||||
// Closed -> Open
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
// Should be open now
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait for timeout to allow half-open transition
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Open -> HalfOpen on first request after timeout
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Execute more successful requests to trigger HalfOpen -> Closed
|
||||
for i := 0; i < config.HalfOpenMaxRequests-1; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.Contains(t, transitions, "closed->open")
|
||||
assert.Contains(t, transitions, "open->half-open")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_IsHealthy tests health check
|
||||
func TestCircuitBreaker_IsHealthy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, cb.IsHealthy())
|
||||
|
||||
// Open circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
assert.False(t, cb.IsHealthy(), "Should not be healthy when open")
|
||||
|
||||
// Wait for timeout and allow successful request
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Should be healthy after recovery
|
||||
assert.True(t, cb.IsHealthy(), "Should be healthy after recovery")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_RapidFailures tests rapid consecutive failures
|
||||
func TestCircuitBreaker_RapidFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
Timeout: 200 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Rapid failures
|
||||
for i := 0; i < 10; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("rapid error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
stats := cb.Stats()
|
||||
assert.GreaterOrEqual(t, stats.TotalFailures, int64(5))
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_TimeoutAccuracy tests timeout precision
|
||||
func TestCircuitBreaker_TimeoutAccuracy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
timeout := 100 * time.Millisecond
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: timeout,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open circuit
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait just before timeout
|
||||
time.Sleep(timeout - 20*time.Millisecond)
|
||||
assert.False(t, cb.IsHealthy())
|
||||
|
||||
// Wait until after timeout
|
||||
time.Sleep(40 * time.Millisecond)
|
||||
// After timeout, AllowRequest should return true for transition to half-open
|
||||
assert.True(t, cb.AllowRequest())
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_DefaultConfig tests default configuration
|
||||
func TestCircuitBreaker_DefaultConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cb := NewCircuitBreaker(nil) // Should use defaults
|
||||
|
||||
assert.NotNil(t, cb)
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
|
||||
// Verify defaults by triggering circuit breaker behavior
|
||||
ctx := context.Background()
|
||||
|
||||
// Test that it takes 5 failures to open (default MaxFailures)
|
||||
for i := 0; i < 4; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateClosed, cb.GetState(), "Should still be closed after 4 failures")
|
||||
|
||||
// 5th failure should open it
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
assert.Equal(t, StateOpen, cb.GetState(), "Should be open after 5 failures (default threshold)")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_StateString tests state string representation
|
||||
func TestCircuitBreaker_StateString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "closed", StateClosed.String())
|
||||
assert.Equal(t, "open", StateOpen.String())
|
||||
assert.Equal(t, "half-open", StateHalfOpen.String())
|
||||
assert.Equal(t, "unknown", State(999).String())
|
||||
}
|
||||
|
||||
// Benchmark circuit breaker performance
|
||||
func BenchmarkCircuitBreaker_Execute(b *testing.B) {
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 100,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 10,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_ExecuteWithFailures(b *testing.B) {
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 1000,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 10,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
if i%10 == 0 {
|
||||
return errors.New("error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
+375
@@ -0,0 +1,375 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthStatus represents the health status of a backend
|
||||
type HealthStatus int32
|
||||
|
||||
const (
|
||||
// HealthUnknown indicates unknown health status
|
||||
HealthUnknown HealthStatus = iota
|
||||
|
||||
// HealthHealthy indicates the backend is healthy
|
||||
HealthHealthy
|
||||
|
||||
// HealthDegraded indicates the backend is degraded but operational
|
||||
HealthDegraded
|
||||
|
||||
// HealthUnhealthy indicates the backend is unhealthy
|
||||
HealthUnhealthy
|
||||
)
|
||||
|
||||
// String returns the string representation of the health status
|
||||
func (h HealthStatus) String() string {
|
||||
switch h {
|
||||
case HealthHealthy:
|
||||
return "healthy"
|
||||
case HealthDegraded:
|
||||
return "degraded"
|
||||
case HealthUnhealthy:
|
||||
return "unhealthy"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// HealthCheckConfig holds configuration for the health checker
|
||||
type HealthCheckConfig struct {
|
||||
// CheckInterval is how often to check health
|
||||
CheckInterval time.Duration
|
||||
|
||||
// Timeout is the timeout for each health check
|
||||
Timeout time.Duration
|
||||
|
||||
// HealthyThreshold is the number of consecutive successes to become healthy
|
||||
HealthyThreshold int
|
||||
|
||||
// UnhealthyThreshold is the number of consecutive failures to become unhealthy
|
||||
UnhealthyThreshold int
|
||||
|
||||
// DegradedThreshold is the latency threshold in ms to mark as degraded
|
||||
DegradedThreshold time.Duration
|
||||
|
||||
// OnStatusChange is called when health status changes
|
||||
OnStatusChange func(from, to HealthStatus)
|
||||
|
||||
// CheckFunc is the function to check health
|
||||
CheckFunc func(ctx context.Context) error
|
||||
}
|
||||
|
||||
// DefaultHealthCheckConfig returns default configuration
|
||||
func DefaultHealthCheckConfig() *HealthCheckConfig {
|
||||
return &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
HealthyThreshold: 3,
|
||||
UnhealthyThreshold: 3,
|
||||
DegradedThreshold: 100 * time.Millisecond,
|
||||
}
|
||||
}
|
||||
|
||||
// HealthChecker monitors the health of a backend
|
||||
type HealthChecker struct {
|
||||
config *HealthCheckConfig
|
||||
|
||||
// Status tracking
|
||||
status atomic.Int32
|
||||
consecutiveSuccesses atomic.Int32
|
||||
consecutiveFailures atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastCheckTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureTime time.Time
|
||||
averageLatency atomic.Int64
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalChecks atomic.Int64
|
||||
totalSuccesses atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
statusChanges atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
ticker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
stopped atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewHealthChecker creates a new health checker
|
||||
func NewHealthChecker(config *HealthCheckConfig) *HealthChecker {
|
||||
if config == nil {
|
||||
config = DefaultHealthCheckConfig()
|
||||
}
|
||||
|
||||
hc := &HealthChecker{
|
||||
config: config,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
hc.status.Store(int32(HealthUnknown))
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
// Start begins health checking
|
||||
func (hc *HealthChecker) Start() {
|
||||
if hc.stopped.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
hc.ticker = time.NewTicker(hc.config.CheckInterval)
|
||||
hc.wg.Add(1)
|
||||
go hc.checkLoop()
|
||||
}
|
||||
|
||||
// Stop stops health checking
|
||||
func (hc *HealthChecker) Stop() {
|
||||
if hc.stopped.Swap(true) {
|
||||
return // Already stopped
|
||||
}
|
||||
|
||||
close(hc.stopChan)
|
||||
if hc.ticker != nil {
|
||||
hc.ticker.Stop()
|
||||
}
|
||||
hc.wg.Wait()
|
||||
}
|
||||
|
||||
// checkLoop runs periodic health checks
|
||||
func (hc *HealthChecker) checkLoop() {
|
||||
defer hc.wg.Done()
|
||||
|
||||
// Initial check - log error but continue
|
||||
if err := hc.Check(context.Background()); err != nil {
|
||||
// Error is already tracked in Check() method, no need to log again
|
||||
_ = err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hc.stopChan:
|
||||
return
|
||||
case <-hc.ticker.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hc.config.Timeout)
|
||||
if err := hc.Check(ctx); err != nil {
|
||||
// Error is already tracked in Check() method, no need to log again
|
||||
_ = err
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check performs a health check
|
||||
func (hc *HealthChecker) Check(ctx context.Context) error {
|
||||
if hc.config.CheckFunc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
hc.totalChecks.Add(1)
|
||||
start := time.Now()
|
||||
|
||||
// Create timeout context if not already set
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, hc.config.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Perform health check
|
||||
err := hc.config.CheckFunc(ctx)
|
||||
latency := time.Since(start)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastCheckTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Update average latency
|
||||
hc.updateAverageLatency(latency)
|
||||
|
||||
if err != nil {
|
||||
hc.recordFailure()
|
||||
} else {
|
||||
hc.recordSuccess(latency)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// recordSuccess records a successful health check
|
||||
func (hc *HealthChecker) recordSuccess(latency time.Duration) {
|
||||
hc.totalSuccesses.Add(1)
|
||||
successes := hc.consecutiveSuccesses.Add(1)
|
||||
hc.consecutiveFailures.Store(0)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastSuccessTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
currentStatus := hc.GetStatus()
|
||||
newStatus := currentStatus
|
||||
|
||||
// Check if we should become healthy
|
||||
if successes >= int32(hc.config.HealthyThreshold) {
|
||||
if latency > hc.config.DegradedThreshold {
|
||||
newStatus = HealthDegraded
|
||||
} else {
|
||||
newStatus = HealthHealthy
|
||||
}
|
||||
}
|
||||
|
||||
if newStatus != currentStatus {
|
||||
hc.setStatus(newStatus)
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failed health check
|
||||
func (hc *HealthChecker) recordFailure() {
|
||||
hc.totalFailures.Add(1)
|
||||
failures := hc.consecutiveFailures.Add(1)
|
||||
hc.consecutiveSuccesses.Store(0)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastFailureTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Check if we should become unhealthy
|
||||
if failures >= int32(hc.config.UnhealthyThreshold) {
|
||||
hc.setStatus(HealthUnhealthy)
|
||||
}
|
||||
}
|
||||
|
||||
// updateAverageLatency updates the rolling average latency
|
||||
func (hc *HealthChecker) updateAverageLatency(latency time.Duration) {
|
||||
// Simple exponential moving average
|
||||
currentAvg := time.Duration(hc.averageLatency.Load())
|
||||
if currentAvg == 0 {
|
||||
hc.averageLatency.Store(int64(latency))
|
||||
} else {
|
||||
// Weight: 0.2 for new value, 0.8 for old average
|
||||
newAvg := (currentAvg*4 + latency) / 5
|
||||
hc.averageLatency.Store(int64(newAvg))
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus returns the current health status
|
||||
func (hc *HealthChecker) GetStatus() HealthStatus {
|
||||
return HealthStatus(hc.status.Load())
|
||||
}
|
||||
|
||||
// setStatus changes the health status
|
||||
func (hc *HealthChecker) setStatus(newStatus HealthStatus) {
|
||||
oldStatus := HealthStatus(hc.status.Swap(int32(newStatus)))
|
||||
|
||||
if oldStatus != newStatus {
|
||||
hc.statusChanges.Add(1)
|
||||
if hc.config.OnStatusChange != nil {
|
||||
hc.config.OnStatusChange(oldStatus, newStatus)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy or degraded
|
||||
func (hc *HealthChecker) IsHealthy() bool {
|
||||
status := hc.GetStatus()
|
||||
return status == HealthHealthy || status == HealthDegraded
|
||||
}
|
||||
|
||||
// LastCheckTime returns the time of the last health check
|
||||
func (hc *HealthChecker) LastCheckTime() time.Time {
|
||||
hc.timeMu.RLock()
|
||||
defer hc.timeMu.RUnlock()
|
||||
return hc.lastCheckTime
|
||||
}
|
||||
|
||||
// HealthScore returns a health score between 0.0 (unhealthy) and 1.0 (healthy)
|
||||
func (hc *HealthChecker) HealthScore() float64 {
|
||||
status := hc.GetStatus()
|
||||
switch status {
|
||||
case HealthHealthy:
|
||||
return 1.0
|
||||
case HealthDegraded:
|
||||
return 0.7
|
||||
case HealthUnhealthy:
|
||||
return 0.0
|
||||
default:
|
||||
return 0.5
|
||||
}
|
||||
}
|
||||
|
||||
// Stats returns health checker statistics
|
||||
func (hc *HealthChecker) Stats() HealthCheckerStats {
|
||||
hc.timeMu.RLock()
|
||||
lastCheck := hc.lastCheckTime
|
||||
lastSuccess := hc.lastSuccessTime
|
||||
lastFailure := hc.lastFailureTime
|
||||
hc.timeMu.RUnlock()
|
||||
|
||||
totalChecks := hc.totalChecks.Load()
|
||||
totalSuccesses := hc.totalSuccesses.Load()
|
||||
totalFailures := hc.totalFailures.Load()
|
||||
|
||||
successRate := float64(0)
|
||||
if totalChecks > 0 {
|
||||
successRate = float64(totalSuccesses) / float64(totalChecks)
|
||||
}
|
||||
|
||||
return HealthCheckerStats{
|
||||
Status: hc.GetStatus(),
|
||||
ConsecutiveSuccesses: hc.consecutiveSuccesses.Load(),
|
||||
ConsecutiveFailures: hc.consecutiveFailures.Load(),
|
||||
TotalChecks: totalChecks,
|
||||
TotalSuccesses: totalSuccesses,
|
||||
TotalFailures: totalFailures,
|
||||
SuccessRate: successRate,
|
||||
AverageLatency: time.Duration(hc.averageLatency.Load()),
|
||||
StatusChanges: hc.statusChanges.Load(),
|
||||
LastCheckTime: lastCheck,
|
||||
LastSuccessTime: lastSuccess,
|
||||
LastFailureTime: lastFailure,
|
||||
HealthScore: hc.HealthScore(),
|
||||
}
|
||||
}
|
||||
|
||||
// HealthCheckerStats holds statistics for the health checker
|
||||
type HealthCheckerStats struct {
|
||||
Status HealthStatus
|
||||
ConsecutiveSuccesses int32
|
||||
ConsecutiveFailures int32
|
||||
TotalChecks int64
|
||||
TotalSuccesses int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
AverageLatency time.Duration
|
||||
StatusChanges int64
|
||||
LastCheckTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastFailureTime time.Time
|
||||
HealthScore float64
|
||||
}
|
||||
|
||||
// Reset resets the health checker statistics
|
||||
func (hc *HealthChecker) Reset() {
|
||||
hc.status.Store(int32(HealthUnknown))
|
||||
hc.consecutiveSuccesses.Store(0)
|
||||
hc.consecutiveFailures.Store(0)
|
||||
hc.totalChecks.Store(0)
|
||||
hc.totalSuccesses.Store(0)
|
||||
hc.totalFailures.Store(0)
|
||||
hc.statusChanges.Store(0)
|
||||
hc.averageLatency.Store(0)
|
||||
|
||||
now := time.Now()
|
||||
hc.timeMu.Lock()
|
||||
hc.lastCheckTime = now
|
||||
hc.lastSuccessTime = now
|
||||
hc.lastFailureTime = now
|
||||
hc.timeMu.Unlock()
|
||||
}
|
||||
+215
@@ -0,0 +1,215 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
)
|
||||
|
||||
// HealthCheckBackend wraps a cache backend with health checking
|
||||
type HealthCheckBackend struct {
|
||||
backend backends.CacheBackend
|
||||
config *HealthCheckConfig
|
||||
|
||||
// Health tracking
|
||||
status atomic.Int32
|
||||
consecutiveFails atomic.Int32
|
||||
consecutiveOK atomic.Int32
|
||||
lastCheck time.Time
|
||||
checkMutex sync.RWMutex
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewHealthCheckBackend creates a new health check wrapped backend
|
||||
func NewHealthCheckBackend(b backends.CacheBackend, config *HealthCheckConfig) backends.CacheBackend {
|
||||
if config == nil {
|
||||
config = DefaultHealthCheckConfig()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
hc := &HealthCheckBackend{
|
||||
backend: b,
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Set initial status to healthy (optimistic)
|
||||
hc.status.Store(int32(HealthHealthy))
|
||||
|
||||
// Start health check routine
|
||||
hc.wg.Add(1)
|
||||
go hc.healthCheckLoop()
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
// Set stores a value and tracks health
|
||||
func (h *HealthCheckBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
// Allow operations even if unhealthy (may recover)
|
||||
err := h.backend.Set(ctx, key, value, ttl)
|
||||
h.recordResult(err == nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value and tracks health
|
||||
func (h *HealthCheckBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
value, ttl, exists, err := h.backend.Get(ctx, key)
|
||||
h.recordResult(err == nil)
|
||||
return value, ttl, exists, err
|
||||
}
|
||||
|
||||
// Delete removes a key and tracks health
|
||||
func (h *HealthCheckBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
deleted, err := h.backend.Delete(ctx, key)
|
||||
h.recordResult(err == nil)
|
||||
return deleted, err
|
||||
}
|
||||
|
||||
// Exists checks if a key exists and tracks health
|
||||
func (h *HealthCheckBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
exists, err := h.backend.Exists(ctx, key)
|
||||
h.recordResult(err == nil)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// Clear removes all keys and tracks health
|
||||
func (h *HealthCheckBackend) Clear(ctx context.Context) error {
|
||||
err := h.backend.Clear(ctx)
|
||||
h.recordResult(err == nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetStats returns statistics including health status
|
||||
func (h *HealthCheckBackend) GetStats() map[string]interface{} {
|
||||
stats := h.backend.GetStats()
|
||||
if stats == nil {
|
||||
stats = make(map[string]interface{})
|
||||
}
|
||||
|
||||
h.checkMutex.RLock()
|
||||
lastCheck := h.lastCheck
|
||||
h.checkMutex.RUnlock()
|
||||
|
||||
status := HealthStatus(h.status.Load())
|
||||
stats["health"] = map[string]interface{}{
|
||||
"status": status.String(),
|
||||
"consecutive_fails": h.consecutiveFails.Load(),
|
||||
"consecutive_ok": h.consecutiveOK.Load(),
|
||||
"last_check": lastCheck.Format(time.RFC3339),
|
||||
"time_since_check": time.Since(lastCheck).Seconds(),
|
||||
"check_interval_sec": h.config.CheckInterval.Seconds(),
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks backend health
|
||||
func (h *HealthCheckBackend) Ping(ctx context.Context) error {
|
||||
err := h.backend.Ping(ctx)
|
||||
h.recordResult(err == nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close shuts down the health checker and backend
|
||||
func (h *HealthCheckBackend) Close() error {
|
||||
// Stop health check routine
|
||||
h.cancel()
|
||||
|
||||
// Wait for routine to finish
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
h.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Finished normally
|
||||
case <-time.After(2 * time.Second):
|
||||
// Timeout
|
||||
}
|
||||
|
||||
return h.backend.Close()
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy
|
||||
func (h *HealthCheckBackend) IsHealthy() bool {
|
||||
status := HealthStatus(h.status.Load())
|
||||
return status == HealthHealthy || status == HealthDegraded
|
||||
}
|
||||
|
||||
// recordResult records the result of an operation for health tracking
|
||||
func (h *HealthCheckBackend) recordResult(success bool) {
|
||||
if success {
|
||||
fails := h.consecutiveFails.Swap(0)
|
||||
oks := h.consecutiveOK.Add(1)
|
||||
|
||||
// Check if we should transition to healthy
|
||||
if fails > 0 && oks >= int32(h.config.HealthyThreshold) {
|
||||
oldStatus := HealthStatus(h.status.Swap(int32(HealthHealthy)))
|
||||
if oldStatus != HealthHealthy && h.config.OnStatusChange != nil {
|
||||
h.config.OnStatusChange(oldStatus, HealthHealthy)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
oks := h.consecutiveOK.Swap(0)
|
||||
fails := h.consecutiveFails.Add(1)
|
||||
|
||||
// Check if we should transition to unhealthy
|
||||
if oks > 0 && fails >= int32(h.config.UnhealthyThreshold) {
|
||||
oldStatus := HealthStatus(h.status.Swap(int32(HealthUnhealthy)))
|
||||
if oldStatus != HealthUnhealthy && h.config.OnStatusChange != nil {
|
||||
h.config.OnStatusChange(oldStatus, HealthUnhealthy)
|
||||
}
|
||||
} else if fails >= int32(h.config.UnhealthyThreshold)*2 {
|
||||
// Severely degraded
|
||||
h.status.Store(int32(HealthUnhealthy))
|
||||
} else if fails >= int32(h.config.UnhealthyThreshold) {
|
||||
// Degraded but still trying
|
||||
h.status.Store(int32(HealthDegraded))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthCheckLoop runs periodic health checks
|
||||
func (h *HealthCheckBackend) healthCheckLoop() {
|
||||
defer h.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(h.config.CheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Do initial check
|
||||
h.performHealthCheck()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
h.performHealthCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthCheck performs a single health check
|
||||
func (h *HealthCheckBackend) performHealthCheck() {
|
||||
h.checkMutex.Lock()
|
||||
h.lastCheck = time.Now()
|
||||
h.checkMutex.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), h.config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
err := h.backend.Ping(ctx)
|
||||
h.recordResult(err == nil)
|
||||
}
|
||||
+447
@@ -0,0 +1,447 @@
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestHealthChecker_StatusTransitions tests health status transitions
|
||||
func TestHealthChecker_StatusTransitions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
var shouldFail atomic.Bool
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
if shouldFail.Load() {
|
||||
return errors.New("health check failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
// Initially unknown
|
||||
assert.Equal(t, HealthUnknown, hc.GetStatus())
|
||||
|
||||
// Trigger failures
|
||||
shouldFail.Store(true)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Should be unhealthy after threshold failures
|
||||
status := hc.GetStatus()
|
||||
assert.True(t, status == HealthUnhealthy || status == HealthDegraded)
|
||||
|
||||
// Recover
|
||||
shouldFail.Store(false)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should recover towards healthy
|
||||
finalStatus := hc.GetStatus()
|
||||
assert.True(t, finalStatus == HealthHealthy || finalStatus == HealthDegraded || finalStatus == HealthUnknown)
|
||||
}
|
||||
|
||||
// TestHealthChecker_InitialState tests initial health status
|
||||
func TestHealthChecker_InitialState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
hc := NewHealthChecker(config)
|
||||
assert.Equal(t, HealthUnknown, hc.GetStatus())
|
||||
assert.False(t, hc.IsHealthy())
|
||||
}
|
||||
|
||||
// TestHealthChecker_ForceCheck tests manual health check trigger
|
||||
func TestHealthChecker_ForceCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 10 * time.Second, // Long interval
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
initialCount := callCount.Load()
|
||||
|
||||
// Force check
|
||||
hc.Check(context.Background())
|
||||
|
||||
// Should have been called
|
||||
assert.Greater(t, callCount.Load(), initialCount)
|
||||
}
|
||||
|
||||
// TestHealthChecker_StatusChangeCallback tests status change notifications
|
||||
func TestHealthChecker_StatusChangeCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var transitions []string
|
||||
var mu sync.Mutex
|
||||
var shouldFail atomic.Bool
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
if shouldFail.Load() {
|
||||
return errors.New("health check failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 2,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
OnStatusChange: func(from, to HealthStatus) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
transitions = append(transitions, from.String()+"->"+to.String())
|
||||
},
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
// Trigger failures
|
||||
shouldFail.Store(true)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Recover
|
||||
shouldFail.Store(false)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Should have status transitions
|
||||
assert.NotEmpty(t, transitions)
|
||||
}
|
||||
|
||||
// TestHealthChecker_Stats tests statistics tracking
|
||||
func TestHealthChecker_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
if callCount.Load()%2 == 0 {
|
||||
return errors.New("failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 20 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 5,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
stats := hc.Stats()
|
||||
|
||||
assert.Greater(t, stats.TotalChecks, int64(0))
|
||||
assert.Greater(t, stats.TotalFailures, int64(0))
|
||||
assert.Greater(t, stats.SuccessRate, 0.0)
|
||||
assert.Less(t, stats.SuccessRate, 1.0)
|
||||
}
|
||||
|
||||
// TestHealthChecker_Timeout tests check timeout handling
|
||||
func TestHealthChecker_Timeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
// Simulate slow check
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond, // Short timeout
|
||||
UnhealthyThreshold: 2,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be unhealthy due to timeouts
|
||||
status := hc.GetStatus()
|
||||
assert.NotEqual(t, HealthHealthy, status)
|
||||
}
|
||||
|
||||
// TestHealthChecker_ConcurrentAccess tests thread safety
|
||||
func TestHealthChecker_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 10 * time.Millisecond,
|
||||
Timeout: 5 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 50; j++ {
|
||||
_ = hc.GetStatus()
|
||||
_ = hc.IsHealthy()
|
||||
_ = hc.Stats()
|
||||
hc.Check(context.Background())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
// Should complete without panics
|
||||
}
|
||||
|
||||
// TestHealthChecker_StopAndStart tests lifecycle management
|
||||
func TestHealthChecker_StopAndStart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 20 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
// Start
|
||||
hc.Start()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
count1 := callCount.Load()
|
||||
assert.Greater(t, count1, int32(0))
|
||||
|
||||
// Stop
|
||||
hc.Stop()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
count2 := callCount.Load()
|
||||
|
||||
// Should not have increased significantly after stop
|
||||
assert.Less(t, count2-count1, int32(3))
|
||||
}
|
||||
|
||||
// TestHealthChecker_DegradedState tests degraded status
|
||||
func TestHealthChecker_DegradedState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
count := callCount.Add(1)
|
||||
// Fail once, then succeed
|
||||
if count == 1 {
|
||||
return errors.New("single failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3, // Need 3 failures for unhealthy
|
||||
HealthyThreshold: 2, // Need 2 successes for healthy
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// After initial checks, status should be set (might be healthy or degraded based on execution)
|
||||
status := hc.GetStatus()
|
||||
assert.True(t, status != HealthUnknown, "Status should not be unknown after checks")
|
||||
}
|
||||
|
||||
// TestHealthChecker_DefaultConfig tests default configuration
|
||||
func TestHealthChecker_DefaultConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
assert.NotNil(t, hc)
|
||||
assert.Equal(t, HealthUnknown, hc.GetStatus())
|
||||
|
||||
// Verify default config was applied (we can't access private fields, so just check it works)
|
||||
assert.NotNil(t, hc)
|
||||
}
|
||||
|
||||
// TestHealthChecker_StatusString tests status string representation
|
||||
func TestHealthChecker_StatusString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "healthy", HealthHealthy.String())
|
||||
assert.Equal(t, "unhealthy", HealthUnhealthy.String())
|
||||
assert.Equal(t, "degraded", HealthDegraded.String())
|
||||
assert.Equal(t, "unknown", HealthStatus(999).String())
|
||||
}
|
||||
|
||||
// TestHealthChecker_RecoveryPattern tests typical failure and recovery
|
||||
func TestHealthChecker_RecoveryPattern(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var checkNumber atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
n := checkNumber.Add(1)
|
||||
// Fail checks 3-5, succeed others
|
||||
if n >= 3 && n <= 5 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var statusLog []HealthStatus
|
||||
var mu sync.Mutex
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
OnStatusChange: func(from, to HealthStatus) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
statusLog = append(statusLog, to)
|
||||
},
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Should see transitions through unhealthy and back to healthy
|
||||
assert.NotEmpty(t, statusLog)
|
||||
|
||||
// Final status should be healthy or degraded (recovered)
|
||||
finalStatus := hc.GetStatus()
|
||||
assert.True(t, finalStatus == HealthHealthy || finalStatus == HealthDegraded, "Should have recovered")
|
||||
}
|
||||
|
||||
// Benchmark health checker performance
|
||||
func BenchmarkHealthChecker_ForceCheck(b *testing.B) {
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 10 * time.Minute,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
hc.Check(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHealthChecker_Status(b *testing.B) {
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = hc.GetStatus()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user