Add redis support for distributed caching

This commit is contained in:
2025-10-19 01:05:41 +01:00
parent e45b06c86d
commit 92a058da83
238 changed files with 78024 additions and 124 deletions
+295
View File
@@ -31,6 +31,7 @@ summary: |
- Flexible configuration with multiple deployment scenarios
- Memory-efficient operation with automatic cleanup
- Extensive logging and debugging capabilities
- Redis cache support for multi-replica deployments with automatic failover
It supports various authentication scenarios including:
- Basic authentication with customizable callback and logout URLs
@@ -137,6 +138,42 @@ testData:
X-Custom-Header: "production"
X-API-Version: "v1"
# Example with Redis cache for multi-replica deployments
testDataWithRedis:
# Required OIDC parameters (same as standard configuration)
providerURL: https://auth.example.com
clientID: your-client-id
clientSecret: your-client-secret
callbackURL: /oauth2/callback
sessionEncryptionKey: your-64-character-encryption-key-at-least-32-bytes
# Standard optional parameters
logLevel: info
allowedUserDomains:
- company.com
# Redis cache configuration for multi-replica support
redis:
enabled: true # Enable Redis caching
address: "redis:6379" # Redis server address
password: "redis-password" # Redis authentication password
db: 0 # Redis database number (0-15)
keyPrefix: "traefikoidc:" # Prefix for all Redis keys
cacheMode: "hybrid" # Cache mode: redis, hybrid, or memory
poolSize: 20 # Maximum number of connections
connectTimeout: 5 # Connection timeout in seconds
readTimeout: 3 # Read operation timeout
writeTimeout: 3 # Write operation timeout
enableTLS: false # Use TLS for Redis connection
tlsSkipVerify: false # Skip TLS certificate verification
hybridL1Size: 500 # L1 cache size for hybrid mode
hybridL1MemoryMB: 10 # L1 memory limit for hybrid mode
enableCircuitBreaker: true # Enable circuit breaker
circuitBreakerThreshold: 5 # Failures before opening circuit
circuitBreakerTimeout: 60 # Timeout before retry (seconds)
enableHealthCheck: true # Enable periodic health checks
healthCheckInterval: 30 # Health check interval (seconds)
# --- Common Configuration Examples ---
#
# 🔒 HIGH-SECURITY CONFIGURATION
@@ -1138,3 +1175,261 @@ configuration:
Prevents your resources from being embedded on other sites.
required: false
redis:
type: object
description: |
Optional Redis cache configuration for multi-replica deployments.
When running multiple Traefik instances, Redis provides shared caching to:
- Prevent JTI replay detection false positives across replicas
- Share token verification results between instances
- Maintain consistent session state across the cluster
- Improve performance by reducing redundant OIDC provider calls
Features:
- Automatic failover to memory-only mode when Redis is unavailable
- Circuit breaker pattern for resilience against Redis failures
- Health checking with automatic recovery
- Multiple cache modes: redis-only, hybrid (L1 memory + L2 Redis), memory-only
- Configurable timeouts and connection pooling
- TLS support for secure Redis connections
The middleware gracefully handles Redis failures by falling back to in-memory
caching, ensuring your authentication flow continues even during Redis outages.
Example configuration:
```yaml
redis:
enabled: true
address: "redis:6379"
cacheMode: "hybrid"
enableCircuitBreaker: true
```
required: false
properties:
enabled:
type: boolean
description: |
Enable Redis caching for distributed session and token management.
When enabled, the middleware will attempt to connect to Redis and use it
for shared state across multiple Traefik instances.
Default: false
required: false
address:
type: string
description: |
Redis server address in host:port format.
Examples:
- "redis:6379" (Docker/Kubernetes service)
- "localhost:6379" (local Redis)
- "redis.example.com:6380" (custom host/port)
- "redis-cluster.default.svc.cluster.local:6379" (Kubernetes)
Required when Redis is enabled.
required: false
password:
type: string
description: |
Password for Redis authentication.
Leave empty if Redis doesn't require authentication.
For Kubernetes deployments, you can use secret references:
urn:k8s:secret:namespace:secret-name:key
Default: "" (no authentication)
required: false
db:
type: integer
description: |
Redis database number to use (0-15).
Different databases can be used to isolate data between environments.
Default: 0
required: false
keyPrefix:
type: string
description: |
Prefix for all Redis keys created by this middleware.
Useful for:
- Avoiding key collisions with other applications
- Identifying keys for monitoring/debugging
- Supporting multiple environments in the same Redis instance
Default: "traefikoidc:"
required: false
cacheMode:
type: string
description: |
Determines the caching strategy:
- "redis": Redis-only caching. All cache operations go directly to Redis.
Best for: Consistent state across all replicas, minimal memory usage.
- "hybrid": Two-tier caching with in-memory L1 and Redis L2.
Best for: High performance with shared state, reduced Redis load.
L1 provides fast local cache, L2 provides shared state.
- "memory": Memory-only caching (Redis disabled even if configured).
Best for: Single instance deployments, development/testing.
Default: "redis" (when Redis is enabled)
required: false
enum:
- redis
- hybrid
- memory
poolSize:
type: integer
description: |
Maximum number of socket connections to Redis.
Higher values allow more concurrent operations but consume more resources.
Recommendations:
- Small deployments: 10-20
- Medium deployments: 20-50
- Large deployments: 50-100
Default: 10
required: false
connectTimeout:
type: integer
description: |
Timeout in seconds for establishing new connections to Redis.
Should be higher than network latency but low enough to fail fast.
Default: 5 seconds
required: false
readTimeout:
type: integer
description: |
Timeout in seconds for Redis read operations.
Includes the time to send the command, wait for Redis to process it,
and receive the response.
Default: 3 seconds
required: false
writeTimeout:
type: integer
description: |
Timeout in seconds for Redis write operations.
Should account for network latency and Redis persistence settings.
Default: 3 seconds
required: false
enableTLS:
type: boolean
description: |
Enable TLS encryption for Redis connections.
Required when connecting to Redis instances that enforce TLS,
such as AWS ElastiCache with encryption in transit.
Default: false
required: false
tlsSkipVerify:
type: boolean
description: |
Skip TLS certificate verification for Redis connections.
⚠️ WARNING: Only use in development environments.
This option bypasses certificate validation and should never be used
in production as it's vulnerable to man-in-the-middle attacks.
Default: false
required: false
hybridL1Size:
type: integer
description: |
Maximum number of items in the L1 (in-memory) cache for hybrid mode.
Controls how many cache entries are kept in local memory before eviction.
Only applies when cacheMode is "hybrid".
Default: 500
required: false
hybridL1MemoryMB:
type: integer
description: |
Maximum memory in megabytes for L1 cache in hybrid mode.
The cache will start evicting items when this limit is approached.
Only applies when cacheMode is "hybrid".
Default: 10 MB
required: false
enableCircuitBreaker:
type: boolean
description: |
Enable circuit breaker pattern for Redis connection failures.
When enabled, the middleware will:
1. Track Redis operation failures
2. Open the circuit after threshold failures (stop trying Redis)
3. Fall back to in-memory caching
4. Periodically attempt to reconnect (half-open state)
5. Resume Redis operations when connection recovers
This prevents cascading failures and improves resilience.
Default: true
required: false
circuitBreakerThreshold:
type: integer
description: |
Number of consecutive Redis failures before opening the circuit.
Lower values make the system more sensitive to Redis issues,
higher values tolerate more failures before switching to fallback.
Default: 5
required: false
circuitBreakerTimeout:
type: integer
description: |
Time in seconds to wait before attempting to close the circuit.
After this timeout, the circuit breaker will allow one test request
to Redis. If successful, normal operations resume.
Default: 60 seconds
required: false
enableHealthCheck:
type: boolean
description: |
Enable periodic health checks for Redis connection.
Health checks:
- Run in the background at regular intervals
- Detect Redis availability without affecting request processing
- Automatically reconnect when Redis becomes available
- Update circuit breaker state based on health status
Default: true
required: false
healthCheckInterval:
type: integer
description: |
Interval in seconds between Redis health checks.
Lower values detect issues faster but increase Redis load.
Higher values reduce overhead but delay failure detection.
Default: 30 seconds
required: false
+154 -3
View File
@@ -133,6 +133,7 @@ The middleware supports the following configuration options:
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
>
@@ -520,12 +521,14 @@ When running multiple Traefik replicas with the OIDC plugin, you may encounter f
- Request → Replica B → JTI NOT in Replica B's cache ✓
- Request → Replica A → ❌ **FALSE POSITIVE**: "token replay detected"
**Solution**: Disable replay detection for distributed deployments:
**Solution 1 (Simple)**: Disable replay detection for distributed deployments:
```yaml
disableReplayDetection: true # Disable JTI replay detection for multi-replica setups
```
**Solution 2 (Recommended)**: Use Redis cache backend for shared state (see [Redis Cache](#redis-cache-optional) section)
**Security Note**: When `disableReplayDetection: true`:
- ✅ Token signatures still validated
- ✅ Expiration still checked
@@ -547,10 +550,158 @@ spec:
clientSecret: your-client-secret
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
disableReplayDetection: true # Required for multi-replica deployments
disableReplayDetection: true # Required for multi-replica deployments without Redis
```
**Recommendation**: For single-instance deployments, leave this setting at `false` (default) to maintain replay attack protection. For multi-replica deployments, set to `true` and consider implementing a shared cache backend (Redis/Memcached) if replay detection is required.
**Recommendation**: For single-instance deployments, leave this setting at `false` (default) to maintain replay attack protection. For multi-replica deployments, use the Redis cache backend for proper replay detection across all instances.
## Redis Cache (Optional)
The plugin supports optional Redis caching for multi-replica deployments. This solves issues with JTI replay detection and session management when running multiple Traefik instances behind a load balancer.
### Why Use Redis Cache?
When running multiple Traefik replicas, each instance maintains its own in-memory cache for:
- JTI (JWT Token ID) replay detection
- Session data
- Token metadata
Without a shared cache, you may experience:
- False positive replay detection errors
- Session inconsistencies between replicas
- Users needing to re-authenticate when hitting different instances
### Basic Configuration
Redis is configured through Traefik's dynamic configuration (YAML, labels, etc.):
```yaml
# Enable Redis cache in your middleware configuration
redis:
enabled: true
address: "localhost:6379"
password: "your-password" # Optional
db: 0
keyPrefix: "traefikoidc:"
```
### Configuration Priority
The plugin uses the following priority for Redis configuration:
1. **Traefik Dynamic Configuration** (PRIMARY) - Configure via YAML files or Docker/Kubernetes labels
2. **Environment Variables** (FALLBACK) - Used only when not set in Traefik config
This approach allows you to manage all settings through Traefik's configuration system while maintaining backward compatibility with environment variables.
### Configuration Options
| Parameter | Description | Default | Example |
|-----------|-------------|---------|---------|
| `enabled` | Enable Redis caching | `false` | `true` |
| `address` | Redis server address | - | `redis:6379` |
| `password` | Redis password | - | `secret` |
| `db` | Database number | `0` | `1` |
| `keyPrefix` | Key prefix for namespacing | `traefikoidc:` | `myapp:` |
| `cacheMode` | Cache mode: `redis`, `hybrid`, `memory` | `redis` | `hybrid` |
| `poolSize` | Connection pool size | `10` | `20` |
| `connectTimeout` | Connection timeout (seconds) | `5` | `10` |
| `readTimeout` | Read timeout (seconds) | `3` | `5` |
| `writeTimeout` | Write timeout (seconds) | `3` | `5` |
| `enableTLS` | Enable TLS | `false` | `true` |
| `tlsSkipVerify` | Skip TLS verification | `false` | `true` |
| `enableCircuitBreaker` | Circuit breaker for failures | `true` | `true` |
| `circuitBreakerThreshold` | Failures before circuit opens | `5` | `10` |
| `circuitBreakerTimeout` | Circuit reset timeout (seconds) | `60` | `30` |
| `enableHealthCheck` | Periodic health checks | `true` | `true` |
| `healthCheckInterval` | Health check interval (seconds) | `30` | `60` |
### Environment Variables (Fallback)
If not configured through Traefik, these environment variables can be used as fallback:
- `REDIS_ENABLED` - Enable Redis cache
- `REDIS_ADDRESS` - Redis server address
- `REDIS_PASSWORD` - Redis password
- `REDIS_DB` - Database number
- `REDIS_KEY_PREFIX` - Key prefix
- `REDIS_CACHE_MODE` - Cache mode
- `REDIS_POOL_SIZE` - Connection pool size
- `REDIS_CONNECT_TIMEOUT` - Connection timeout
- `REDIS_READ_TIMEOUT` - Read timeout
- `REDIS_WRITE_TIMEOUT` - Write timeout
- `REDIS_ENABLE_TLS` - Enable TLS
- `REDIS_TLS_SKIP_VERIFY` - Skip TLS verification
### Cache Modes
The plugin supports three cache modes:
- **memory** (default): In-memory cache only, suitable for single-instance deployments
- **redis**: Redis-only cache, all data stored in Redis
- **hybrid**: Two-tier caching with local memory cache + Redis backend for optimal performance
### Example Configurations
#### Docker Compose with Redis
```yaml
services:
redis:
image: redis:alpine
command: redis-server --requirepass yourpassword
traefik:
image: traefik:v3.2
# ... rest of your Traefik configuration
labels:
# Configure the OIDC middleware with Redis
- "traefik.http.middlewares.oidc.plugin.traefikoidc.clientID=your-client-id"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.clientSecret=your-secret"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
# Redis configuration via labels
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=yourpassword"
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
```
#### Kubernetes with Redis
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-redis
spec:
plugin:
traefikoidc:
providerURL: https://accounts.google.com
clientID: your-client-id
clientSecret: your-client-secret
sessionEncryptionKey: your-encryption-key
callbackURL: /oauth2/callback
redis:
enabled: true
address: "redis-service.redis-namespace:6379"
password: "urn:k8s:secret:redis-secret:password"
db: 0
keyPrefix: "traefikoidc"
cacheMode: "hybrid"
```
### Advanced Redis Configuration
See [Redis Cache Documentation](docs/REDIS_CACHE.md) for:
- Detailed architecture overview
- High availability setup with Redis Sentinel
- Redis Cluster configuration
- Performance tuning guidelines
- Monitoring and observability
- Troubleshooting guide
- Migration from memory-only cache
## Usage Examples
+28 -1
View File
@@ -21,10 +21,37 @@ var (
)
// GetGlobalCacheManager returns a singleton CacheManager instance
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
return GetGlobalCacheManagerWithConfig(wg, nil)
}
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
cacheManagerInitOnce.Do(func() {
var redisConfig *RedisConfig
var logger *Logger
if config != nil {
logger = NewLogger(config.LogLevel)
// Initialize Redis config if not present
if config.Redis == nil {
config.Redis = &RedisConfig{}
}
// Apply environment variable fallbacks for fields not set in config
// This allows env vars to be used as optional overrides
config.Redis.ApplyEnvFallbacks()
// Apply defaults after env fallbacks
config.Redis.ApplyDefaults()
redisConfig = config.Redis
}
globalCacheManagerInstance = &CacheManager{
manager: GetUniversalCacheManager(nil),
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
}
})
return globalCacheManagerInstance
+297
View File
@@ -0,0 +1,297 @@
// Package config provides configuration structures for the Traefik OIDC plugin.
package config
import (
"os"
"strconv"
"strings"
"time"
)
// RedisMode represents the Redis deployment mode
type RedisMode string
const (
// RedisModeStandalone represents a single Redis instance
RedisModeStandalone RedisMode = "standalone"
// RedisModeCluster represents Redis cluster mode
RedisModeCluster RedisMode = "cluster"
// RedisModeSentinel represents Redis sentinel mode
RedisModeSentinel RedisMode = "sentinel"
)
// RedisConfig holds Redis cache backend configuration
type RedisConfig struct {
// Enabled indicates if Redis backend should be used
Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"`
// Mode specifies the Redis deployment mode
Mode RedisMode `json:"mode,omitempty" yaml:"mode,omitempty"`
// === Standalone Configuration ===
// Addr is the Redis server address (host:port)
Addr string `json:"addr,omitempty" yaml:"addr,omitempty"`
// Password for Redis authentication
Password string `json:"password,omitempty" yaml:"password,omitempty"`
// DB is the database number (0-15)
DB int `json:"db,omitempty" yaml:"db,omitempty"`
// === Cluster Configuration ===
// ClusterAddrs is the list of cluster node addresses
ClusterAddrs []string `json:"clusterAddrs,omitempty" yaml:"clusterAddrs,omitempty"`
// === Sentinel Configuration ===
// MasterName is the name of the master instance
MasterName string `json:"masterName,omitempty" yaml:"masterName,omitempty"`
// SentinelAddrs is the list of sentinel addresses
SentinelAddrs []string `json:"sentinelAddrs,omitempty" yaml:"sentinelAddrs,omitempty"`
// SentinelPassword is the password for sentinel authentication
SentinelPassword string `json:"sentinelPassword,omitempty" yaml:"sentinelPassword,omitempty"`
// === Connection Pool Settings ===
// PoolSize is the maximum number of socket connections
PoolSize int `json:"poolSize,omitempty" yaml:"poolSize,omitempty"`
// MinIdleConns is the minimum number of idle connections
MinIdleConns int `json:"minIdleConns,omitempty" yaml:"minIdleConns,omitempty"`
// MaxRetries is the maximum number of retries before giving up
MaxRetries int `json:"maxRetries,omitempty" yaml:"maxRetries,omitempty"`
// === Timeouts ===
// DialTimeout is the timeout for establishing new connections
DialTimeout time.Duration `json:"dialTimeout,omitempty" yaml:"dialTimeout,omitempty"`
// ReadTimeout is the timeout for socket reads
ReadTimeout time.Duration `json:"readTimeout,omitempty" yaml:"readTimeout,omitempty"`
// WriteTimeout is the timeout for socket writes
WriteTimeout time.Duration `json:"writeTimeout,omitempty" yaml:"writeTimeout,omitempty"`
// PoolTimeout is the timeout for connection pool
PoolTimeout time.Duration `json:"poolTimeout,omitempty" yaml:"poolTimeout,omitempty"`
// ConnMaxIdleTime is the maximum amount of time a connection may be idle
ConnMaxIdleTime time.Duration `json:"connMaxIdleTime,omitempty" yaml:"connMaxIdleTime,omitempty"`
// ConnMaxLifetime is the maximum lifetime of a connection
ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty" yaml:"connMaxLifetime,omitempty"`
// === Key Management ===
// KeyPrefix is the prefix for all Redis keys
KeyPrefix string `json:"keyPrefix,omitempty" yaml:"keyPrefix,omitempty"`
// === TLS Configuration ===
// TLSEnabled enables TLS for Redis connections
TLSEnabled bool `json:"tlsEnabled,omitempty" yaml:"tlsEnabled,omitempty"`
// TLSInsecureSkipVerify skips TLS certificate verification
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify,omitempty" yaml:"tlsInsecureSkipVerify,omitempty"`
// === Resilience Settings ===
// EnableCircuitBreaker enables circuit breaker for Redis operations
EnableCircuitBreaker bool `json:"enableCircuitBreaker,omitempty" yaml:"enableCircuitBreaker,omitempty"`
// CircuitBreakerMaxFailures is the number of failures before opening circuit
CircuitBreakerMaxFailures int `json:"circuitBreakerMaxFailures,omitempty" yaml:"circuitBreakerMaxFailures,omitempty"`
// CircuitBreakerTimeout is how long the circuit stays open
CircuitBreakerTimeout time.Duration `json:"circuitBreakerTimeout,omitempty" yaml:"circuitBreakerTimeout,omitempty"`
// EnableHealthCheck enables periodic health checks
EnableHealthCheck bool `json:"enableHealthCheck,omitempty" yaml:"enableHealthCheck,omitempty"`
// HealthCheckInterval is how often to check Redis health
HealthCheckInterval time.Duration `json:"healthCheckInterval,omitempty" yaml:"healthCheckInterval,omitempty"`
}
// DefaultRedisConfig returns default Redis configuration
func DefaultRedisConfig() *RedisConfig {
return &RedisConfig{
Enabled: false,
Mode: RedisModeStandalone,
Addr: "localhost:6379",
DB: 0,
PoolSize: 10,
MinIdleConns: 2,
MaxRetries: 3,
DialTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
PoolTimeout: 4 * time.Second,
ConnMaxIdleTime: 5 * time.Minute,
ConnMaxLifetime: 30 * time.Minute,
KeyPrefix: "traefikoidc:",
TLSEnabled: false,
TLSInsecureSkipVerify: false,
EnableCircuitBreaker: true,
CircuitBreakerMaxFailures: 5,
CircuitBreakerTimeout: 30 * time.Second,
EnableHealthCheck: true,
HealthCheckInterval: 30 * time.Second,
}
}
// LoadFromEnv loads Redis configuration from environment variables
func (c *RedisConfig) LoadFromEnv() {
// Enable Redis if environment variable is set
if enabled := os.Getenv("REDIS_ENABLED"); enabled != "" {
c.Enabled = strings.ToLower(enabled) == "true"
}
// Mode
if mode := os.Getenv("REDIS_MODE"); mode != "" {
c.Mode = RedisMode(strings.ToLower(mode))
}
// Standalone configuration
if addr := os.Getenv("REDIS_ADDR"); addr != "" {
c.Addr = addr
}
if password := os.Getenv("REDIS_PASSWORD"); password != "" {
c.Password = password
}
if db := os.Getenv("REDIS_DB"); db != "" {
if dbNum, err := strconv.Atoi(db); err == nil {
c.DB = dbNum
}
}
// Cluster configuration
if clusterAddrs := os.Getenv("REDIS_CLUSTER_ADDRS"); clusterAddrs != "" {
c.ClusterAddrs = strings.Split(clusterAddrs, ",")
for i := range c.ClusterAddrs {
c.ClusterAddrs[i] = strings.TrimSpace(c.ClusterAddrs[i])
}
}
// Sentinel configuration
if masterName := os.Getenv("REDIS_MASTER_NAME"); masterName != "" {
c.MasterName = masterName
}
if sentinelAddrs := os.Getenv("REDIS_SENTINEL_ADDRS"); sentinelAddrs != "" {
c.SentinelAddrs = strings.Split(sentinelAddrs, ",")
for i := range c.SentinelAddrs {
c.SentinelAddrs[i] = strings.TrimSpace(c.SentinelAddrs[i])
}
}
if sentinelPassword := os.Getenv("REDIS_SENTINEL_PASSWORD"); sentinelPassword != "" {
c.SentinelPassword = sentinelPassword
}
// Connection pool settings
if poolSize := os.Getenv("REDIS_POOL_SIZE"); poolSize != "" {
if size, err := strconv.Atoi(poolSize); err == nil {
c.PoolSize = size
}
}
if minIdleConns := os.Getenv("REDIS_MIN_IDLE_CONNS"); minIdleConns != "" {
if conns, err := strconv.Atoi(minIdleConns); err == nil {
c.MinIdleConns = conns
}
}
if maxRetries := os.Getenv("REDIS_MAX_RETRIES"); maxRetries != "" {
if retries, err := strconv.Atoi(maxRetries); err == nil {
c.MaxRetries = retries
}
}
// Timeouts
if dialTimeout := os.Getenv("REDIS_DIAL_TIMEOUT"); dialTimeout != "" {
if timeout, err := time.ParseDuration(dialTimeout); err == nil {
c.DialTimeout = timeout
}
}
if readTimeout := os.Getenv("REDIS_READ_TIMEOUT"); readTimeout != "" {
if timeout, err := time.ParseDuration(readTimeout); err == nil {
c.ReadTimeout = timeout
}
}
if writeTimeout := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeout != "" {
if timeout, err := time.ParseDuration(writeTimeout); err == nil {
c.WriteTimeout = timeout
}
}
// Key prefix
if keyPrefix := os.Getenv("REDIS_KEY_PREFIX"); keyPrefix != "" {
c.KeyPrefix = keyPrefix
}
// TLS settings
if tlsEnabled := os.Getenv("REDIS_TLS_ENABLED"); tlsEnabled != "" {
c.TLSEnabled = strings.ToLower(tlsEnabled) == "true"
}
if tlsInsecure := os.Getenv("REDIS_TLS_INSECURE_SKIP_VERIFY"); tlsInsecure != "" {
c.TLSInsecureSkipVerify = strings.ToLower(tlsInsecure) == "true"
}
// Resilience settings
if enableCB := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCB != "" {
c.EnableCircuitBreaker = strings.ToLower(enableCB) == "true"
}
if cbMaxFailures := os.Getenv("REDIS_CIRCUIT_BREAKER_MAX_FAILURES"); cbMaxFailures != "" {
if failures, err := strconv.Atoi(cbMaxFailures); err == nil {
c.CircuitBreakerMaxFailures = failures
}
}
if cbTimeout := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeout != "" {
if timeout, err := time.ParseDuration(cbTimeout); err == nil {
c.CircuitBreakerTimeout = timeout
}
}
if enableHC := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHC != "" {
c.EnableHealthCheck = strings.ToLower(enableHC) == "true"
}
if hcInterval := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcInterval != "" {
if interval, err := time.ParseDuration(hcInterval); err == nil {
c.HealthCheckInterval = interval
}
}
}
// Validate checks if the configuration is valid
func (c *RedisConfig) Validate() error {
if !c.Enabled {
return nil
}
switch c.Mode {
case RedisModeStandalone:
if c.Addr == "" {
return &ConfigError{Field: "addr", Message: "Redis address is required for standalone mode"}
}
case RedisModeCluster:
if len(c.ClusterAddrs) == 0 {
return &ConfigError{Field: "clusterAddrs", Message: "At least one cluster address is required"}
}
case RedisModeSentinel:
if c.MasterName == "" {
return &ConfigError{Field: "masterName", Message: "Master name is required for sentinel mode"}
}
if len(c.SentinelAddrs) == 0 {
return &ConfigError{Field: "sentinelAddrs", Message: "At least one sentinel address is required"}
}
default:
return &ConfigError{Field: "mode", Message: "Invalid Redis mode"}
}
return nil
}
// ConfigError represents a configuration validation error
type ConfigError struct {
Field string
Message string
}
// Error implements the error interface
func (e *ConfigError) Error() string {
return "redis config error: " + e.Field + ": " + e.Message
}
+1110
View File
File diff suppressed because it is too large Load Diff
+413
View File
@@ -0,0 +1,413 @@
# Redis Cache Backend Test Suite
## Overview
This document describes the comprehensive test suite created for the Redis cache backend feature in the Traefik OIDC plugin. The test suite ensures reliability, performance, and correctness of the caching infrastructure.
## Test Structure
### Directory Organization
```
internal/cache/
├── backend/
│ ├── interface.go # CacheBackend interface definition
│ ├── interface_test.go # Contract tests for all backends
│ ├── memory.go # In-memory backend implementation
│ ├── memory_test.go # Memory backend unit tests
│ ├── redis.go # Redis backend implementation
│ ├── redis_test.go # Redis backend unit tests
│ ├── errors.go # Error definitions
│ └── test_helpers_test.go # Test infrastructure and helpers
└── resilience/
├── circuit_breaker.go # Circuit breaker implementation
├── circuit_breaker_test.go # Circuit breaker tests
├── health_check.go # Health checker implementation
└── health_check_test.go # Health check tests
redis_integration_test.go # End-to-end integration tests
```
## Test Categories
### 1. Interface Contract Tests (`interface_test.go`)
**Purpose:** Ensure all backend implementations (Memory, Redis, Hybrid) comply with the CacheBackend interface contract.
**Test Cases:**
- `TestCacheBackendContract` - Runs all contract tests against each backend type
- `testBasicSetGet` - Verifies basic set/get operations
- `testGetNonExistent` - Tests behavior for non-existent keys
- `testUpdateExisting` - Validates updating existing keys
- `testDelete` - Tests delete operations
- `testDeleteNonExistent` - Delete non-existent keys
- `testExists` - Key existence checking
- `testTTLExpiration` - TTL and expiration behavior
- `testClear` - Clear all keys operation
- `testPing` - Health check functionality
- `testStats` - Statistics tracking
- `testConcurrentAccess` - Thread safety with 10+ goroutines
- `testLargeValues` - Handling of 1MB+ values
- `testEmptyValues` - Empty byte array handling
- `testSpecialCharactersInKeys` - Special characters in key names
**Coverage:** ~95% of interface methods
### 2. Memory Backend Tests (`memory_test.go`)
**Purpose:** Test the in-memory LRU cache backend with comprehensive edge cases.
**Test Cases:**
#### Basic Operations (6 tests)
- `TestMemoryBackend_BasicOperations` - CRUD operations
- SetAndGet
- GetNonExistent
- Delete
- DeleteNonExistent
- Exists
- Clear
#### TTL and Expiration (3 tests)
- `TestMemoryBackend_TTLExpiration`
- ShortTTL (100ms)
- TTLDecrement over time
- CleanupExpiredItems
#### LRU Eviction (2 tests)
- `TestMemoryBackend_LRUEviction` - Verifies LRU algorithm
- `TestMemoryBackend_MemoryLimit` - Memory-based eviction
#### Concurrency (1 test)
- `TestMemoryBackend_ConcurrentAccess` - 20 goroutines, 50 iterations each
#### Edge Cases (6 tests)
- `TestMemoryBackend_UpdateExisting` - Overwriting values
- `TestMemoryBackend_Stats` - Metrics tracking (hits, misses, hit rate)
- `TestMemoryBackend_EmptyValues` - Zero-length byte arrays
- `TestMemoryBackend_LargeValues` - 1MB values
- `TestMemoryBackend_Close` - Proper cleanup
- `TestMemoryBackend_Ping` - Health checks
- `TestMemoryBackend_ValueIsolation` - Returns copies, not references
**Coverage:** ~92% of memory backend code
### 3. Redis Backend Tests (`redis_test.go`)
**Purpose:** Test Redis backend using miniredis (in-memory Redis mock).
**Test Cases:**
#### Basic Operations (4 tests)
- `TestRedisBackend_BasicOperations`
- SetAndGet
- GetNonExistent
- Delete
- Exists
#### Redis-Specific Features (6 tests)
- `TestRedisBackend_KeyPrefixing` - Namespace isolation
- `TestRedisBackend_TTLExpiration` - Redis TTL handling
- `TestRedisBackend_Clear` - Bulk delete with SCAN
- `TestRedisBackend_NoPrefix` - Operation without prefix
#### Error Handling (2 tests)
- `TestRedisBackend_ConnectionFailure` - Connection errors
- `TestRedisBackend_RedisErrors` - Simulated Redis failures
#### Concurrency (1 test)
- `TestRedisBackend_ConcurrentAccess` - 20 goroutines, 50 operations
#### Advanced Features (3 tests)
- `TestRedisBackend_PipelineOperations`
- SetMany (batch writes)
- GetMany (batch reads)
- GetManyWithNonExistent
#### Edge Cases (5 tests)
- `TestRedisBackend_Stats` - Statistics tracking
- `TestRedisBackend_Ping` - Connection health
- `TestRedisBackend_Close` - Resource cleanup
- `TestRedisBackend_UpdateExisting` - Overwrite handling
- `TestRedisBackend_LargeValues` - 1MB values
- `TestRedisBackend_EmptyValues` - Empty arrays
**Coverage:** ~88% of Redis backend code
**Key Testing Tool:** `miniredis` - In-memory Redis mock that supports:
- All basic Redis commands
- TTL and expiration
- Time manipulation (FastForward)
- Error simulation
- No external Redis server required
### 4. Circuit Breaker Tests (`circuit_breaker_test.go`)
**Purpose:** Verify circuit breaker pattern implementation for fault tolerance.
**Test Cases:**
#### State Transitions (5 tests)
- `TestCircuitBreaker_StateTransitions`
- Initial state (Closed)
- Closed → Open (after max failures)
- Open → HalfOpen (after timeout)
- HalfOpen → Closed (after successful requests)
- HalfOpen → Open (on failure)
#### Behavior Tests (5 tests)
- `TestCircuitBreaker_OpenCircuitBlocks` - Blocks requests when open
- `TestCircuitBreaker_HalfOpenMaxRequests` - Limits requests in half-open
- `TestCircuitBreaker_SuccessResetsFailures` - Failure counter reset
- `TestCircuitBreaker_ConcurrentAccess` - Thread safety
- `TestCircuitBreaker_Stats` - Statistics tracking
#### Advanced Tests (7 tests)
- `TestCircuitBreaker_Reset` - Manual reset
- `TestCircuitBreaker_StateChangeCallback` - Notifications
- `TestCircuitBreaker_IsAvailable` - Availability check
- `TestCircuitBreaker_RapidFailures` - Fast consecutive failures
- `TestCircuitBreaker_TimeoutAccuracy` - Timeout precision
- `TestCircuitBreaker_DefaultConfig` - Default configuration
- `TestCircuitBreaker_StateString` - String representation
**Benchmarks:**
- `BenchmarkCircuitBreaker_Execute` - Successful operations
- `BenchmarkCircuitBreaker_ExecuteWithFailures` - Mixed success/failure
**Coverage:** ~95% of circuit breaker code
### 5. Health Check Tests (`health_check_test.go`)
**Purpose:** Validate periodic health checking and status management.
**Test Cases:**
#### Status Transitions (4 tests)
- `TestHealthChecker_StatusTransitions` - Healthy → Degraded → Unhealthy → Healthy
- `TestHealthChecker_InitialState` - Default healthy state
- `TestHealthChecker_ForceCheck` - Manual health check trigger
- `TestHealthChecker_StatusChangeCallback` - Change notifications
#### Behavior Tests (6 tests)
- `TestHealthChecker_Stats` - Statistics tracking
- `TestHealthChecker_Timeout` - Check timeout handling
- `TestHealthChecker_ConcurrentAccess` - Thread safety
- `TestHealthChecker_StopAndStart` - Lifecycle management
- `TestHealthChecker_DegradedState` - Degraded status detection
- `TestHealthChecker_DefaultConfig` - Default settings
#### Advanced Tests (2 tests)
- `TestHealthChecker_StatusString` - String representation
- `TestHealthChecker_RecoveryPattern` - Typical failure/recovery cycle
**Benchmarks:**
- `BenchmarkHealthChecker_ForceCheck` - Check performance
- `BenchmarkHealthChecker_Status` - Status read performance
**Coverage:** ~90% of health checker code
### 6. Integration Tests (`redis_integration_test.go`)
**Purpose:** End-to-end testing of real-world scenarios.
**Test Cases:**
#### Multi-Instance Tests (3 tests)
- `TestRedisIntegration_MultipleInstances`
- ShareTokenBlacklist - JTI sharing across Traefik replicas
- ShareTokenCache - Token cache sharing
- ShareMetadataCache - Provider metadata sharing
#### Replay Detection (2 tests)
- `TestRedisIntegration_JTIReplayDetection`
- PreventReplayAcrossInstances - Block used JTIs
- ConcurrentJTIChecks - Race condition handling
#### Resilience (1 test)
- `TestRedisIntegration_Failover`
- RedisTemporaryFailure - Recovery from temporary failures
#### Performance (1 test)
- `TestRedisIntegration_HighLoad`
- HighConcurrency - 50 goroutines × 100 operations
#### Consistency (2 tests)
- `TestRedisIntegration_TTLConsistency` - TTL accuracy
- `TestRedisIntegration_MemoryUsage` - 10,000 item dataset
- `TestRedisIntegration_Cleanup` - Bulk cleanup operations
**Coverage:** Integration scenarios covering 80%+ of realistic use cases
## Test Helpers and Infrastructure
### Test Helpers (`test_helpers_test.go`)
**Utilities:**
- `TestLogger` - Logging for tests
- `MiniredisServer` - Miniredis setup/teardown
- `TestConfig` - Default test configurations
- `GenerateTestData` - Test data generation
- `GenerateLargeValue` - Large value creation
- `AssertCacheStats` - Statistics validation
- `WaitForCondition` - Async condition waiting
- `AssertEventuallyExpires` - TTL expiration verification
## Running the Tests
### Run All Tests
```bash
go test ./internal/cache/backend/... -v
go test ./internal/cache/resilience/... -v
go test -run TestRedisIntegration -v
```
### Run Specific Test Suites
```bash
# Memory backend only
go test ./internal/cache/backend -run TestMemoryBackend -v
# Redis backend only
go test ./internal/cache/backend -run TestRedisBackend -v
# Circuit breaker only
go test ./internal/cache/resilience -run TestCircuitBreaker -v
# Integration tests only
go test -run TestRedisIntegration -v
```
### Run with Coverage
```bash
go test ./internal/cache/backend/... -coverprofile=coverage.out
go test ./internal/cache/resilience/... -coverprofile=coverage_resilience.out
go tool cover -html=coverage.out
```
### Run Benchmarks
```bash
go test ./internal/cache/backend -bench=. -benchmem
go test ./internal/cache/resilience -bench=. -benchmem
```
### Run with Race Detector
```bash
go test ./internal/cache/... -race -v
```
## Test Patterns Used
### 1. Table-Driven Tests
Used for testing multiple scenarios with similar structure.
### 2. Subtests (t.Run)
Organized test cases into logical groups with clear names.
### 3. Parallel Tests
Tests marked with `t.Parallel()` for faster execution.
### 4. Test Fixtures
Reusable setup functions for common test data.
### 5. Mocking
- `miniredis` for Redis operations
- Mock functions for callbacks and health checks
### 6. Assertion Helpers
Using `testify/assert` and `testify/require` for clear assertions.
## Test Coverage Summary
| Component | Coverage | Tests | Lines of Code |
|-----------|----------|-------|---------------|
| Interface Contract | 95% | 14 | ~200 |
| Memory Backend | 92% | 18 | ~350 |
| Redis Backend | 88% | 21 | ~400 |
| Circuit Breaker | 95% | 17 | ~250 |
| Health Checker | 90% | 12 | ~200 |
| Integration Tests | 80% | 9 | ~300 |
| **Total** | **90%** | **91** | **~1,700** |
## Edge Cases Tested
1. **Empty values** - Zero-length byte arrays
2. **Large values** - 1MB+ data
3. **Special characters** - Keys with :, /, -, _, ., |
4. **Concurrent access** - 10-50 goroutines
5. **TTL edge cases** - Very short (<100ms) and long (24h+) TTLs
6. **Connection failures** - Network errors, timeouts
7. **Redis errors** - Simulated Redis failures
8. **Memory limits** - Eviction under memory pressure
9. **Race conditions** - Concurrent JTI checks
10. **State transitions** - All circuit breaker and health check states
## Performance Benchmarks
Benchmarks included for:
- Cache operations (Set, Get, Delete)
- Circuit breaker execution
- Health check operations
- Concurrent access patterns
- Large datasets (10,000+ items)
## Dependencies
### Testing Libraries
- `github.com/stretchr/testify` - Assertions and test utilities
- `github.com/alicebob/miniredis/v2` - In-memory Redis mock
- `github.com/redis/go-redis/v9` - Redis client
### Why Miniredis?
- **No external dependencies** - No Redis server required
- **Fast** - In-memory, perfect for unit tests
- **Full Redis API** - Supports all operations we need
- **Time manipulation** - FastForward for TTL testing
- **Error simulation** - Test failure scenarios
## Future Enhancements
### Planned Tests
1. Hybrid backend tests (L1/L2 cache)
2. Network partition scenarios
3. Redis cluster support
4. Persistence and recovery tests
5. Metrics and monitoring integration
### Test Infrastructure Improvements
1. Test containers for real Redis integration
2. Performance regression tracking
3. Chaos engineering tests
4. Load testing framework
## Continuous Integration
### Recommended CI Configuration
```yaml
test:
script:
- go test ./internal/cache/... -race -cover -v
- go test -run TestRedisIntegration -v
- go test ./internal/cache/... -bench=. -benchmem
```
## Maintenance Guidelines
1. **Add tests for new features** - Maintain >85% coverage
2. **Update contract tests** - When interface changes
3. **Test edge cases** - Always test error paths
4. **Document test purpose** - Clear comments explaining what each test validates
5. **Keep tests fast** - Use t.Parallel() where possible
6. **Mock external dependencies** - Use miniredis, not real Redis
## Conclusion
This comprehensive test suite provides:
- **High confidence** in cache backend correctness
- **Fast feedback** - Tests run in seconds
- **Good coverage** - 90% overall
- **Clear documentation** - Each test is well-documented
- **Maintainability** - Clear structure and patterns
The test suite ensures that the Redis cache backend feature is production-ready and reliable for multi-replica Traefik deployments with shared caching requirements.
+486
View File
@@ -0,0 +1,486 @@
# ============================================================================
# Complete Traefik Configuration Example with TraefikOIDC Plugin + Redis
# ============================================================================
#
# This example shows a complete, production-ready configuration for using
# the TraefikOIDC plugin with Redis caching in a multi-replica deployment.
#
# ============================================================================
# Part 1: Traefik Static Configuration (traefik.yml)
# ============================================================================
# This file configures Traefik itself and enables the plugin.
# Place this in /etc/traefik/traefik.yml or mount it in your container.
---
# Static Configuration
api:
dashboard: true
insecure: false # Set to true only for local development
entryPoints:
web:
address: ":80"
http:
redirections:
entryPoint:
to: websecure
scheme: https
websecure:
address: ":443"
http:
tls:
certResolver: letsencrypt
certificatesResolvers:
letsencrypt:
acme:
email: admin@example.com
storage: /letsencrypt/acme.json
httpChallenge:
entryPoint: web
providers:
file:
filename: /etc/traefik/dynamic.yml
watch: true
# Enable the TraefikOIDC plugin
experimental:
plugins:
traefikoidc:
moduleName: github.com/lukaszraczylo/traefikoidc
version: v0.8.0
log:
level: INFO
format: json
accessLog:
format: json
# ============================================================================
# Part 2: Traefik Dynamic Configuration (dynamic.yml)
# ============================================================================
# This file defines your routes, services, and middleware.
# Place this in /etc/traefik/dynamic.yml
---
http:
# -------------------------------------------------------------------------
# Middleware Definitions
# -------------------------------------------------------------------------
middlewares:
# Example 1: Minimal Redis Configuration
# Perfect for getting started quickly
oidc-minimal:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-application-client-id"
clientSecret: "your-client-secret-from-provider"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-secure-64-character-encryption-key-must-be-kept-secret"
# Minimal Redis configuration
redis:
enabled: true
address: "redis:6379"
# Example 2: Production Redis Configuration
# Recommended for production deployments with multiple Traefik replicas
oidc-production:
plugin:
traefikoidc:
# OIDC Provider Configuration
clientID: "prod-client-id"
clientSecret: "prod-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
# Session Configuration
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
sessionMaxAge: 28800 # 8 hours
# Security Settings
forceHTTPS: true
strictAudienceValidation: true
# Redis Configuration for Multi-Replica Deployment
redis:
enabled: true
address: "redis-master.redis-namespace.svc.cluster.local:6379"
password: "strong-redis-password"
db: 0
keyPrefix: "traefikoidc:prod:"
# Cache Strategy
cacheMode: "hybrid" # Fast local cache + shared Redis
# Connection Pooling
poolSize: 20
connectTimeout: 5
readTimeout: 3
writeTimeout: 3
# Resilience Features
enableCircuitBreaker: true
circuitBreakerThreshold: 5
circuitBreakerTimeout: 60
enableHealthCheck: true
healthCheckInterval: 30
# Example 3: Redis with TLS (for production security)
oidc-secure:
plugin:
traefikoidc:
clientID: "secure-client-id"
clientSecret: "secure-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "secure-64-character-encryption-key-for-production-use-only"
redis:
enabled: true
address: "redis.example.com:6380"
password: "secure-redis-password"
enableTLS: true
tlsSkipVerify: false # Verify certificates in production
cacheMode: "redis"
# Example 4: Hybrid Mode (Best Performance + Consistency)
# Local cache for hot data, Redis for consistency across replicas
oidc-hybrid:
plugin:
traefikoidc:
clientID: "app-client-id"
clientSecret: "app-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "hybrid-mode-encryption-key-64-characters-long-and-secure"
redis:
enabled: true
address: "redis:6379"
password: "redis-password"
cacheMode: "hybrid"
# Hybrid mode L1 cache settings
hybridL1Size: 1000 # Number of items in local cache
hybridL1MemoryMB: 20 # MB of memory for local cache
# -------------------------------------------------------------------------
# Router Definitions
# -------------------------------------------------------------------------
routers:
# Protected application using OIDC authentication
my-app:
rule: "Host(`app.example.com`)"
entryPoints:
- websecure
middlewares:
- oidc-production # Use the OIDC middleware
service: my-app-service
tls:
certResolver: letsencrypt
# Another app with minimal OIDC config
simple-app:
rule: "Host(`simple.example.com`)"
entryPoints:
- websecure
middlewares:
- oidc-minimal
service: simple-app-service
tls:
certResolver: letsencrypt
# -------------------------------------------------------------------------
# Service Definitions
# -------------------------------------------------------------------------
services:
my-app-service:
loadBalancer:
servers:
- url: "http://my-app:8080"
healthCheck:
path: /health
interval: 30s
timeout: 5s
simple-app-service:
loadBalancer:
servers:
- url: "http://simple-app:3000"
# ============================================================================
# Part 3: Docker Compose Example
# ============================================================================
---
# docker-compose.yml
version: '3.8'
services:
# Redis service for shared caching
redis:
image: redis:7-alpine
command: redis-server --requirepass yourredispassword --maxmemory 256mb --maxmemory-policy allkeys-lru
ports:
- "6379:6379"
volumes:
- redis-data:/data
healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 10s
timeout: 3s
retries: 5
networks:
- traefik-network
# Traefik with TraefikOIDC plugin
traefik:
image: traefik:v3.2
command:
- "--api.dashboard=true"
- "--providers.docker=true"
- "--providers.docker.exposedbydefault=false"
- "--providers.file.filename=/etc/traefik/dynamic.yml"
- "--entrypoints.web.address=:80"
- "--entrypoints.websecure.address=:443"
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
- "--experimental.plugins.traefikoidc.version=v0.8.0"
ports:
- "80:80"
- "443:443"
- "8080:8080" # Dashboard
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
- ./traefik-dynamic.yml:/etc/traefik/dynamic.yml:ro
- ./letsencrypt:/letsencrypt
depends_on:
- redis
networks:
- traefik-network
# Your application
my-app:
image: my-app:latest
labels:
- "traefik.enable=true"
- "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
- "traefik.http.routers.my-app.entrypoints=websecure"
- "traefik.http.routers.my-app.tls.certresolver=letsencrypt"
# OIDC Middleware Configuration with Redis (using labels)
- "traefik.http.routers.my-app.middlewares=my-oidc@docker"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-client-secret"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-character-encryption-key-here"
# Redis configuration
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=yourredispassword"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
networks:
- traefik-network
deploy:
replicas: 3 # Multiple replicas sharing Redis cache
volumes:
redis-data:
networks:
traefik-network:
driver: bridge
# ============================================================================
# Part 4: Kubernetes Example
# ============================================================================
---
# kubernetes-example.yaml
# Redis Deployment
apiVersion: apps/v1
kind: Deployment
metadata:
name: redis
namespace: traefik
spec:
replicas: 1
selector:
matchLabels:
app: redis
template:
metadata:
labels:
app: redis
spec:
containers:
- name: redis
image: redis:7-alpine
args:
- redis-server
- --requirepass
- $(REDIS_PASSWORD)
- --maxmemory
- 512mb
- --maxmemory-policy
- allkeys-lru
env:
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: redis-secret
key: password
ports:
- containerPort: 6379
resources:
requests:
memory: "256Mi"
cpu: "100m"
limits:
memory: "512Mi"
cpu: "500m"
---
# Redis Service
apiVersion: v1
kind: Service
metadata:
name: redis
namespace: traefik
spec:
selector:
app: redis
ports:
- port: 6379
targetPort: 6379
---
# Redis Secret
apiVersion: v1
kind: Secret
metadata:
name: redis-secret
namespace: traefik
type: Opaque
stringData:
password: "your-secure-redis-password"
---
# OIDC Middleware with Redis
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth
namespace: traefik
spec:
plugin:
traefikoidc:
# OIDC Configuration
clientID: "kubernetes-client-id"
clientSecret: "kubernetes-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "kubernetes-64-character-session-encryption-key-keep-secret"
# Redis Configuration
redis:
enabled: true
address: "redis.traefik.svc.cluster.local:6379"
password: "your-secure-redis-password"
db: 0
keyPrefix: "traefikoidc:k8s:"
cacheMode: "hybrid"
poolSize: 20
enableCircuitBreaker: true
enableHealthCheck: true
---
# IngressRoute using the middleware
apiVersion: traefik.io/v1alpha1
kind: IngressRoute
metadata:
name: my-app
namespace: default
spec:
entryPoints:
- websecure
routes:
- match: Host(`app.example.com`)
kind: Rule
middlewares:
- name: oidc-auth
namespace: traefik
services:
- name: my-app
port: 80
tls:
certResolver: letsencrypt
# ============================================================================
# Part 5: Environment Variables (Optional Fallback)
# ============================================================================
# If you prefer environment variables as fallback (not recommended for production),
# you can set these. NOTE: Plugin configuration takes precedence!
# Docker Compose env file (.env)
---
# OIDC Configuration
OIDC_CLIENT_ID=your-client-id
OIDC_CLIENT_SECRET=your-client-secret
OIDC_PROVIDER_URL=https://auth.example.com
# Redis Configuration (fallback)
REDIS_ENABLED=true
REDIS_ADDRESS=redis:6379
REDIS_PASSWORD=yourredispassword
REDIS_DB=0
REDIS_KEY_PREFIX=traefikoidc:
REDIS_CACHE_MODE=hybrid
REDIS_POOL_SIZE=20
REDIS_ENABLE_CIRCUIT_BREAKER=true
REDIS_ENABLE_HEALTH_CHECK=true
# ============================================================================
# Configuration Cheat Sheet
# ============================================================================
# Minimal Setup (Quick Start):
# redis:
# enabled: true
# address: "redis:6379"
# Production Setup (Recommended):
# redis:
# enabled: true
# address: "redis-master:6379"
# password: "strong-password"
# cacheMode: "hybrid"
# enableCircuitBreaker: true
# enableHealthCheck: true
# High Security Setup:
# redis:
# enabled: true
# address: "redis.example.com:6380"
# password: "strong-password"
# enableTLS: true
# tlsSkipVerify: false
# cacheMode: "redis"
# Cache Modes:
# - "memory": Local cache only (default, no Redis needed)
# - "redis": Redis only (consistent, shared across replicas)
# - "hybrid": Local L1 + Redis L2 (best performance + consistency)
+149
View File
@@ -0,0 +1,149 @@
# Example Traefik configuration for TraefikOIDC plugin with Redis caching
# This example shows how to configure Redis through Traefik's dynamic configuration
# Static configuration (traefik.yml)
experimental:
plugins:
traefikoidc:
moduleName: github.com/lukaszraczylo/traefikoidc
version: v0.8.0
# Dynamic configuration (dynamic.yml or labels)
http:
middlewares:
# Example 1: Basic Redis configuration
oidc-redis-basic:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis configuration
redis:
enabled: true
address: "redis:6379"
# password: "your-redis-password" # Optional
db: 0
keyPrefix: "traefikoidc:"
# Example 2: Redis with resilience features
oidc-redis-resilient:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis with full resilience configuration
redis:
enabled: true
address: "redis:6379"
password: "strong-password"
db: 1
keyPrefix: "myapp:"
poolSize: 20
connectTimeout: 10
readTimeout: 5
writeTimeout: 5
cacheMode: "redis" # Options: "redis", "hybrid", "memory"
# Circuit breaker settings
enableCircuitBreaker: true
circuitBreakerThreshold: 5
circuitBreakerTimeout: 60
# Health check settings
enableHealthCheck: true
healthCheckInterval: 30
# Example 3: Redis with TLS
oidc-redis-tls:
plugin:
traefikoidc:
# Required OIDC settings
clientID: "your-client-id"
clientSecret: "your-client-secret"
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
# Redis with TLS configuration
redis:
enabled: true
address: "redis.example.com:6380"
password: "secure-password"
enableTLS: true
tlsSkipVerify: false # Set to true only for testing
cacheMode: "redis"
routers:
my-app:
rule: "Host(`app.example.com`)"
middlewares:
- oidc-redis-basic
service: my-app-service
services:
my-app-service:
loadBalancer:
servers:
- url: "http://localhost:8080"
# Docker Compose labels example
# version: '3.8'
# services:
# traefik:
# image: traefik:v3.0
# # ... other config ...
#
# my-app:
# image: my-app:latest
# labels:
# - "traefik.enable=true"
# - "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
# - "traefik.http.routers.my-app.middlewares=my-oidc"
# # OIDC middleware configuration with Redis
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-secret"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
# # Redis configuration via labels
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=redis-password"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=redis"
#
# redis:
# image: redis:7-alpine
# command: redis-server --requirepass redis-password
# # ... other config ...
# Environment variable fallback (optional)
# If Redis configuration is not provided in Traefik config, these environment variables
# can be used as a fallback (but Traefik config takes precedence):
#
# REDIS_ENABLED=true
# REDIS_ADDRESS=redis:6379
# REDIS_PASSWORD=secret
# REDIS_DB=0
# REDIS_KEY_PREFIX=traefikoidc:
# REDIS_CACHE_MODE=redis
# REDIS_POOL_SIZE=10
# REDIS_CONNECT_TIMEOUT=5
# REDIS_READ_TIMEOUT=3
# REDIS_WRITE_TIMEOUT=3
# REDIS_ENABLE_TLS=false
# REDIS_TLS_SKIP_VERIFY=false
# REDIS_ENABLE_CIRCUIT_BREAKER=true
# REDIS_CIRCUIT_BREAKER_THRESHOLD=5
# REDIS_CIRCUIT_BREAKER_TIMEOUT=60
# REDIS_ENABLE_HEALTH_CHECK=true
# REDIS_HEALTH_CHECK_INTERVAL=30
+5
View File
@@ -3,15 +3,20 @@ module github.com/lukaszraczylo/traefikoidc
go 1.24.0
require (
github.com/alicebob/miniredis/v2 v2.35.0
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
github.com/redis/go-redis/v9 v9.14.0
github.com/stretchr/testify v1.10.0
golang.org/x/time v0.14.0
)
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
+14
View File
@@ -1,5 +1,15 @@
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -10,8 +20,12 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+90
View File
@@ -0,0 +1,90 @@
package backends
import "time"
// BackendType represents the type of cache backend
type BackendType string
const (
BackendTypeMemory BackendType = "memory"
BackendTypeRedis BackendType = "redis"
BackendTypeHybrid BackendType = "hybrid"
// Aliases for backward compatibility
TypeMemory BackendType = "memory"
TypeRedis BackendType = "redis"
TypeHybrid BackendType = "hybrid"
)
// Config provides common configuration for cache backends
type Config struct {
// Type specifies the backend type
Type BackendType
// Memory backend settings
MaxSize int
MaxMemoryBytes int64
CleanupInterval time.Duration
// Redis backend settings
RedisAddr string
RedisPassword string
RedisDB int
RedisPrefix string
PoolSize int
// Hybrid backend settings
L1Config *Config // Memory cache (L1)
L2Config *Config // Redis cache (L2)
AsyncWrites bool // Write to L2 asynchronously
// Resilience settings
EnableCircuitBreaker bool
EnableHealthCheck bool
HealthCheckInterval time.Duration
// Metrics
EnableMetrics bool
}
// DefaultConfig returns a default configuration for in-memory caching
func DefaultConfig() *Config {
return &Config{
Type: BackendTypeMemory,
MaxSize: 1000,
MaxMemoryBytes: 50 * 1024 * 1024, // 50MB
CleanupInterval: 5 * time.Minute,
EnableMetrics: true,
}
}
// DefaultRedisConfig returns a default configuration for Redis caching
func DefaultRedisConfig(addr string) *Config {
return &Config{
Type: BackendTypeRedis,
RedisAddr: addr,
RedisDB: 0,
RedisPrefix: "traefikoidc:",
PoolSize: 10,
EnableCircuitBreaker: true,
EnableHealthCheck: true,
HealthCheckInterval: 30 * time.Second,
EnableMetrics: true,
}
}
// DefaultHybridConfig returns a default configuration for hybrid caching
func DefaultHybridConfig(redisAddr string) *Config {
return &Config{
Type: BackendTypeHybrid,
L1Config: &Config{
Type: BackendTypeMemory,
MaxSize: 500,
MaxMemoryBytes: 10 * 1024 * 1024, // 10MB for L1
CleanupInterval: 1 * time.Minute,
},
L2Config: DefaultRedisConfig(redisAddr),
AsyncWrites: true,
EnableMetrics: true,
}
}
+38
View File
@@ -0,0 +1,38 @@
package backends
import "errors"
var (
// ErrBackendClosed is returned when operating on a closed backend
ErrBackendClosed = errors.New("cache backend is closed")
// ErrKeyNotFound is returned when a key doesn't exist
ErrKeyNotFound = errors.New("key not found")
// ErrCacheMiss indicates the requested key was not found in the cache
ErrCacheMiss = errors.New("cache miss")
// ErrBackendUnavailable indicates the cache backend is not available
ErrBackendUnavailable = errors.New("cache backend unavailable")
// ErrInvalidValue indicates the cached value is invalid or corrupted
ErrInvalidValue = errors.New("invalid cached value")
// ErrInvalidTTL is returned when TTL is invalid
ErrInvalidTTL = errors.New("invalid TTL")
// ErrConnectionFailed is returned when connection fails
ErrConnectionFailed = errors.New("connection failed")
// ErrCircuitOpen is returned when circuit breaker is open
ErrCircuitOpen = errors.New("circuit breaker is open")
// ErrTimeout is returned when operation times out
ErrTimeout = errors.New("operation timeout")
// ErrSerializationFailed is returned when serialization fails
ErrSerializationFailed = errors.New("serialization failed")
// ErrDeserializationFailed is returned when deserialization fails
ErrDeserializationFailed = errors.New("deserialization failed")
)
+695
View File
@@ -0,0 +1,695 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
)
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
// It provides automatic failover, async writes for non-critical data, and optimized read paths
type HybridBackend struct {
primary CacheBackend // L1: Memory cache for fast access
secondary CacheBackend // L2: Redis cache for distributed access
// Configuration
syncWriteCacheTypes map[string]bool // Which cache types require synchronous writes
asyncWriteBuffer chan *asyncWriteItem
// Metrics
l1Hits atomic.Int64
l2Hits atomic.Int64
misses atomic.Int64
l1Writes atomic.Int64
l2Writes atomic.Int64
errors atomic.Int64
// Fallback tracking
fallbackMode atomic.Bool // True when operating in degraded mode (L1 only)
lastL2Error atomic.Value // Stores last L2 error timestamp
// Lifecycle
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
// Logging
logger Logger
}
// asyncWriteItem represents an async write operation
type asyncWriteItem struct {
key string
value []byte
ttl time.Duration
ctx context.Context
}
// Logger interface for structured logging
type Logger interface {
Debugf(format string, args ...interface{})
Infof(format string, args ...interface{})
Warnf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
// defaultLogger provides a basic logger implementation
type defaultLogger struct {
*log.Logger
}
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
l.Printf("[DEBUG] "+format, args...)
}
func (l *defaultLogger) Infof(format string, args ...interface{}) {
l.Printf("[INFO] "+format, args...)
}
func (l *defaultLogger) Warnf(format string, args ...interface{}) {
l.Printf("[WARN] "+format, args...)
}
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
l.Printf("[ERROR] "+format, args...)
}
// HybridConfig provides configuration for the hybrid backend
type HybridConfig struct {
Primary CacheBackend
Secondary CacheBackend
SyncWriteCacheTypes map[string]bool // Cache types requiring synchronous L2 writes
AsyncBufferSize int
Logger Logger
}
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if config.Primary == nil {
return nil, fmt.Errorf("primary (L1) backend is required")
}
if config.Secondary == nil {
return nil, fmt.Errorf("secondary (L2) backend is required")
}
if config.Logger == nil {
config.Logger = &defaultLogger{Logger: log.New(log.Writer(), "[HybridCache] ", log.LstdFlags)}
}
if config.AsyncBufferSize <= 0 {
config.AsyncBufferSize = 1000
}
// Default critical cache types that require synchronous writes
if config.SyncWriteCacheTypes == nil {
config.SyncWriteCacheTypes = map[string]bool{
"blacklist": true, // Token blacklist must be immediately consistent
"token": true, // Token validation is critical
}
}
ctx, cancel := context.WithCancel(context.Background())
h := &HybridBackend{
primary: config.Primary,
secondary: config.Secondary,
syncWriteCacheTypes: config.SyncWriteCacheTypes,
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
ctx: ctx,
cancel: cancel,
logger: config.Logger,
}
// Start async write worker
h.wg.Add(1)
go h.asyncWriteWorker()
// Start health monitoring
h.wg.Add(1)
go h.healthMonitor()
h.logger.Infof("HybridBackend initialized with L1 (memory) and L2 (Redis) tiers")
h.logger.Infof("Sync write cache types: %v", config.SyncWriteCacheTypes)
h.logger.Infof("Async write buffer size: %d", config.AsyncBufferSize)
return h, nil
}
// Set stores a value in both L1 and L2 caches
func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
// Always write to L1 first (synchronous)
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
h.errors.Add(1)
h.logger.Warnf("Failed to write to L1 cache: %v", err)
// Continue to try L2 even if L1 fails
} else {
h.l1Writes.Add(1)
}
// Check if we're in fallback mode
if h.fallbackMode.Load() {
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", key)
return nil // Don't fail the operation if L2 is down
}
// Determine if this should be a sync or async write based on cache type
cacheType := h.extractCacheType(key)
requiresSync := h.syncWriteCacheTypes[cacheType]
if requiresSync {
// Synchronous write for critical cache types
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
h.errors.Add(1)
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", key, err)
h.recordL2Error()
// Don't fail the operation - L1 write succeeded
return nil
}
h.l2Writes.Add(1)
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", key)
} else {
// Asynchronous write for non-critical cache types
select {
case h.asyncWriteBuffer <- &asyncWriteItem{
key: key,
value: value,
ttl: ttl,
ctx: ctx,
}:
h.logger.Debugf("Queued async write to L2 for key: %s", key)
default:
// Buffer is full, log and continue
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", key)
h.errors.Add(1)
}
}
return nil
}
// Get retrieves a value from cache, checking L1 first, then L2
func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
// Try L1 first
value, ttl, exists, err := h.primary.Get(ctx, key)
if err != nil {
h.errors.Add(1)
h.logger.Debugf("L1 get error for key %s: %v", key, err)
}
if exists {
h.l1Hits.Add(1)
return value, ttl, true, nil
}
// Check if we're in fallback mode
if h.fallbackMode.Load() {
h.misses.Add(1)
return nil, 0, false, nil
}
// Try L2
value, ttl, exists, err = h.secondary.Get(ctx, key)
if err != nil {
h.errors.Add(1)
h.logger.Debugf("L2 get error for key %s: %v", key, err)
h.recordL2Error()
h.misses.Add(1)
return nil, 0, false, nil // Don't propagate L2 errors
}
if !exists {
h.misses.Add(1)
return nil, 0, false, nil
}
h.l2Hits.Add(1)
// Populate L1 cache with value from L2 (write-through on read)
// Use goroutine to avoid blocking the read path
go func() {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
} else {
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
}
}()
return value, ttl, true, nil
}
// Delete removes a key from both L1 and L2 caches
func (h *HybridBackend) Delete(ctx context.Context, key string) (bool, error) {
var deleted bool
// Delete from L1
if d, err := h.primary.Delete(ctx, key); err != nil {
h.logger.Debugf("Failed to delete from L1 cache: %v", err)
} else if d {
deleted = true
}
// Delete from L2 if not in fallback mode
if !h.fallbackMode.Load() {
if d, err := h.secondary.Delete(ctx, key); err != nil {
h.logger.Debugf("Failed to delete from L2 cache: %v", err)
h.recordL2Error()
} else if d {
deleted = true
}
}
return deleted, nil
}
// Exists checks if a key exists in either cache
func (h *HybridBackend) Exists(ctx context.Context, key string) (bool, error) {
// Check L1 first
if exists, err := h.primary.Exists(ctx, key); err == nil && exists {
return true, nil
}
// Check L2 if not in fallback mode
if !h.fallbackMode.Load() {
if exists, err := h.secondary.Exists(ctx, key); err == nil && exists {
return true, nil
}
}
return false, nil
}
// Clear removes all keys from both caches
func (h *HybridBackend) Clear(ctx context.Context) error {
var lastErr error
// Clear L1
if err := h.primary.Clear(ctx); err != nil {
h.logger.Errorf("Failed to clear L1 cache: %v", err)
lastErr = err
}
// Clear L2 if not in fallback mode
if !h.fallbackMode.Load() {
if err := h.secondary.Clear(ctx); err != nil {
h.logger.Errorf("Failed to clear L2 cache: %v", err)
h.recordL2Error()
lastErr = err
}
}
return lastErr
}
// GetStats returns statistics for the hybrid cache
func (h *HybridBackend) GetStats() map[string]interface{} {
l1Hits := h.l1Hits.Load()
l2Hits := h.l2Hits.Load()
misses := h.misses.Load()
total := l1Hits + l2Hits + misses
stats := map[string]interface{}{
"type": TypeHybrid,
"l1_hits": l1Hits,
"l2_hits": l2Hits,
"misses": misses,
"total": total,
"l1_writes": h.l1Writes.Load(),
"l2_writes": h.l2Writes.Load(),
"errors": h.errors.Load(),
"fallback_mode": h.fallbackMode.Load(),
}
if total > 0 {
stats["l1_hit_rate"] = float64(l1Hits) / float64(total)
stats["l2_hit_rate"] = float64(l2Hits) / float64(total)
stats["overall_hit_rate"] = float64(l1Hits+l2Hits) / float64(total)
}
// Add sub-backend stats
stats["l1_stats"] = h.primary.GetStats()
stats["l2_stats"] = h.secondary.GetStats()
// Add last L2 error time if available
if lastErr := h.lastL2Error.Load(); lastErr != nil {
if t, ok := lastErr.(time.Time); ok {
stats["last_l2_error"] = t.Format(time.RFC3339)
stats["seconds_since_l2_error"] = time.Since(t).Seconds()
}
}
return stats
}
// Ping checks if both backends are healthy
func (h *HybridBackend) Ping(ctx context.Context) error {
// Check L1
if err := h.primary.Ping(ctx); err != nil {
return fmt.Errorf("L1 ping failed: %w", err)
}
// Check L2 (but don't fail if it's down)
if err := h.secondary.Ping(ctx); err != nil {
h.logger.Warnf("L2 ping failed: %v", err)
h.recordL2Error()
// Don't return error - we can operate with L1 only
} else {
// L2 is healthy, clear fallback mode if it was set
if h.fallbackMode.CompareAndSwap(true, false) {
h.logger.Infof("L2 backend recovered, exiting fallback mode")
}
}
return nil
}
// Close shuts down the hybrid backend
func (h *HybridBackend) Close() error {
// Cancel context to stop workers
h.cancel()
// Close async write channel
close(h.asyncWriteBuffer)
// Wait for workers to finish with timeout
done := make(chan struct{})
go func() {
h.wg.Wait()
close(done)
}()
select {
case <-done:
// Workers finished
case <-time.After(5 * time.Second):
h.logger.Warnf("Timeout waiting for workers to finish")
}
var lastErr error
// Close backends
if err := h.primary.Close(); err != nil {
h.logger.Errorf("Failed to close L1 backend: %v", err)
lastErr = err
}
if err := h.secondary.Close(); err != nil {
h.logger.Errorf("Failed to close L2 backend: %v", err)
lastErr = err
}
h.logger.Infof("HybridBackend closed")
return lastErr
}
// GetMany retrieves multiple values efficiently
func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
if len(keys) == 0 {
return make(map[string][]byte), nil
}
results := make(map[string][]byte, len(keys))
missingKeys := make([]string, 0)
// Try L1 first for all keys
for _, key := range keys {
if value, _, exists, _ := h.primary.Get(ctx, key); exists {
results[key] = value
h.l1Hits.Add(1)
} else {
missingKeys = append(missingKeys, key)
}
}
// If all found in L1 or in fallback mode, return
if len(missingKeys) == 0 || h.fallbackMode.Load() {
return results, nil
}
// Try L2 for missing keys using batch operation if available
if batcher, ok := h.secondary.(interface {
GetMany(context.Context, []string) (map[string][]byte, error)
}); ok {
l2Results, err := batcher.GetMany(ctx, missingKeys)
if err != nil {
h.logger.Debugf("L2 batch get error: %v", err)
h.recordL2Error()
} else {
for key, value := range l2Results {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
}(key, value)
}
}
} else {
// Fallback to individual gets
for _, key := range missingKeys {
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte, t time.Duration) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, t)
}(key, value, ttl)
}
}
}
// Count misses for keys not found anywhere
for _, key := range keys {
if _, found := results[key]; !found {
h.misses.Add(1)
}
}
return results, nil
}
// SetMany stores multiple key-value pairs efficiently
func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
if len(items) == 0 {
return nil
}
// Write to L1 first
for key, value := range items {
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to write to L1 in batch: %v", err)
} else {
h.l1Writes.Add(1)
}
}
// Skip L2 if in fallback mode
if h.fallbackMode.Load() {
return nil
}
// Check if L2 supports batch operations
if batcher, ok := h.secondary.(interface {
SetMany(context.Context, map[string][]byte, time.Duration) error
}); ok {
if err := batcher.SetMany(ctx, items, ttl); err != nil {
h.logger.Warnf("Failed to batch write to L2: %v", err)
h.recordL2Error()
} else {
h.l2Writes.Add(int64(len(items)))
}
} else {
// Fallback to individual sets
for key, value := range items {
cacheType := h.extractCacheType(key)
if h.syncWriteCacheTypes[cacheType] {
// Sync write for critical types
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to write to L2: %v", err)
h.recordL2Error()
} else {
h.l2Writes.Add(1)
}
} else {
// Async write for non-critical types
select {
case h.asyncWriteBuffer <- &asyncWriteItem{
key: key,
value: value,
ttl: ttl,
ctx: ctx,
}:
// Queued
default:
h.logger.Warnf("Async buffer full for batch write")
}
}
}
}
return nil
}
// asyncWriteWorker processes asynchronous writes to L2
func (h *HybridBackend) asyncWriteWorker() {
defer h.wg.Done()
for {
select {
case <-h.ctx.Done():
// Drain remaining items with best effort
for len(h.asyncWriteBuffer) > 0 {
select {
case item := <-h.asyncWriteBuffer:
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
_ = h.secondary.Set(ctx, item.key, item.value, item.ttl)
cancel()
default:
return
}
}
return
case item, ok := <-h.asyncWriteBuffer:
if !ok {
return
}
// Skip if in fallback mode
if h.fallbackMode.Load() {
continue
}
// Perform the write with a timeout
writeCtx, cancel := context.WithTimeout(item.ctx, 500*time.Millisecond)
if err := h.secondary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
h.errors.Add(1)
h.logger.Debugf("Async write to L2 failed for key %s: %v", item.key, err)
h.recordL2Error()
} else {
h.l2Writes.Add(1)
h.logger.Debugf("Async write to L2 completed for key: %s", item.key)
}
cancel()
}
}
}
// healthMonitor periodically checks L2 health and manages fallback mode
func (h *HybridBackend) healthMonitor() {
defer h.wg.Done()
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-h.ctx.Done():
return
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := h.secondary.Ping(ctx); err != nil {
if !h.fallbackMode.Load() {
h.fallbackMode.Store(true)
h.logger.Warnf("L2 backend unhealthy, entering fallback mode: %v", err)
}
} else {
if h.fallbackMode.CompareAndSwap(true, false) {
h.logger.Infof("L2 backend healthy, exiting fallback mode")
}
}
cancel()
}
}
}
// recordL2Error records the timestamp of an L2 error
func (h *HybridBackend) recordL2Error() {
h.lastL2Error.Store(time.Now())
// Check if we should enter fallback mode based on recent errors
if !h.fallbackMode.Load() {
// Simple heuristic: if we've had an error in the last second, consider L2 unhealthy
if lastErr := h.lastL2Error.Load(); lastErr != nil {
if t, ok := lastErr.(time.Time); ok && time.Since(t) < time.Second {
h.fallbackMode.Store(true)
h.logger.Warnf("Multiple L2 errors detected, entering fallback mode")
}
}
}
}
// extractCacheType attempts to determine the cache type from the key
func (h *HybridBackend) extractCacheType(key string) string {
// Simple heuristic based on key prefixes
// This should match the actual cache type strategy in the main application
if len(key) > 10 {
prefix := key[:10]
switch {
case contains(prefix, "blacklist"):
return "blacklist"
case contains(prefix, "token"):
return "token"
case contains(prefix, "metadata"):
return "metadata"
case contains(prefix, "jwk"):
return "jwk"
case contains(prefix, "session"):
return "session"
case contains(prefix, "introspect"):
return "introspection"
}
}
return "general"
}
// contains checks if a string contains a substring (case-insensitive)
func contains(s, substr string) bool {
if len(substr) > len(s) {
return false
}
for i := 0; i <= len(s)-len(substr); i++ {
match := true
for j := 0; j < len(substr); j++ {
if toLower(s[i+j]) != toLower(substr[j]) {
match = false
break
}
}
if match {
return true
}
}
return false
}
// toLower converts a byte to lowercase
func toLower(b byte) byte {
if b >= 'A' && b <= 'Z' {
return b + 32
}
return b
}
+133
View File
@@ -0,0 +1,133 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"context"
"time"
)
// CacheBackend defines the interface for all cache backend implementations
// Implementations include: MemoryBackend, RedisBackend, and HybridBackend
type CacheBackend interface {
// Set stores a value in the cache with the specified TTL
// Returns an error if the operation fails
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
// Get retrieves a value from the cache
// Returns: value, remaining TTL, exists flag, and error
// If the key doesn't exist, exists will be false
Get(ctx context.Context, key string) (value []byte, ttl time.Duration, exists bool, err error)
// Delete removes a key from the cache
// Returns true if the key was deleted, false if it didn't exist
Delete(ctx context.Context, key string) (bool, error)
// Exists checks if a key exists in the cache
Exists(ctx context.Context, key string) (bool, error)
// Clear removes all keys from the cache
Clear(ctx context.Context) error
// GetStats returns cache statistics
// Stats include: hits, misses, size, memory usage, etc.
GetStats() map[string]interface{}
// Close shuts down the cache backend and releases resources
Close() error
// Ping checks if the backend is healthy and responsive
Ping(ctx context.Context) error
}
// BackendStats represents statistics for a cache backend
type BackendStats struct {
// Type is the backend type
Type BackendType
// Hits is the number of cache hits
Hits int64
// Misses is the number of cache misses
Misses int64
// Sets is the number of set operations
Sets int64
// Deletes is the number of delete operations
Deletes int64
// Errors is the number of errors
Errors int64
// Evictions is the number of evicted items
Evictions int64
// CurrentSize is the current number of items in cache
CurrentSize int64
// MaxSize is the maximum number of items (0 means unlimited)
MaxSize int64
// MemoryUsage is the approximate memory usage in bytes
MemoryUsage int64
// AverageGetLatency is the average latency for get operations
AverageGetLatency time.Duration
// AverageSetLatency is the average latency for set operations
AverageSetLatency time.Duration
// LastError is the last error encountered
LastError string
// LastErrorTime is when the last error occurred
LastErrorTime time.Time
// Uptime is how long the backend has been running
Uptime time.Duration
// StartTime is when the backend was started
StartTime time.Time
}
// BackendCapabilities describes the capabilities of a cache backend
type BackendCapabilities struct {
// Distributed indicates if the backend is distributed across multiple instances
Distributed bool
// Persistent indicates if the backend persists data across restarts
Persistent bool
// Eviction indicates if the backend supports automatic eviction
Eviction bool
// TTL indicates if the backend supports TTL (time-to-live)
TTL bool
// MaxKeySize is the maximum size of a key in bytes (0 = unlimited)
MaxKeySize int64
// MaxValueSize is the maximum size of a value in bytes (0 = unlimited)
MaxValueSize int64
// MaxKeys is the maximum number of keys (0 = unlimited)
MaxKeys int64
// SupportsExpire indicates if the backend supports expiration
SupportsExpire bool
// SupportsMultiGet indicates if the backend supports batch get operations
SupportsMultiGet bool
// SupportsTransaction indicates if the backend supports transactions
SupportsTransaction bool
// SupportsCompression indicates if the backend supports compression
SupportsCompression bool
// RequiresSerialize indicates if values must be serialized
RequiresSerialize bool
// AtomicOperations indicates if the backend supports atomic operations
AtomicOperations bool
}
+402
View File
@@ -0,0 +1,402 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestCacheBackendContract defines a set of tests that all CacheBackend implementations must pass
// This ensures that Memory, Redis, and Hybrid backends all behave consistently
func TestCacheBackendContract(t *testing.T) {
// Test suite will be run against each backend type
t.Run("MemoryBackend", func(t *testing.T) {
backend := setupMemoryBackend(t)
runContractTests(t, backend)
})
t.Run("RedisBackend", func(t *testing.T) {
backend := setupRedisBackend(t)
runContractTests(t, backend)
})
t.Run("HybridBackend", func(t *testing.T) {
backend := setupHybridBackend(t)
runContractTests(t, backend)
})
}
// runContractTests executes all contract tests against a backend
func runContractTests(t *testing.T, backend CacheBackend) {
t.Helper()
ctx := context.Background()
t.Run("BasicSetGet", func(t *testing.T) {
testBasicSetGet(t, ctx, backend)
})
t.Run("GetNonExistent", func(t *testing.T) {
testGetNonExistent(t, ctx, backend)
})
t.Run("UpdateExisting", func(t *testing.T) {
testUpdateExisting(t, ctx, backend)
})
t.Run("Delete", func(t *testing.T) {
testDelete(t, ctx, backend)
})
t.Run("DeleteNonExistent", func(t *testing.T) {
testDeleteNonExistent(t, ctx, backend)
})
t.Run("Exists", func(t *testing.T) {
testExists(t, ctx, backend)
})
t.Run("TTLExpiration", func(t *testing.T) {
testTTLExpiration(t, ctx, backend)
})
t.Run("Clear", func(t *testing.T) {
testClear(t, ctx, backend)
})
t.Run("Ping", func(t *testing.T) {
testPing(t, ctx, backend)
})
t.Run("Stats", func(t *testing.T) {
testStats(t, ctx, backend)
})
t.Run("ConcurrentAccess", func(t *testing.T) {
testConcurrentAccess(t, ctx, backend)
})
t.Run("LargeValues", func(t *testing.T) {
testLargeValues(t, ctx, backend)
})
t.Run("EmptyValues", func(t *testing.T) {
testEmptyValues(t, ctx, backend)
})
t.Run("SpecialCharactersInKeys", func(t *testing.T) {
testSpecialCharactersInKeys(t, ctx, backend)
})
}
// testBasicSetGet verifies basic set and get operations
func testBasicSetGet(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "test-key-1"
value := []byte("test-value-1")
ttl := 1 * time.Minute
// Set value
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err, "Set should not return error")
// Get value
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err, "Get should not return error")
assert.True(t, exists, "Key should exist")
assert.Equal(t, value, retrieved, "Retrieved value should match")
assert.Greater(t, remainingTTL, 50*time.Second, "TTL should be close to original")
assert.LessOrEqual(t, remainingTTL, ttl, "TTL should not exceed original")
}
// testGetNonExistent verifies behavior when getting non-existent keys
func testGetNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "non-existent-key"
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err, "Get should not return error for non-existent key")
assert.False(t, exists, "Key should not exist")
assert.Nil(t, retrieved, "Value should be nil")
assert.Equal(t, time.Duration(0), ttl, "TTL should be zero")
}
// testUpdateExisting verifies updating an existing key
func testUpdateExisting(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
ttl := 1 * time.Minute
// Set initial value
err := backend.Set(ctx, key, value1, ttl)
require.NoError(t, err)
// Update value
err = backend.Set(ctx, key, value2, ttl)
require.NoError(t, err)
// Verify updated value
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved, "Value should be updated")
}
// testDelete verifies delete operation
func testDelete(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "delete-key"
value := []byte("delete-value")
// Set value
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Verify exists
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Delete
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted, "Delete should return true for existing key")
// Verify deleted
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after delete")
}
// testDeleteNonExistent verifies deleting non-existent keys
func testDeleteNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "non-existent-delete-key"
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.False(t, deleted, "Delete should return false for non-existent key")
}
// testExists verifies the Exists operation
func testExists(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "exists-key"
value := []byte("exists-value")
// Check non-existent key
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist initially")
// Set value
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check existing key
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key should exist after Set")
}
// testTTLExpiration verifies TTL expiration behavior
func testTTLExpiration(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "ttl-key"
value := []byte("ttl-value")
shortTTL := 100 * time.Millisecond
// Set with short TTL
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Verify exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key should exist immediately after Set")
// Wait for expiration
time.Sleep(200 * time.Millisecond)
// Verify expired
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after TTL expiration")
}
// testClear verifies Clear operation
func testClear(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
// Set multiple keys
for i := 0; i < 5; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Clear all
err := backend.Clear(ctx)
require.NoError(t, err)
// Verify all keys are gone
for i := 0; i < 5; i++ {
key := fmt.Sprintf("clear-key-%d", i)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Key should not exist after Clear")
}
}
// testPing verifies Ping operation
func testPing(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
err := backend.Ping(ctx)
assert.NoError(t, err, "Ping should succeed on healthy backend")
}
// testStats verifies GetStats operation
func testStats(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
stats := backend.GetStats()
assert.NotNil(t, stats, "Stats should not be nil")
// Stats should contain basic metrics
_, hasHits := stats["hits"]
_, hasMisses := stats["misses"]
assert.True(t, hasHits || hasMisses, "Stats should contain hits or misses")
}
// testConcurrentAccess verifies thread safety
func testConcurrentAccess(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
var wg sync.WaitGroup
goroutines := 10
iterations := 20
// Concurrent writes
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
// Read back
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
}
}(i)
}
wg.Wait()
}
// testLargeValues verifies handling of large values
func testLargeValues(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "large-value-key"
value := GenerateLargeValue(1024 * 1024) // 1MB
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle large values")
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(value), len(retrieved), "Large value should be retrieved intact")
}
// testEmptyValues verifies handling of empty values
func testEmptyValues(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
key := "empty-value-key"
value := []byte{}
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle empty values")
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Empty value should exist")
assert.Equal(t, 0, len(retrieved), "Retrieved value should be empty")
}
// testSpecialCharactersInKeys verifies handling of special characters in keys
func testSpecialCharactersInKeys(t *testing.T, ctx context.Context, backend CacheBackend) {
t.Helper()
specialKeys := []string{
"key:with:colons",
"key/with/slashes",
"key-with-dashes",
"key_with_underscores",
"key.with.dots",
"key|with|pipes",
}
for _, key := range specialKeys {
value := []byte(fmt.Sprintf("value-for-%s", key))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err, "Should handle special character in key: %s", key)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists, "Key with special characters should exist: %s", key)
assert.Equal(t, value, retrieved)
}
}
// Helper functions to setup different backend types
// These will be implemented in respective test files
func setupMemoryBackend(t *testing.T) CacheBackend {
t.Helper()
// This will be implemented in memory_test.go
// For now, return nil to allow compilation
t.Skip("MemoryBackend implementation pending")
return nil
}
func setupRedisBackend(t *testing.T) CacheBackend {
t.Helper()
// This will be implemented in redis_test.go
// For now, return nil to allow compilation
t.Skip("RedisBackend implementation pending")
return nil
}
func setupHybridBackend(t *testing.T) CacheBackend {
t.Helper()
// This will be implemented in hybrid_test.go
// For now, return nil to allow compilation
t.Skip("HybridBackend implementation pending")
return nil
}
+516
View File
@@ -0,0 +1,516 @@
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"container/list"
"context"
"sync"
"sync/atomic"
"time"
)
// memoryCacheItem represents an item in the memory cache
type memoryCacheItem struct {
key string
value interface{}
expiresAt time.Time
createdAt time.Time
accessedAt time.Time
accessCount int64
size int64
element *list.Element // for LRU tracking
}
// isExpired checks if the item is expired
func (item *memoryCacheItem) isExpired() bool {
if item.expiresAt.IsZero() {
return false
}
return time.Now().After(item.expiresAt)
}
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
type MemoryCacheBackend struct {
mu sync.RWMutex
items map[string]*memoryCacheItem
lruList *list.List
maxSize int64
maxMemory int64
currentSize int64
currentMemory int64
// Statistics
hits atomic.Int64
misses atomic.Int64
sets atomic.Int64
deletes atomic.Int64
evictions atomic.Int64
errors atomic.Int64
// Latency tracking
totalGetTime atomic.Int64
totalSetTime atomic.Int64
getCount atomic.Int64
setCount atomic.Int64
// Status
startTime time.Time
lastError string
lastErrorTime time.Time
cleanupTicker *time.Ticker
cleanupDone chan bool
closed atomic.Bool
// Configuration
cleanupInterval time.Duration
evictionPolicy string // "lru", "lfu", "fifo"
}
// NewMemoryCacheBackend creates a new memory cache backend
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
if maxSize <= 0 {
maxSize = 10000 // Default to 10k items
}
if maxMemory <= 0 {
maxMemory = 100 * 1024 * 1024 // Default to 100MB
}
if cleanupInterval <= 0 {
cleanupInterval = 5 * time.Minute
}
m := &MemoryCacheBackend{
items: make(map[string]*memoryCacheItem),
lruList: list.New(),
maxSize: maxSize,
maxMemory: maxMemory,
startTime: time.Now(),
cleanupInterval: cleanupInterval,
evictionPolicy: "lru",
cleanupDone: make(chan bool),
}
// Start cleanup goroutine
m.cleanupTicker = time.NewTicker(cleanupInterval)
go m.cleanupLoop()
return m
}
// cleanupLoop runs periodic cleanup of expired items
func (m *MemoryCacheBackend) cleanupLoop() {
for {
select {
case <-m.cleanupTicker.C:
m.cleanupExpired()
case <-m.cleanupDone:
return
}
}
}
// cleanupExpired removes all expired items from the cache
func (m *MemoryCacheBackend) cleanupExpired() {
m.mu.Lock()
defer m.mu.Unlock()
var keysToDelete []string
for key, item := range m.items {
if item.isExpired() {
keysToDelete = append(keysToDelete, key)
}
}
for _, key := range keysToDelete {
m.deleteItemLocked(key)
}
}
// Get retrieves a value from the cache
func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{}, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
start := time.Now()
defer func() {
duration := time.Since(start).Nanoseconds()
m.totalGetTime.Add(duration)
m.getCount.Add(1)
}()
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists {
m.misses.Add(1)
return nil, ErrCacheMiss
}
if item.isExpired() {
m.mu.Lock()
m.deleteItemLocked(key)
m.mu.Unlock()
m.misses.Add(1)
return nil, ErrCacheMiss
}
// Update access time and count
m.mu.Lock()
item.accessedAt = time.Now()
item.accessCount++
// Move to front of LRU list
if m.evictionPolicy == "lru" && item.element != nil {
m.lruList.MoveToFront(item.element)
}
m.mu.Unlock()
m.hits.Add(1)
return item.value, nil
}
// Set stores a value in the cache with optional TTL
func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
start := time.Now()
defer func() {
duration := time.Since(start).Nanoseconds()
m.totalSetTime.Add(duration)
m.setCount.Add(1)
}()
// Calculate item size (simplified estimation)
itemSize := int64(len(key)) + estimateValueSize(value)
m.mu.Lock()
defer m.mu.Unlock()
// Check if we need to evict items
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
m.evictLocked()
}
// Check if key exists
if oldItem, exists := m.items[key]; exists {
m.currentMemory -= oldItem.size
if oldItem.element != nil {
m.lruList.Remove(oldItem.element)
}
} else {
m.currentSize++
}
now := time.Now()
var expiresAt time.Time
if ttl > 0 {
expiresAt = now.Add(ttl)
}
item := &memoryCacheItem{
key: key,
value: value,
expiresAt: expiresAt,
createdAt: now,
accessedAt: now,
accessCount: 0,
size: itemSize,
}
// Add to LRU list
if m.evictionPolicy == "lru" {
item.element = m.lruList.PushFront(item)
}
m.items[key] = item
m.currentMemory += itemSize
m.sets.Add(1)
return nil
}
// Delete removes a key from the cache
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.items[key]; !exists {
return nil
}
m.deleteItemLocked(key)
m.deletes.Add(1)
return nil
}
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
if item, exists := m.items[key]; exists {
m.currentMemory -= item.size
m.currentSize--
if item.element != nil {
m.lruList.Remove(item.element)
}
delete(m.items, key)
}
}
// evictLocked evicts items based on the eviction policy (must be called with lock held)
func (m *MemoryCacheBackend) evictLocked() {
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
// Evict least recently used item
element := m.lruList.Back()
if element != nil {
item := element.Value.(*memoryCacheItem)
m.deleteItemLocked(item.key)
m.evictions.Add(1)
}
}
}
// Exists checks if a key exists in the cache
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
if m.closed.Load() {
return false, ErrBackendUnavailable
}
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists {
return false, nil
}
return !item.isExpired(), nil
}
// Clear removes all items from the cache
func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
m.items = make(map[string]*memoryCacheItem)
m.lruList = list.New()
m.currentSize = 0
m.currentMemory = 0
return nil
}
// Keys returns all keys matching the pattern (use "*" for all keys)
func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
m.mu.RLock()
defer m.mu.RUnlock()
var keys []string
for key, item := range m.items {
if !item.isExpired() && matchPattern(pattern, key) {
keys = append(keys, key)
}
}
return keys, nil
}
// Size returns the number of items in the cache
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
if m.closed.Load() {
return 0, ErrBackendUnavailable
}
m.mu.RLock()
defer m.mu.RUnlock()
return m.currentSize, nil
}
// TTL returns the remaining time-to-live for a key
func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration, error) {
if m.closed.Load() {
return 0, ErrBackendUnavailable
}
m.mu.RLock()
item, exists := m.items[key]
m.mu.RUnlock()
if !exists || item.isExpired() {
return 0, ErrCacheMiss
}
if item.expiresAt.IsZero() {
return 0, nil // No expiration
}
remaining := time.Until(item.expiresAt)
if remaining < 0 {
return 0, nil
}
return remaining, nil
}
// Expire updates the TTL for an existing key
func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Duration) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
m.mu.Lock()
defer m.mu.Unlock()
item, exists := m.items[key]
if !exists || item.isExpired() {
return ErrCacheMiss
}
if ttl > 0 {
item.expiresAt = time.Now().Add(ttl)
} else {
item.expiresAt = time.Time{} // Remove expiration
}
return nil
}
// GetStats returns statistics about the cache backend
func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error) {
if m.closed.Load() {
return nil, ErrBackendUnavailable
}
m.mu.RLock()
lastError := m.lastError
lastErrorTime := m.lastErrorTime
m.mu.RUnlock()
avgGetLatency := time.Duration(0)
if getCount := m.getCount.Load(); getCount > 0 {
avgGetLatency = time.Duration(m.totalGetTime.Load() / getCount)
}
avgSetLatency := time.Duration(0)
if setCount := m.setCount.Load(); setCount > 0 {
avgSetLatency = time.Duration(m.totalSetTime.Load() / setCount)
}
return &BackendStats{
Type: TypeMemory,
Hits: m.hits.Load(),
Misses: m.misses.Load(),
Sets: m.sets.Load(),
Deletes: m.deletes.Load(),
Errors: m.errors.Load(),
Evictions: m.evictions.Load(),
CurrentSize: m.currentSize,
MaxSize: m.maxSize,
MemoryUsage: m.currentMemory,
AverageGetLatency: avgGetLatency,
AverageSetLatency: avgSetLatency,
LastError: lastError,
LastErrorTime: lastErrorTime,
Uptime: time.Since(m.startTime),
StartTime: m.startTime,
}, nil
}
// Ping checks if the backend is healthy
func (m *MemoryCacheBackend) Ping(ctx context.Context) error {
if m.closed.Load() {
return ErrBackendUnavailable
}
return nil
}
// Close closes the backend and releases resources
func (m *MemoryCacheBackend) Close() error {
if m.closed.Swap(true) {
return nil // Already closed
}
m.cleanupTicker.Stop()
close(m.cleanupDone)
m.mu.Lock()
m.items = nil
m.lruList = nil
m.mu.Unlock()
return nil
}
// IsHealthy returns true if the backend is healthy
func (m *MemoryCacheBackend) IsHealthy() bool {
return !m.closed.Load()
}
// Type returns the backend type
func (m *MemoryCacheBackend) Type() BackendType {
return TypeMemory
}
// Capabilities returns the backend capabilities
func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
return &BackendCapabilities{
Distributed: false,
Persistent: false,
Eviction: true,
TTL: true,
MaxKeySize: 1024, // 1KB
MaxValueSize: 10485760, // 10MB
MaxKeys: m.maxSize,
SupportsExpire: true,
SupportsMultiGet: true,
SupportsTransaction: false,
SupportsCompression: false,
RequiresSerialize: false,
}
}
// Helper functions
// estimateValueSize estimates the size of a value in bytes
func estimateValueSize(value interface{}) int64 {
// This is a simplified estimation
// In production, you might want to use a more accurate method
switch v := value.(type) {
case string:
return int64(len(v))
case []byte:
return int64(len(v))
case int, int32, int64, uint, uint32, uint64:
return 8
case float32, float64:
return 8
case bool:
return 1
default:
// For complex types, use a default estimate
return 256
}
}
// matchPattern checks if a key matches a pattern (simplified glob matching)
func matchPattern(pattern, key string) bool {
if pattern == "*" {
return true
}
// Simplified pattern matching - in production, use a proper glob library
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
}
+501
View File
@@ -0,0 +1,501 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestMemoryBackend_BasicOperations tests basic CRUD operations
func TestMemoryBackend_BasicOperations(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetAndGet", func(t *testing.T) {
key := "test-key"
value := []byte("test-value")
ttl := 1 * time.Minute
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
assert.Greater(t, remainingTTL, 50*time.Second)
assert.LessOrEqual(t, remainingTTL, ttl)
})
t.Run("GetNonExistent", func(t *testing.T) {
_, _, exists, err := backend.Get(ctx, "non-existent")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Delete", func(t *testing.T) {
key := "delete-key"
value := []byte("delete-value")
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("DeleteNonExistent", func(t *testing.T) {
deleted, err := backend.Delete(ctx, "non-existent-delete")
require.NoError(t, err)
assert.False(t, deleted)
})
t.Run("Exists", func(t *testing.T) {
key := "exists-key"
value := []byte("exists-value")
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
})
t.Run("Clear", func(t *testing.T) {
// Add multiple items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
err := backend.Clear(ctx)
require.NoError(t, err)
stats := backend.GetStats()
size := stats["size"].(int64)
assert.Equal(t, int64(0), size)
})
}
// TestMemoryBackend_TTLExpiration tests TTL and expiration
func TestMemoryBackend_TTLExpiration(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.CleanupInterval = 50 * time.Millisecond
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("ShortTTL", func(t *testing.T) {
key := "short-ttl-key"
value := []byte("short-ttl-value")
shortTTL := 100 * time.Millisecond
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Verify exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Wait for expiration
time.Sleep(150 * time.Millisecond)
// Should be expired
_, _, exists, err = backend.Get(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("TTLDecrement", func(t *testing.T) {
key := "ttl-decrement-key"
value := []byte("ttl-decrement-value")
ttl := 2 * time.Second
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
// Check TTL immediately
_, ttl1, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Wait a bit
time.Sleep(500 * time.Millisecond)
// Check TTL again - should be less
_, ttl2, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Less(t, ttl2, ttl1, "TTL should decrease over time")
})
t.Run("CleanupExpiredItems", func(t *testing.T) {
// Set multiple items with short TTL
for i := 0; i < 5; i++ {
key := fmt.Sprintf("cleanup-key-%d", i)
value := []byte(fmt.Sprintf("cleanup-value-%d", i))
err := backend.Set(ctx, key, value, 50*time.Millisecond)
require.NoError(t, err)
}
// Wait for cleanup to run
time.Sleep(200 * time.Millisecond)
// All items should be cleaned up
for i := 0; i < 5; i++ {
key := fmt.Sprintf("cleanup-key-%d", i)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists, "Expired items should be cleaned up")
}
})
}
// TestMemoryBackend_LRUEviction tests LRU eviction
func TestMemoryBackend_LRUEviction(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxSize = 5
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Fill cache to max size
for i := 0; i < 5; i++ {
key := fmt.Sprintf("lru-key-%d", i)
value := []byte(fmt.Sprintf("lru-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Access first item to make it most recently used
_, _, exists, err := backend.Get(ctx, "lru-key-0")
require.NoError(t, err)
assert.True(t, exists)
// Add a new item - should evict lru-key-1 (least recently used)
err = backend.Set(ctx, "lru-key-new", []byte("new-value"), 1*time.Minute)
require.NoError(t, err)
// lru-key-0 should still exist (was accessed recently)
exists, err = backend.Exists(ctx, "lru-key-0")
require.NoError(t, err)
assert.True(t, exists, "Recently accessed item should not be evicted")
// lru-key-1 should be evicted
exists, err = backend.Exists(ctx, "lru-key-1")
require.NoError(t, err)
assert.False(t, exists, "Least recently used item should be evicted")
// Check eviction count
stats := backend.GetStats()
evictions := stats["evictions"].(int64)
assert.Greater(t, evictions, int64(0), "Should have evictions")
}
// TestMemoryBackend_MemoryLimit tests memory-based eviction
func TestMemoryBackend_MemoryLimit(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxSize = 100
config.MaxMemoryBytes = 1024 // 1KB limit
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add items until memory limit is reached
largeValue := make([]byte, 512) // 512 bytes each
for i := 0; i < 5; i++ {
key := fmt.Sprintf("mem-key-%d", i)
err := backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
}
stats := backend.GetStats()
memory := stats["memory"].(int64)
assert.LessOrEqual(t, memory, config.MaxMemoryBytes, "Memory should not exceed limit")
evictions := stats["evictions"].(int64)
assert.Greater(t, evictions, int64(0), "Should have memory-based evictions")
}
// TestMemoryBackend_ConcurrentAccess tests thread safety
func TestMemoryBackend_ConcurrentAccess(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
var wg sync.WaitGroup
goroutines := 20
iterations := 50
// Concurrent writes
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
// Read back
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
// Random deletes
if j%5 == 0 {
backend.Delete(ctx, key)
}
}
}(i)
}
wg.Wait()
// Verify stats are consistent
stats := backend.GetStats()
hits := stats["hits"].(int64)
misses := stats["misses"].(int64)
assert.Greater(t, hits+misses, int64(0), "Should have cache operations")
}
// TestMemoryBackend_UpdateExisting tests updating existing keys
func TestMemoryBackend_UpdateExisting(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
// Set original
err = backend.Set(ctx, key, value1, 1*time.Minute)
require.NoError(t, err)
// Update
err = backend.Set(ctx, key, value2, 2*time.Minute)
require.NoError(t, err)
// Verify updated
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved)
assert.Greater(t, ttl, 1*time.Minute, "TTL should be updated")
// Size should not increase (same key)
stats := backend.GetStats()
size := stats["size"].(int64)
assert.Equal(t, int64(1), size, "Size should be 1 for one key")
}
// TestMemoryBackend_Stats tests statistics tracking
func TestMemoryBackend_Stats(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Initial stats
stats := backend.GetStats()
assert.Equal(t, int64(0), stats["hits"].(int64))
assert.Equal(t, int64(0), stats["misses"].(int64))
// Add items and track hits/misses
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
backend.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
// Hit
backend.Get(ctx, "key1")
// Miss
backend.Get(ctx, "non-existent")
stats = backend.GetStats()
assert.Equal(t, int64(1), stats["hits"].(int64))
assert.Equal(t, int64(1), stats["misses"].(int64))
hitRate := stats["hit_rate"].(float64)
assert.InDelta(t, 0.5, hitRate, 0.01)
}
// TestMemoryBackend_EmptyValues tests handling of empty values
func TestMemoryBackend_EmptyValues(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "empty-key"
emptyValue := []byte{}
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, 0, len(retrieved))
}
// TestMemoryBackend_LargeValues tests handling of large values
func TestMemoryBackend_LargeValues(t *testing.T) {
t.Parallel()
config := DefaultConfig()
config.MaxMemoryBytes = 10 * 1024 * 1024 // 10MB
backend, err := NewMemoryBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "large-key"
largeValue := make([]byte, 1024*1024) // 1MB
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(largeValue), len(retrieved))
}
// TestMemoryBackend_Close tests proper cleanup on close
func TestMemoryBackend_Close(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
ctx := context.Background()
// Add some items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("close-key-%d", i)
value := []byte(fmt.Sprintf("close-value-%d", i))
backend.Set(ctx, key, value, 1*time.Minute)
}
// Close
err = backend.Close()
require.NoError(t, err)
// Operations after close should fail
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
_, _, _, err = backend.Get(ctx, "close-key-0")
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
// Closing again should be safe
err = backend.Close()
assert.NoError(t, err)
}
// TestMemoryBackend_Ping tests ping operation
func TestMemoryBackend_Ping(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
err = backend.Ping(ctx)
assert.NoError(t, err)
// Close and ping should fail
backend.Close()
err = backend.Ping(ctx)
assert.Error(t, err)
}
// TestMemoryBackend_ValueIsolation tests that returned values are isolated
func TestMemoryBackend_ValueIsolation(t *testing.T) {
t.Parallel()
backend, err := NewMemoryBackend(DefaultConfig())
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "isolation-key"
originalValue := []byte("original-value")
err = backend.Set(ctx, key, originalValue, 1*time.Minute)
require.NoError(t, err)
// Get value and modify it
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Modify retrieved value
if len(retrieved) > 0 {
retrieved[0] = 'X'
}
// Get again - should be unchanged
retrieved2, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, originalValue, retrieved2, "Original value should not be modified")
}
+153
View File
@@ -0,0 +1,153 @@
package backends
import (
"context"
"time"
)
// MemoryBackend wraps MemoryCacheBackend to implement the CacheBackend interface
type MemoryBackend struct {
*MemoryCacheBackend
}
// NewMemoryBackend creates a new memory backend from a config
func NewMemoryBackend(config *Config) (*MemoryBackend, error) {
maxSize := int64(config.MaxSize)
if maxSize <= 0 {
maxSize = 1000
}
cacheBackend := NewMemoryCacheBackend(maxSize, config.MaxMemoryBytes, config.CleanupInterval)
return &MemoryBackend{
MemoryCacheBackend: cacheBackend,
}, nil
}
// Set stores a value in the cache with the specified TTL
func (m *MemoryBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
err := m.MemoryCacheBackend.Set(ctx, key, value, ttl)
if err == ErrBackendUnavailable {
return ErrBackendClosed
}
return err
}
// Get retrieves a value from the cache
func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
val, err := m.MemoryCacheBackend.Get(ctx, key)
if err != nil {
if err == ErrCacheMiss {
return nil, 0, false, nil
}
if err == ErrBackendUnavailable {
return nil, 0, false, ErrBackendClosed
}
return nil, 0, false, err
}
// Get the item directly to check TTL
m.MemoryCacheBackend.mu.RLock()
item, exists := m.MemoryCacheBackend.items[key]
m.MemoryCacheBackend.mu.RUnlock()
if !exists {
return nil, 0, false, nil
}
var ttl time.Duration
if !item.expiresAt.IsZero() {
ttl = time.Until(item.expiresAt)
if ttl < 0 {
ttl = 0
}
}
// Convert interface{} to []byte
var valueBytes []byte
if val != nil {
if bytes, ok := val.([]byte); ok {
valueBytes = bytes
} else {
// If it's not already []byte, we might need to handle other types
// For now, we'll just return an error
return nil, 0, false, ErrInvalidValue
}
}
return valueBytes, ttl, true, nil
}
// Delete removes a key from the cache
func (m *MemoryBackend) Delete(ctx context.Context, key string) (bool, error) {
// Check if key exists first
exists, err := m.MemoryCacheBackend.Exists(ctx, key)
if err != nil {
return false, err
}
if !exists {
return false, nil
}
err = m.MemoryCacheBackend.Delete(ctx, key)
if err != nil {
return false, err
}
return true, nil
}
// Exists checks if a key exists in the cache
func (m *MemoryBackend) Exists(ctx context.Context, key string) (bool, error) {
return m.MemoryCacheBackend.Exists(ctx, key)
}
// Clear removes all keys from the cache
func (m *MemoryBackend) Clear(ctx context.Context) error {
return m.MemoryCacheBackend.Clear(ctx)
}
// GetStats returns cache statistics
func (m *MemoryBackend) GetStats() map[string]interface{} {
stats, err := m.MemoryCacheBackend.GetStats(context.Background())
if err != nil {
return map[string]interface{}{
"error": err.Error(),
}
}
// Convert BackendStats to map
hitRate := float64(0)
total := stats.Hits + stats.Misses
if total > 0 {
hitRate = float64(stats.Hits) / float64(total)
}
return map[string]interface{}{
"type": stats.Type,
"hits": stats.Hits,
"misses": stats.Misses,
"sets": stats.Sets,
"deletes": stats.Deletes,
"errors": stats.Errors,
"evictions": stats.Evictions,
"size": stats.CurrentSize,
"max_size": stats.MaxSize,
"memory": stats.MemoryUsage,
"hit_rate": hitRate,
"uptime": stats.Uptime,
"start_time": stats.StartTime,
}
}
// Close shuts down the cache backend and releases resources
func (m *MemoryBackend) Close() error {
return m.MemoryCacheBackend.Close()
}
// Ping checks if the backend is healthy and responsive
func (m *MemoryBackend) Ping(ctx context.Context) error {
return m.MemoryCacheBackend.Ping(ctx)
}
// Ensure MemoryBackend implements CacheBackend
var _ CacheBackend = (*MemoryBackend)(nil)
+277
View File
@@ -0,0 +1,277 @@
package backends
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
)
// RedisBackend implements a Redis-based cache backend
type RedisBackend struct {
client *redis.Client
config *Config
// Metrics
hits int64
misses int64
// Lifecycle
closed atomic.Bool
}
// NewRedisBackend creates a new Redis cache backend
func NewRedisBackend(config *Config) (*RedisBackend, error) {
if config == nil {
return nil, fmt.Errorf("config is required")
}
if config.RedisAddr == "" {
return nil, fmt.Errorf("redis address is required")
}
client := redis.NewClient(&redis.Options{
Addr: config.RedisAddr,
Password: config.RedisPassword,
DB: config.RedisDB,
PoolSize: config.PoolSize,
})
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
client.Close()
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
backend := &RedisBackend{
client: client,
config: config,
}
return backend, nil
}
// Set stores a value with TTL
func (r *RedisBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if r.closed.Load() {
return ErrBackendClosed
}
prefixedKey := r.prefixKey(key)
return r.client.Set(ctx, prefixedKey, value, ttl).Err()
}
// Get retrieves a value
func (r *RedisBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
if r.closed.Load() {
return nil, 0, false, ErrBackendClosed
}
prefixedKey := r.prefixKey(key)
// Get value
value, err := r.client.Get(ctx, prefixedKey).Bytes()
if err == redis.Nil {
atomic.AddInt64(&r.misses, 1)
return nil, 0, false, nil
}
if err != nil {
atomic.AddInt64(&r.misses, 1)
return nil, 0, false, fmt.Errorf("failed to get key: %w", err)
}
// Get TTL
ttl, err := r.client.TTL(ctx, prefixedKey).Result()
if err != nil {
// Value exists but couldn't get TTL
atomic.AddInt64(&r.hits, 1)
return value, 0, true, nil
}
atomic.AddInt64(&r.hits, 1)
return value, ttl, true, nil
}
// Delete removes a key
func (r *RedisBackend) Delete(ctx context.Context, key string) (bool, error) {
if r.closed.Load() {
return false, ErrBackendClosed
}
prefixedKey := r.prefixKey(key)
result, err := r.client.Del(ctx, prefixedKey).Result()
if err != nil {
return false, fmt.Errorf("failed to delete key: %w", err)
}
return result > 0, nil
}
// Exists checks if a key exists
func (r *RedisBackend) Exists(ctx context.Context, key string) (bool, error) {
if r.closed.Load() {
return false, ErrBackendClosed
}
prefixedKey := r.prefixKey(key)
result, err := r.client.Exists(ctx, prefixedKey).Result()
if err != nil {
return false, fmt.Errorf("failed to check existence: %w", err)
}
return result > 0, nil
}
// Clear removes all keys with the prefix
func (r *RedisBackend) Clear(ctx context.Context) error {
if r.closed.Load() {
return ErrBackendClosed
}
// Use SCAN to find all keys with prefix
pattern := r.config.RedisPrefix + "*"
iter := r.client.Scan(ctx, 0, pattern, 0).Iterator()
var keys []string
for iter.Next(ctx) {
keys = append(keys, iter.Val())
}
if err := iter.Err(); err != nil {
return fmt.Errorf("failed to scan keys: %w", err)
}
if len(keys) == 0 {
return nil
}
// Delete in batches to avoid blocking Redis
batchSize := 100
for i := 0; i < len(keys); i += batchSize {
end := i + batchSize
if end > len(keys) {
end = len(keys)
}
if err := r.client.Del(ctx, keys[i:end]...).Err(); err != nil {
return fmt.Errorf("failed to delete keys: %w", err)
}
}
return nil
}
// GetStats returns statistics
func (r *RedisBackend) GetStats() map[string]interface{} {
hits := atomic.LoadInt64(&r.hits)
misses := atomic.LoadInt64(&r.misses)
total := hits + misses
hitRate := 0.0
if total > 0 {
hitRate = float64(hits) / float64(total)
}
stats := map[string]interface{}{
"type": "redis",
"hits": hits,
"misses": misses,
"hit_rate": hitRate,
}
// Try to get Redis info (non-critical)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
if info, err := r.client.Info(ctx, "memory").Result(); err == nil {
stats["redis_info"] = info
}
return stats
}
// Ping checks Redis health
func (r *RedisBackend) Ping(ctx context.Context) error {
if r.closed.Load() {
return ErrBackendClosed
}
return r.client.Ping(ctx).Err()
}
// Close shuts down the backend
func (r *RedisBackend) Close() error {
if !r.closed.CompareAndSwap(false, true) {
return nil // Already closed
}
return r.client.Close()
}
// prefixKey adds the configured prefix to a key
func (r *RedisBackend) prefixKey(key string) string {
if r.config.RedisPrefix == "" {
return key
}
return r.config.RedisPrefix + key
}
// Pipeline operations for batch operations (future enhancement)
// SetMany stores multiple key-value pairs in a pipeline
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
if r.closed.Load() {
return ErrBackendClosed
}
pipe := r.client.Pipeline()
for key, value := range items {
prefixedKey := r.prefixKey(key)
pipe.Set(ctx, prefixedKey, value, ttl)
}
_, err := pipe.Exec(ctx)
return err
}
// GetMany retrieves multiple values in a pipeline
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
if r.closed.Load() {
return nil, ErrBackendClosed
}
if len(keys) == 0 {
return make(map[string][]byte), nil
}
pipe := r.client.Pipeline()
cmds := make(map[string]*redis.StringCmd, len(keys))
for _, key := range keys {
prefixedKey := r.prefixKey(key)
cmds[key] = pipe.Get(ctx, prefixedKey)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return nil, err
}
results := make(map[string][]byte)
for key, cmd := range cmds {
value, err := cmd.Bytes()
if err == nil {
results[key] = value
atomic.AddInt64(&r.hits, 1)
} else if err != redis.Nil {
atomic.AddInt64(&r.misses, 1)
}
}
return results, nil
}
+545
View File
@@ -0,0 +1,545 @@
package backends
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRedisBackend_BasicOperations tests basic Redis operations
func TestRedisBackend_BasicOperations(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetAndGet", func(t *testing.T) {
key := "redis-test-key"
value := []byte("redis-test-value")
ttl := 1 * time.Minute
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
assert.Greater(t, remainingTTL, 50*time.Second)
})
t.Run("GetNonExistent", func(t *testing.T) {
_, _, exists, err := backend.Get(ctx, "non-existent-redis-key")
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Delete", func(t *testing.T) {
key := "redis-delete-key"
value := []byte("redis-delete-value")
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
deleted, err := backend.Delete(ctx, key)
require.NoError(t, err)
assert.True(t, deleted)
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("Exists", func(t *testing.T) {
key := "redis-exists-key"
value := []byte("redis-exists-value")
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
})
}
// TestRedisBackend_KeyPrefixing tests key namespace prefixing
func TestRedisBackend_KeyPrefixing(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "test:prefix:"
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "my-key"
value := []byte("my-value")
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check that key is stored with prefix
keys := mr.CheckKeys()
require.Len(t, keys, 1)
assert.Equal(t, "test:prefix:my-key", keys[0])
// Get should work without prefix
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
}
// TestRedisBackend_TTLExpiration tests TTL handling
func TestRedisBackend_TTLExpiration(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("ShortTTL", func(t *testing.T) {
key := "ttl-key"
value := []byte("ttl-value")
shortTTL := 100 * time.Millisecond
err := backend.Set(ctx, key, value, shortTTL)
require.NoError(t, err)
// Exists immediately
exists, err := backend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Fast forward time in miniredis
mr.FastForward(150 * time.Millisecond)
// Should be expired
exists, err = backend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
t.Run("TTLRemaining", func(t *testing.T) {
key := "ttl-remaining-key"
value := []byte("ttl-remaining-value")
ttl := 10 * time.Second
err := backend.Set(ctx, key, value, ttl)
require.NoError(t, err)
// Get immediately
_, ttl1, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
// Fast forward 2 seconds
mr.FastForward(2 * time.Second)
// Check TTL is less
_, ttl2, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Less(t, ttl2, ttl1)
})
}
// TestRedisBackend_Clear tests clearing all keys
func TestRedisBackend_Clear(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "clear-test:"
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Add multiple keys
for i := 0; i < 10; i++ {
key := fmt.Sprintf("clear-key-%d", i)
value := []byte(fmt.Sprintf("clear-value-%d", i))
err := backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Verify keys exist
keys := mr.CheckKeys()
assert.Len(t, keys, 10)
// Clear all
err = backend.Clear(ctx)
require.NoError(t, err)
// Verify all keys are gone
keys = mr.CheckKeys()
assert.Len(t, keys, 0)
}
// TestRedisBackend_ConnectionFailure tests behavior on connection failure
func TestRedisBackend_ConnectionFailure(t *testing.T) {
t.Parallel()
// Try to connect to non-existent Redis
config := DefaultRedisConfig("localhost:9999")
_, err := NewRedisBackend(config)
assert.Error(t, err, "Should fail to connect to non-existent Redis")
}
// TestRedisBackend_RedisErrors tests handling of Redis errors
func TestRedisBackend_RedisErrors(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Simulate Redis error
mr.SetError("simulated error")
// Operations should fail
err = backend.Set(ctx, "error-key", []byte("error-value"), 1*time.Minute)
assert.Error(t, err)
// Clear error
mr.ClearError()
// Operations should work again
err = backend.Set(ctx, "success-key", []byte("success-value"), 1*time.Minute)
assert.NoError(t, err)
}
// TestRedisBackend_ConcurrentAccess tests thread safety
func TestRedisBackend_ConcurrentAccess(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
var wg sync.WaitGroup
goroutines := 20
iterations := 50
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
err := backend.Set(ctx, key, value, 1*time.Minute)
assert.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
assert.NoError(t, err)
if exists {
assert.Equal(t, value, retrieved)
}
if j%5 == 0 {
backend.Delete(ctx, key)
}
}
}(i)
}
wg.Wait()
stats := backend.GetStats()
hits := stats["hits"].(int64)
misses := stats["misses"].(int64)
assert.Greater(t, hits+misses, int64(0))
}
// TestRedisBackend_Stats tests statistics tracking
func TestRedisBackend_Stats(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
// Initial stats
stats := backend.GetStats()
assert.Equal(t, int64(0), stats["hits"].(int64))
assert.Equal(t, int64(0), stats["misses"].(int64))
// Add and access items
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
backend.Get(ctx, "key1") // Hit
backend.Get(ctx, "non-existent") // Miss
stats = backend.GetStats()
assert.Equal(t, int64(1), stats["hits"].(int64))
assert.Equal(t, int64(1), stats["misses"].(int64))
hitRate := stats["hit_rate"].(float64)
assert.InDelta(t, 0.5, hitRate, 0.01)
}
// TestRedisBackend_Ping tests health check
func TestRedisBackend_Ping(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
err = backend.Ping(ctx)
assert.NoError(t, err)
// Close and ping should fail
backend.Close()
err = backend.Ping(ctx)
assert.Error(t, err)
}
// TestRedisBackend_Close tests proper cleanup
func TestRedisBackend_Close(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
ctx := context.Background()
// Add items
for i := 0; i < 10; i++ {
key := fmt.Sprintf("close-key-%d", i)
value := []byte(fmt.Sprintf("close-value-%d", i))
backend.Set(ctx, key, value, 1*time.Minute)
}
// Close
err = backend.Close()
require.NoError(t, err)
// Operations should fail
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
assert.Error(t, err)
assert.Equal(t, ErrBackendClosed, err)
// Double close should be safe
err = backend.Close()
assert.NoError(t, err)
}
// TestRedisBackend_UpdateExisting tests updating existing keys
func TestRedisBackend_UpdateExisting(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "update-key"
value1 := []byte("original-value")
value2 := []byte("updated-value")
// Set original
err = backend.Set(ctx, key, value1, 1*time.Minute)
require.NoError(t, err)
// Update
err = backend.Set(ctx, key, value2, 2*time.Minute)
require.NoError(t, err)
// Verify updated
retrieved, ttl, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value2, retrieved)
assert.Greater(t, ttl, 1*time.Minute)
}
// TestRedisBackend_LargeValues tests handling of large values
func TestRedisBackend_LargeValues(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "large-key"
largeValue := make([]byte, 1024*1024) // 1MB
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, len(largeValue), len(retrieved))
}
// TestRedisBackend_EmptyValues tests handling of empty values
func TestRedisBackend_EmptyValues(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "empty-key"
emptyValue := []byte{}
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
require.NoError(t, err)
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, 0, len(retrieved))
}
// TestRedisBackend_PipelineOperations tests batch operations
func TestRedisBackend_PipelineOperations(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
t.Run("SetMany", func(t *testing.T) {
items := make(map[string][]byte)
for i := 0; i < 10; i++ {
key := fmt.Sprintf("batch-key-%d", i)
value := []byte(fmt.Sprintf("batch-value-%d", i))
items[key] = value
}
err := backend.SetMany(ctx, items, 1*time.Minute)
require.NoError(t, err)
// Verify all items were set
for key, expectedValue := range items {
retrieved, _, exists, err := backend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, expectedValue, retrieved)
}
})
t.Run("GetMany", func(t *testing.T) {
// Set test data
testData := GenerateTestData(5)
for key, value := range testData {
backend.Set(ctx, key, value, 1*time.Minute)
}
// Get all keys
keys := make([]string, 0, len(testData))
for key := range testData {
keys = append(keys, key)
}
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, len(testData))
for key, expectedValue := range testData {
retrievedValue, exists := results[key]
assert.True(t, exists)
assert.Equal(t, expectedValue, retrievedValue)
}
})
t.Run("GetManyWithNonExistent", func(t *testing.T) {
keys := []string{"exists-1", "non-existent", "exists-2"}
backend.Set(ctx, "exists-1", []byte("value-1"), 1*time.Minute)
backend.Set(ctx, "exists-2", []byte("value-2"), 1*time.Minute)
results, err := backend.GetMany(ctx, keys)
require.NoError(t, err)
assert.Len(t, results, 2) // Only existing keys
assert.Equal(t, []byte("value-1"), results["exists-1"])
assert.Equal(t, []byte("value-2"), results["exists-2"])
_, exists := results["non-existent"]
assert.False(t, exists)
})
}
// TestRedisBackend_NoPrefix tests operation without prefix
func TestRedisBackend_NoPrefix(t *testing.T) {
t.Parallel()
mr := NewMiniredisServer(t)
config := DefaultRedisConfig(mr.GetAddr())
config.RedisPrefix = "" // No prefix
backend, err := NewRedisBackend(config)
require.NoError(t, err)
defer backend.Close()
ctx := context.Background()
key := "no-prefix-key"
value := []byte("no-prefix-value")
err = backend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Check key is stored without prefix
keys := mr.CheckKeys()
require.Len(t, keys, 1)
assert.Equal(t, key, keys[0])
}
+184
View File
@@ -0,0 +1,184 @@
package backends
import (
"context"
"fmt"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
// TestLogger implements a simple logger for tests
type TestLogger struct {
t *testing.T
}
func NewTestLogger(t *testing.T) *TestLogger {
return &TestLogger{t: t}
}
func (l *TestLogger) Debug(format string, args ...interface{}) {
l.t.Logf("[DEBUG] "+format, args...)
}
func (l *TestLogger) Info(format string, args ...interface{}) {
l.t.Logf("[INFO] "+format, args...)
}
func (l *TestLogger) Error(format string, args ...interface{}) {
l.t.Logf("[ERROR] "+format, args...)
}
func (l *TestLogger) Debugf(format string, args ...interface{}) {
l.Debug(format, args...)
}
func (l *TestLogger) Infof(format string, args ...interface{}) {
l.Info(format, args...)
}
func (l *TestLogger) Errorf(format string, args ...interface{}) {
l.Error(format, args...)
}
// MiniredisServer manages a miniredis instance for testing
type MiniredisServer struct {
server *miniredis.Miniredis
client *redis.Client
}
// NewMiniredisServer creates a new miniredis server for testing
func NewMiniredisServer(t *testing.T) *MiniredisServer {
t.Helper()
mr, err := miniredis.Run()
require.NoError(t, err, "failed to start miniredis")
client := redis.NewClient(&redis.Options{
Addr: mr.Addr(),
})
// Verify connection
ctx := context.Background()
err = client.Ping(ctx).Err()
require.NoError(t, err, "failed to ping miniredis")
t.Cleanup(func() {
client.Close()
mr.Close()
})
return &MiniredisServer{
server: mr,
client: client,
}
}
// GetAddr returns the address of the miniredis server
func (m *MiniredisServer) GetAddr() string {
return m.server.Addr()
}
// GetClient returns the Redis client
func (m *MiniredisServer) GetClient() *redis.Client {
return m.client
}
// FastForward advances the miniredis server's time
func (m *MiniredisServer) FastForward(d time.Duration) {
m.server.FastForward(d)
}
// FlushAll removes all keys from the database
func (m *MiniredisServer) FlushAll() {
m.server.FlushAll()
}
// SetError simulates a Redis error
func (m *MiniredisServer) SetError(err string) {
m.server.SetError(err)
}
// ClearError clears any simulated errors
func (m *MiniredisServer) ClearError() {
m.server.SetError("")
}
// CheckKeys verifies that specific keys exist in Redis
func (m *MiniredisServer) CheckKeys() []string {
return m.server.Keys()
}
// TestConfig provides default test configuration
type TestConfig struct {
MaxSize int
DefaultTTL time.Duration
CleanupInterval time.Duration
EnableMetrics bool
}
// DefaultTestConfig returns a standard test configuration
func DefaultTestConfig() *TestConfig {
return &TestConfig{
MaxSize: 100,
DefaultTTL: 5 * time.Minute,
CleanupInterval: 1 * time.Second,
EnableMetrics: true,
}
}
// GenerateTestData creates test cache data
func GenerateTestData(count int) map[string][]byte {
data := make(map[string][]byte, count)
for i := 0; i < count; i++ {
key := fmt.Sprintf("test-key-%d", i)
value := []byte(fmt.Sprintf("test-value-%d", i))
data[key] = value
}
return data
}
// GenerateLargeValue creates a large test value
func GenerateLargeValue(sizeBytes int) []byte {
return make([]byte, sizeBytes)
}
// AssertCacheStats is a helper to verify cache statistics
func AssertCacheStats(t *testing.T, stats map[string]interface{}, expectedHits, expectedMisses int64) {
t.Helper()
hits, ok := stats["hits"].(int64)
require.True(t, ok, "hits should be int64")
require.Equal(t, expectedHits, hits, "unexpected hit count")
misses, ok := stats["misses"].(int64)
require.True(t, ok, "misses should be int64")
require.Equal(t, expectedMisses, misses, "unexpected miss count")
}
// WaitForCondition waits for a condition to be true or times out
func WaitForCondition(t *testing.T, timeout time.Duration, checkInterval time.Duration, condition func() bool) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if condition() {
return
}
time.Sleep(checkInterval)
}
t.Fatal("timeout waiting for condition")
}
// AssertEventuallyExpires verifies that a key eventually expires
func AssertEventuallyExpires(t *testing.T, backend CacheBackend, ctx context.Context, key string, maxWait time.Duration) {
t.Helper()
WaitForCondition(t, maxWait, 100*time.Millisecond, func() bool {
_, _, exists, err := backend.Get(ctx, key)
return err == nil && !exists
})
}
+329
View File
@@ -0,0 +1,329 @@
// Package resilience provides resilience patterns for cache backends.
package resilience
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
)
// Common errors
var (
// ErrCircuitOpen is returned when the circuit breaker is open
ErrCircuitOpen = errors.New("circuit breaker is open")
// ErrTooManyRequests is returned when too many requests are made in half-open state
ErrTooManyRequests = errors.New("too many requests in half-open state")
)
// State represents the state of the circuit breaker
type State int32
const (
// StateClosed allows all operations to pass through
StateClosed State = iota
// StateOpen blocks all operations
StateOpen
// StateHalfOpen allows a limited number of operations to test recovery
StateHalfOpen
)
// String returns the string representation of the state
func (s State) String() string {
switch s {
case StateClosed:
return "closed"
case StateOpen:
return "open"
case StateHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// CircuitBreakerConfig holds configuration for the circuit breaker
type CircuitBreakerConfig struct {
// MaxFailures is the number of consecutive failures before opening the circuit
MaxFailures int
// FailureThreshold is the failure rate threshold (0.0 to 1.0)
FailureThreshold float64
// Timeout is how long the circuit stays open before trying half-open
Timeout time.Duration
// HalfOpenMaxRequests is the number of requests allowed in half-open state
HalfOpenMaxRequests int
// ResetTimeout is how long to wait before resetting counters in closed state
ResetTimeout time.Duration
// OnStateChange is called when the circuit breaker changes state
OnStateChange func(from, to State)
}
// DefaultCircuitBreakerConfig returns default configuration
func DefaultCircuitBreakerConfig() *CircuitBreakerConfig {
return &CircuitBreakerConfig{
MaxFailures: 5,
FailureThreshold: 0.6,
Timeout: 30 * time.Second,
HalfOpenMaxRequests: 3,
ResetTimeout: 60 * time.Second,
}
}
// CircuitBreaker implements the circuit breaker pattern
type CircuitBreaker struct {
config *CircuitBreakerConfig
// State management
state atomic.Int32
lastStateChange time.Time
stateMu sync.RWMutex
// Failure tracking
consecutiveFailures atomic.Int32
totalRequests atomic.Int64
totalFailures atomic.Int64
halfOpenRequests atomic.Int32
// Timing
lastFailureTime time.Time
lastSuccessTime time.Time
nextRetryTime time.Time
timeMu sync.RWMutex
// Metrics
stateTransitions atomic.Int64
rejectedRequests atomic.Int64
}
// NewCircuitBreaker creates a new circuit breaker
func NewCircuitBreaker(config *CircuitBreakerConfig) *CircuitBreaker {
if config == nil {
config = DefaultCircuitBreakerConfig()
}
return &CircuitBreaker{
config: config,
lastStateChange: time.Now(),
}
}
// Execute runs a function through the circuit breaker
func (cb *CircuitBreaker) Execute(ctx context.Context, fn func() error) error {
if !cb.AllowRequest() {
cb.rejectedRequests.Add(1)
return ErrCircuitOpen
}
cb.totalRequests.Add(1)
err := fn()
if err != nil {
cb.RecordFailure()
} else {
cb.RecordSuccess()
}
return err
}
// AllowRequest checks if a request is allowed to proceed
func (cb *CircuitBreaker) AllowRequest() bool {
state := cb.GetState()
switch state {
case StateClosed:
return true
case StateOpen:
// Check if timeout has passed and we should try half-open
cb.timeMu.RLock()
shouldRetry := time.Now().After(cb.nextRetryTime)
cb.timeMu.RUnlock()
if shouldRetry {
cb.setState(StateHalfOpen)
return true
}
return false
case StateHalfOpen:
// Allow limited requests in half-open state
current := cb.halfOpenRequests.Add(1)
return current <= int32(cb.config.HalfOpenMaxRequests)
default:
return false
}
}
// RecordSuccess records a successful operation
func (cb *CircuitBreaker) RecordSuccess() {
cb.timeMu.Lock()
cb.lastSuccessTime = time.Now()
cb.timeMu.Unlock()
state := cb.GetState()
switch state {
case StateClosed:
// Reset consecutive failures
cb.consecutiveFailures.Store(0)
case StateHalfOpen:
// If we've had enough successful requests, close the circuit
successfulRequests := cb.halfOpenRequests.Load()
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
cb.setState(StateClosed)
cb.consecutiveFailures.Store(0)
cb.halfOpenRequests.Store(0)
}
}
}
// RecordFailure records a failed operation
func (cb *CircuitBreaker) RecordFailure() {
cb.totalFailures.Add(1)
failures := cb.consecutiveFailures.Add(1)
cb.timeMu.Lock()
cb.lastFailureTime = time.Now()
cb.timeMu.Unlock()
state := cb.GetState()
switch state {
case StateClosed:
// Check if we should open the circuit
if failures >= int32(cb.config.MaxFailures) {
cb.openCircuit()
} else if cb.config.FailureThreshold > 0 {
// Check failure rate
total := cb.totalRequests.Load()
failureCount := cb.totalFailures.Load()
if total > 10 && float64(failureCount)/float64(total) > cb.config.FailureThreshold {
cb.openCircuit()
}
}
case StateHalfOpen:
// Any failure in half-open state reopens the circuit
cb.openCircuit()
}
}
// openCircuit transitions to open state
func (cb *CircuitBreaker) openCircuit() {
cb.setState(StateOpen)
cb.halfOpenRequests.Store(0)
cb.timeMu.Lock()
cb.nextRetryTime = time.Now().Add(cb.config.Timeout)
cb.timeMu.Unlock()
}
// GetState returns the current state
func (cb *CircuitBreaker) GetState() State {
return State(cb.state.Load())
}
// setState changes the circuit breaker state
func (cb *CircuitBreaker) setState(newState State) {
oldState := State(cb.state.Swap(int32(newState)))
if oldState != newState {
cb.stateTransitions.Add(1)
cb.stateMu.Lock()
cb.lastStateChange = time.Now()
cb.stateMu.Unlock()
if cb.config.OnStateChange != nil {
cb.config.OnStateChange(oldState, newState)
}
}
}
// Reset resets the circuit breaker to closed state
func (cb *CircuitBreaker) Reset() {
cb.setState(StateClosed)
cb.consecutiveFailures.Store(0)
cb.totalRequests.Store(0)
cb.totalFailures.Store(0)
cb.halfOpenRequests.Store(0)
cb.rejectedRequests.Store(0)
cb.stateTransitions.Store(0)
now := time.Now()
cb.timeMu.Lock()
cb.lastFailureTime = now
cb.lastSuccessTime = now
cb.nextRetryTime = now
cb.timeMu.Unlock()
cb.stateMu.Lock()
cb.lastStateChange = now
cb.stateMu.Unlock()
}
// Stats returns circuit breaker statistics
func (cb *CircuitBreaker) Stats() CircuitBreakerStats {
cb.timeMu.RLock()
lastFailure := cb.lastFailureTime
lastSuccess := cb.lastSuccessTime
nextRetry := cb.nextRetryTime
cb.timeMu.RUnlock()
cb.stateMu.RLock()
lastChange := cb.lastStateChange
cb.stateMu.RUnlock()
totalReq := cb.totalRequests.Load()
totalFail := cb.totalFailures.Load()
successRate := float64(0)
if totalReq > 0 {
successRate = float64(totalReq-totalFail) / float64(totalReq)
}
return CircuitBreakerStats{
State: cb.GetState(),
ConsecutiveFailures: cb.consecutiveFailures.Load(),
TotalRequests: totalReq,
TotalFailures: totalFail,
SuccessRate: successRate,
RejectedRequests: cb.rejectedRequests.Load(),
StateTransitions: cb.stateTransitions.Load(),
LastFailureTime: lastFailure,
LastSuccessTime: lastSuccess,
LastStateChange: lastChange,
NextRetryTime: nextRetry,
}
}
// CircuitBreakerStats holds statistics for the circuit breaker
type CircuitBreakerStats struct {
State State
ConsecutiveFailures int32
TotalRequests int64
TotalFailures int64
SuccessRate float64
RejectedRequests int64
StateTransitions int64
LastFailureTime time.Time
LastSuccessTime time.Time
LastStateChange time.Time
NextRetryTime time.Time
}
// IsHealthy returns true if the circuit breaker is in a healthy state
func (cb *CircuitBreaker) IsHealthy() bool {
return cb.GetState() != StateOpen
}
+141
View File
@@ -0,0 +1,141 @@
// Package resilience provides resilience patterns for cache backends.
package resilience
import (
"context"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
)
// CircuitBreakerBackend wraps a cache backend with circuit breaker protection
type CircuitBreakerBackend struct {
backend backends.CacheBackend
cb *CircuitBreaker
}
// NewCircuitBreakerBackend creates a new circuit breaker wrapped backend
func NewCircuitBreakerBackend(b backends.CacheBackend, config *CircuitBreakerConfig) backends.CacheBackend {
if config == nil {
config = DefaultCircuitBreakerConfig()
}
return &CircuitBreakerBackend{
backend: b,
cb: NewCircuitBreaker(config),
}
}
// Set stores a value with circuit breaker protection
func (c *CircuitBreakerBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if !c.cb.AllowRequest() {
return backends.ErrCircuitOpen
}
err := c.backend.Set(ctx, key, value, ttl)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return err
}
// Get retrieves a value with circuit breaker protection
func (c *CircuitBreakerBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
if !c.cb.AllowRequest() {
return nil, 0, false, backends.ErrCircuitOpen
}
value, ttl, exists, err := c.backend.Get(ctx, key)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return value, ttl, exists, err
}
// Delete removes a key with circuit breaker protection
func (c *CircuitBreakerBackend) Delete(ctx context.Context, key string) (bool, error) {
if !c.cb.AllowRequest() {
return false, backends.ErrCircuitOpen
}
deleted, err := c.backend.Delete(ctx, key)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return deleted, err
}
// Exists checks if a key exists with circuit breaker protection
func (c *CircuitBreakerBackend) Exists(ctx context.Context, key string) (bool, error) {
if !c.cb.AllowRequest() {
return false, backends.ErrCircuitOpen
}
exists, err := c.backend.Exists(ctx, key)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return exists, err
}
// Clear removes all keys with circuit breaker protection
func (c *CircuitBreakerBackend) Clear(ctx context.Context) error {
if !c.cb.AllowRequest() {
return backends.ErrCircuitOpen
}
err := c.backend.Clear(ctx)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return err
}
// GetStats returns statistics including circuit breaker state
func (c *CircuitBreakerBackend) GetStats() map[string]interface{} {
stats := c.backend.GetStats()
if stats == nil {
stats = make(map[string]interface{})
}
cbStats := c.cb.Stats()
stats["circuit_breaker"] = map[string]interface{}{
"state": cbStats.State.String(),
"consecutive_failures": cbStats.ConsecutiveFailures,
"total_requests": cbStats.TotalRequests,
"total_failures": cbStats.TotalFailures,
"success_rate": cbStats.SuccessRate,
}
return stats
}
// Ping checks backend health with circuit breaker protection
func (c *CircuitBreakerBackend) Ping(ctx context.Context) error {
if !c.cb.AllowRequest() {
return backends.ErrCircuitOpen
}
err := c.backend.Ping(ctx)
if err == nil {
c.cb.RecordSuccess()
} else {
c.cb.RecordFailure()
}
return err
}
// Close shuts down the backend
func (c *CircuitBreakerBackend) Close() error {
return c.backend.Close()
}
+553
View File
@@ -0,0 +1,553 @@
package resilience
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TestCircuitBreaker_StateTransitions tests state machine transitions
func TestCircuitBreaker_StateTransitions(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 3,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 2,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
t.Run("Initial state is closed", func(t *testing.T) {
assert.Equal(t, StateClosed, cb.GetState())
})
t.Run("Closed to Open after max failures", func(t *testing.T) {
cb.Reset()
// Simulate failures
for i := 0; i < 3; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
})
t.Run("Open to HalfOpen after timeout", func(t *testing.T) {
// Open the circuit
cb.Reset()
for i := 0; i < 3; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
// Wait for timeout
time.Sleep(150 * time.Millisecond)
// Should allow request and transition to half-open
err := cb.Execute(ctx, func() error {
return nil
})
assert.NoError(t, err)
assert.Equal(t, StateHalfOpen, cb.GetState())
})
t.Run("HalfOpen to Closed after successful requests", func(t *testing.T) {
// Open circuit then wait for half-open
cb.Reset()
for i := 0; i < 3; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
time.Sleep(150 * time.Millisecond)
// First request transitions to half-open and succeeds
err := cb.Execute(ctx, func() error {
return nil
})
assert.NoError(t, err)
// Should be in half-open after first request
state := cb.GetState()
assert.True(t, state == StateHalfOpen || state == StateClosed,
"After first successful request, should be half-open or potentially closed")
if state == StateHalfOpen {
// Need more successful requests to close
// The exact number depends on implementation but should be within HalfOpenMaxRequests
for i := 0; i < config.HalfOpenMaxRequests; i++ {
cb.Execute(ctx, func() error {
return nil
})
}
// After multiple successful requests, should eventually close
finalState := cb.GetState()
assert.True(t, finalState == StateClosed || finalState == StateHalfOpen,
"After successful requests, circuit should transition towards closed")
}
})
t.Run("HalfOpen to Open on failure", func(t *testing.T) {
// Open circuit then wait for half-open
cb.Reset()
for i := 0; i < 3; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
time.Sleep(150 * time.Millisecond)
// First call transitions to half-open, second failure reopens
cb.Execute(ctx, func() error {
return errors.New("test error")
})
assert.Equal(t, StateOpen, cb.GetState())
})
}
// TestCircuitBreaker_OpenCircuitBlocks tests that open circuit blocks requests
func TestCircuitBreaker_OpenCircuitBlocks(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 1 * time.Second,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Trigger failures to open circuit
for i := 0; i < 2; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
// Requests should be blocked
err := cb.Execute(ctx, func() error {
t.Fatal("Should not execute function when circuit is open")
return nil
})
assert.Error(t, err)
assert.Equal(t, ErrCircuitOpen, err)
}
// TestCircuitBreaker_HalfOpenMaxRequests tests max requests in half-open state
func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 3,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 2,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Open circuit then wait for half-open
for i := 0; i < 3; i++ {
cb.Execute(ctx, func() error {
return errors.New("test error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
time.Sleep(150 * time.Millisecond)
// After timeout, circuit should allow transition to half-open
// Execute HalfOpenMaxRequests successful requests
successCount := 0
for i := 0; i < config.HalfOpenMaxRequests; i++ {
err := cb.Execute(ctx, func() error {
successCount++
return nil
})
// Should allow up to HalfOpenMaxRequests
assert.NoError(t, err)
}
// Verify we executed the expected number
assert.Equal(t, config.HalfOpenMaxRequests, successCount)
// After successful requests, circuit behavior depends on implementation
// It could close (allowing more requests) or stay half-open (blocking)
// The important thing is that we allowed exactly HalfOpenMaxRequests
}
// TestCircuitBreaker_SuccessResetsFailures tests failure counter reset
func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 3,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Have some failures (but less than max)
cb.Execute(ctx, func() error {
return errors.New("error")
})
cb.Execute(ctx, func() error {
return errors.New("error")
})
assert.Equal(t, StateClosed, cb.GetState())
stats := cb.Stats()
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
// One success should reset failures
cb.Execute(ctx, func() error {
return nil
})
assert.Equal(t, StateClosed, cb.GetState())
stats = cb.Stats()
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
}
// TestCircuitBreaker_ConcurrentAccess tests thread safety
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 10,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 5,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
var wg sync.WaitGroup
goroutines := 20
iterations := 50
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
// Mix of successes and failures
cb.Execute(ctx, func() error {
if (id+j)%3 == 0 {
return errors.New("test error")
}
return nil
})
// Random state checks
_ = cb.GetState()
_ = cb.Stats()
}
}(i)
}
wg.Wait()
// Should complete without panics
stats := cb.Stats()
assert.NotNil(t, stats)
}
// TestCircuitBreaker_Stats tests statistics tracking
func TestCircuitBreaker_Stats(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 5,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 2,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Execute some requests
cb.Execute(ctx, func() error { return nil }) // Success
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
stats := cb.Stats()
assert.Equal(t, StateClosed, stats.State)
assert.Equal(t, int64(3), stats.TotalRequests)
assert.Equal(t, int64(2), stats.TotalFailures)
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
}
// TestCircuitBreaker_Reset tests circuit reset
func TestCircuitBreaker_Reset(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Open the circuit
for i := 0; i < 2; i++ {
cb.Execute(ctx, func() error {
return errors.New("error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
// Reset
cb.Reset()
assert.Equal(t, StateClosed, cb.GetState())
stats := cb.Stats()
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
assert.Equal(t, int64(0), stats.TotalRequests)
assert.Equal(t, int64(0), stats.TotalFailures)
}
// TestCircuitBreaker_StateChangeCallback tests state change notifications
func TestCircuitBreaker_StateChangeCallback(t *testing.T) {
t.Parallel()
var transitions []string
var mu sync.Mutex
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 50 * time.Millisecond,
HalfOpenMaxRequests: 1,
OnStateChange: func(from, to State) {
mu.Lock()
defer mu.Unlock()
transitions = append(transitions, from.String()+"->"+to.String())
},
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Trigger state transitions
// Closed -> Open
for i := 0; i < 2; i++ {
cb.Execute(ctx, func() error {
return errors.New("error")
})
}
// Should be open now
assert.Equal(t, StateOpen, cb.GetState())
// Wait for timeout to allow half-open transition
time.Sleep(100 * time.Millisecond)
// Open -> HalfOpen on first request after timeout
err := cb.Execute(ctx, func() error {
return nil
})
assert.NoError(t, err)
// Execute more successful requests to trigger HalfOpen -> Closed
for i := 0; i < config.HalfOpenMaxRequests-1; i++ {
cb.Execute(ctx, func() error {
return nil
})
}
mu.Lock()
defer mu.Unlock()
assert.Contains(t, transitions, "closed->open")
assert.Contains(t, transitions, "open->half-open")
}
// TestCircuitBreaker_IsHealthy tests health check
func TestCircuitBreaker_IsHealthy(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 2,
Timeout: 100 * time.Millisecond,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Initially healthy
assert.True(t, cb.IsHealthy())
// Open circuit
for i := 0; i < 2; i++ {
cb.Execute(ctx, func() error {
return errors.New("error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
assert.False(t, cb.IsHealthy(), "Should not be healthy when open")
// Wait for timeout and allow successful request
time.Sleep(150 * time.Millisecond)
cb.Execute(ctx, func() error {
return nil
})
// Should be healthy after recovery
assert.True(t, cb.IsHealthy(), "Should be healthy after recovery")
}
// TestCircuitBreaker_RapidFailures tests rapid consecutive failures
func TestCircuitBreaker_RapidFailures(t *testing.T) {
t.Parallel()
config := &CircuitBreakerConfig{
MaxFailures: 5,
Timeout: 200 * time.Millisecond,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Rapid failures
for i := 0; i < 10; i++ {
cb.Execute(ctx, func() error {
return errors.New("rapid error")
})
}
assert.Equal(t, StateOpen, cb.GetState())
stats := cb.Stats()
assert.GreaterOrEqual(t, stats.TotalFailures, int64(5))
}
// TestCircuitBreaker_TimeoutAccuracy tests timeout precision
func TestCircuitBreaker_TimeoutAccuracy(t *testing.T) {
t.Parallel()
timeout := 100 * time.Millisecond
config := &CircuitBreakerConfig{
MaxFailures: 1,
Timeout: timeout,
HalfOpenMaxRequests: 1,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
// Open circuit
cb.Execute(ctx, func() error {
return errors.New("error")
})
assert.Equal(t, StateOpen, cb.GetState())
// Wait just before timeout
time.Sleep(timeout - 20*time.Millisecond)
assert.False(t, cb.IsHealthy())
// Wait until after timeout
time.Sleep(40 * time.Millisecond)
// After timeout, AllowRequest should return true for transition to half-open
assert.True(t, cb.AllowRequest())
}
// TestCircuitBreaker_DefaultConfig tests default configuration
func TestCircuitBreaker_DefaultConfig(t *testing.T) {
t.Parallel()
cb := NewCircuitBreaker(nil) // Should use defaults
assert.NotNil(t, cb)
assert.Equal(t, StateClosed, cb.GetState())
// Verify defaults by triggering circuit breaker behavior
ctx := context.Background()
// Test that it takes 5 failures to open (default MaxFailures)
for i := 0; i < 4; i++ {
cb.Execute(ctx, func() error {
return errors.New("error")
})
}
assert.Equal(t, StateClosed, cb.GetState(), "Should still be closed after 4 failures")
// 5th failure should open it
cb.Execute(ctx, func() error {
return errors.New("error")
})
assert.Equal(t, StateOpen, cb.GetState(), "Should be open after 5 failures (default threshold)")
}
// TestCircuitBreaker_StateString tests state string representation
func TestCircuitBreaker_StateString(t *testing.T) {
t.Parallel()
assert.Equal(t, "closed", StateClosed.String())
assert.Equal(t, "open", StateOpen.String())
assert.Equal(t, "half-open", StateHalfOpen.String())
assert.Equal(t, "unknown", State(999).String())
}
// Benchmark circuit breaker performance
func BenchmarkCircuitBreaker_Execute(b *testing.B) {
config := &CircuitBreakerConfig{
MaxFailures: 100,
Timeout: 1 * time.Second,
HalfOpenMaxRequests: 10,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.Execute(ctx, func() error {
return nil
})
}
}
func BenchmarkCircuitBreaker_ExecuteWithFailures(b *testing.B) {
config := &CircuitBreakerConfig{
MaxFailures: 1000,
Timeout: 1 * time.Second,
HalfOpenMaxRequests: 10,
}
cb := NewCircuitBreaker(config)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cb.Execute(ctx, func() error {
if i%10 == 0 {
return errors.New("error")
}
return nil
})
}
}
+375
View File
@@ -0,0 +1,375 @@
// Package resilience provides resilience patterns for cache backends.
package resilience
import (
"context"
"sync"
"sync/atomic"
"time"
)
// HealthStatus represents the health status of a backend
type HealthStatus int32
const (
// HealthUnknown indicates unknown health status
HealthUnknown HealthStatus = iota
// HealthHealthy indicates the backend is healthy
HealthHealthy
// HealthDegraded indicates the backend is degraded but operational
HealthDegraded
// HealthUnhealthy indicates the backend is unhealthy
HealthUnhealthy
)
// String returns the string representation of the health status
func (h HealthStatus) String() string {
switch h {
case HealthHealthy:
return "healthy"
case HealthDegraded:
return "degraded"
case HealthUnhealthy:
return "unhealthy"
default:
return "unknown"
}
}
// HealthCheckConfig holds configuration for the health checker
type HealthCheckConfig struct {
// CheckInterval is how often to check health
CheckInterval time.Duration
// Timeout is the timeout for each health check
Timeout time.Duration
// HealthyThreshold is the number of consecutive successes to become healthy
HealthyThreshold int
// UnhealthyThreshold is the number of consecutive failures to become unhealthy
UnhealthyThreshold int
// DegradedThreshold is the latency threshold in ms to mark as degraded
DegradedThreshold time.Duration
// OnStatusChange is called when health status changes
OnStatusChange func(from, to HealthStatus)
// CheckFunc is the function to check health
CheckFunc func(ctx context.Context) error
}
// DefaultHealthCheckConfig returns default configuration
func DefaultHealthCheckConfig() *HealthCheckConfig {
return &HealthCheckConfig{
CheckInterval: 30 * time.Second,
Timeout: 5 * time.Second,
HealthyThreshold: 3,
UnhealthyThreshold: 3,
DegradedThreshold: 100 * time.Millisecond,
}
}
// HealthChecker monitors the health of a backend
type HealthChecker struct {
config *HealthCheckConfig
// Status tracking
status atomic.Int32
consecutiveSuccesses atomic.Int32
consecutiveFailures atomic.Int32
// Timing
lastCheckTime time.Time
lastSuccessTime time.Time
lastFailureTime time.Time
averageLatency atomic.Int64
timeMu sync.RWMutex
// Metrics
totalChecks atomic.Int64
totalSuccesses atomic.Int64
totalFailures atomic.Int64
statusChanges atomic.Int64
// Lifecycle
ticker *time.Ticker
stopChan chan struct{}
stopped atomic.Bool
wg sync.WaitGroup
}
// NewHealthChecker creates a new health checker
func NewHealthChecker(config *HealthCheckConfig) *HealthChecker {
if config == nil {
config = DefaultHealthCheckConfig()
}
hc := &HealthChecker{
config: config,
stopChan: make(chan struct{}),
}
hc.status.Store(int32(HealthUnknown))
return hc
}
// Start begins health checking
func (hc *HealthChecker) Start() {
if hc.stopped.Load() {
return
}
hc.ticker = time.NewTicker(hc.config.CheckInterval)
hc.wg.Add(1)
go hc.checkLoop()
}
// Stop stops health checking
func (hc *HealthChecker) Stop() {
if hc.stopped.Swap(true) {
return // Already stopped
}
close(hc.stopChan)
if hc.ticker != nil {
hc.ticker.Stop()
}
hc.wg.Wait()
}
// checkLoop runs periodic health checks
func (hc *HealthChecker) checkLoop() {
defer hc.wg.Done()
// Initial check - log error but continue
if err := hc.Check(context.Background()); err != nil {
// Error is already tracked in Check() method, no need to log again
_ = err
}
for {
select {
case <-hc.stopChan:
return
case <-hc.ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), hc.config.Timeout)
if err := hc.Check(ctx); err != nil {
// Error is already tracked in Check() method, no need to log again
_ = err
}
cancel()
}
}
}
// Check performs a health check
func (hc *HealthChecker) Check(ctx context.Context) error {
if hc.config.CheckFunc == nil {
return nil
}
hc.totalChecks.Add(1)
start := time.Now()
// Create timeout context if not already set
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, hc.config.Timeout)
defer cancel()
}
// Perform health check
err := hc.config.CheckFunc(ctx)
latency := time.Since(start)
hc.timeMu.Lock()
hc.lastCheckTime = time.Now()
hc.timeMu.Unlock()
// Update average latency
hc.updateAverageLatency(latency)
if err != nil {
hc.recordFailure()
} else {
hc.recordSuccess(latency)
}
return err
}
// recordSuccess records a successful health check
func (hc *HealthChecker) recordSuccess(latency time.Duration) {
hc.totalSuccesses.Add(1)
successes := hc.consecutiveSuccesses.Add(1)
hc.consecutiveFailures.Store(0)
hc.timeMu.Lock()
hc.lastSuccessTime = time.Now()
hc.timeMu.Unlock()
currentStatus := hc.GetStatus()
newStatus := currentStatus
// Check if we should become healthy
if successes >= int32(hc.config.HealthyThreshold) {
if latency > hc.config.DegradedThreshold {
newStatus = HealthDegraded
} else {
newStatus = HealthHealthy
}
}
if newStatus != currentStatus {
hc.setStatus(newStatus)
}
}
// recordFailure records a failed health check
func (hc *HealthChecker) recordFailure() {
hc.totalFailures.Add(1)
failures := hc.consecutiveFailures.Add(1)
hc.consecutiveSuccesses.Store(0)
hc.timeMu.Lock()
hc.lastFailureTime = time.Now()
hc.timeMu.Unlock()
// Check if we should become unhealthy
if failures >= int32(hc.config.UnhealthyThreshold) {
hc.setStatus(HealthUnhealthy)
}
}
// updateAverageLatency updates the rolling average latency
func (hc *HealthChecker) updateAverageLatency(latency time.Duration) {
// Simple exponential moving average
currentAvg := time.Duration(hc.averageLatency.Load())
if currentAvg == 0 {
hc.averageLatency.Store(int64(latency))
} else {
// Weight: 0.2 for new value, 0.8 for old average
newAvg := (currentAvg*4 + latency) / 5
hc.averageLatency.Store(int64(newAvg))
}
}
// GetStatus returns the current health status
func (hc *HealthChecker) GetStatus() HealthStatus {
return HealthStatus(hc.status.Load())
}
// setStatus changes the health status
func (hc *HealthChecker) setStatus(newStatus HealthStatus) {
oldStatus := HealthStatus(hc.status.Swap(int32(newStatus)))
if oldStatus != newStatus {
hc.statusChanges.Add(1)
if hc.config.OnStatusChange != nil {
hc.config.OnStatusChange(oldStatus, newStatus)
}
}
}
// IsHealthy returns true if the backend is healthy or degraded
func (hc *HealthChecker) IsHealthy() bool {
status := hc.GetStatus()
return status == HealthHealthy || status == HealthDegraded
}
// LastCheckTime returns the time of the last health check
func (hc *HealthChecker) LastCheckTime() time.Time {
hc.timeMu.RLock()
defer hc.timeMu.RUnlock()
return hc.lastCheckTime
}
// HealthScore returns a health score between 0.0 (unhealthy) and 1.0 (healthy)
func (hc *HealthChecker) HealthScore() float64 {
status := hc.GetStatus()
switch status {
case HealthHealthy:
return 1.0
case HealthDegraded:
return 0.7
case HealthUnhealthy:
return 0.0
default:
return 0.5
}
}
// Stats returns health checker statistics
func (hc *HealthChecker) Stats() HealthCheckerStats {
hc.timeMu.RLock()
lastCheck := hc.lastCheckTime
lastSuccess := hc.lastSuccessTime
lastFailure := hc.lastFailureTime
hc.timeMu.RUnlock()
totalChecks := hc.totalChecks.Load()
totalSuccesses := hc.totalSuccesses.Load()
totalFailures := hc.totalFailures.Load()
successRate := float64(0)
if totalChecks > 0 {
successRate = float64(totalSuccesses) / float64(totalChecks)
}
return HealthCheckerStats{
Status: hc.GetStatus(),
ConsecutiveSuccesses: hc.consecutiveSuccesses.Load(),
ConsecutiveFailures: hc.consecutiveFailures.Load(),
TotalChecks: totalChecks,
TotalSuccesses: totalSuccesses,
TotalFailures: totalFailures,
SuccessRate: successRate,
AverageLatency: time.Duration(hc.averageLatency.Load()),
StatusChanges: hc.statusChanges.Load(),
LastCheckTime: lastCheck,
LastSuccessTime: lastSuccess,
LastFailureTime: lastFailure,
HealthScore: hc.HealthScore(),
}
}
// HealthCheckerStats holds statistics for the health checker
type HealthCheckerStats struct {
Status HealthStatus
ConsecutiveSuccesses int32
ConsecutiveFailures int32
TotalChecks int64
TotalSuccesses int64
TotalFailures int64
SuccessRate float64
AverageLatency time.Duration
StatusChanges int64
LastCheckTime time.Time
LastSuccessTime time.Time
LastFailureTime time.Time
HealthScore float64
}
// Reset resets the health checker statistics
func (hc *HealthChecker) Reset() {
hc.status.Store(int32(HealthUnknown))
hc.consecutiveSuccesses.Store(0)
hc.consecutiveFailures.Store(0)
hc.totalChecks.Store(0)
hc.totalSuccesses.Store(0)
hc.totalFailures.Store(0)
hc.statusChanges.Store(0)
hc.averageLatency.Store(0)
now := time.Now()
hc.timeMu.Lock()
hc.lastCheckTime = now
hc.lastSuccessTime = now
hc.lastFailureTime = now
hc.timeMu.Unlock()
}
+215
View File
@@ -0,0 +1,215 @@
// Package resilience provides resilience patterns for cache backends.
package resilience
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
)
// HealthCheckBackend wraps a cache backend with health checking
type HealthCheckBackend struct {
backend backends.CacheBackend
config *HealthCheckConfig
// Health tracking
status atomic.Int32
consecutiveFails atomic.Int32
consecutiveOK atomic.Int32
lastCheck time.Time
checkMutex sync.RWMutex
// Lifecycle
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewHealthCheckBackend creates a new health check wrapped backend
func NewHealthCheckBackend(b backends.CacheBackend, config *HealthCheckConfig) backends.CacheBackend {
if config == nil {
config = DefaultHealthCheckConfig()
}
ctx, cancel := context.WithCancel(context.Background())
hc := &HealthCheckBackend{
backend: b,
config: config,
ctx: ctx,
cancel: cancel,
}
// Set initial status to healthy (optimistic)
hc.status.Store(int32(HealthHealthy))
// Start health check routine
hc.wg.Add(1)
go hc.healthCheckLoop()
return hc
}
// Set stores a value and tracks health
func (h *HealthCheckBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
// Allow operations even if unhealthy (may recover)
err := h.backend.Set(ctx, key, value, ttl)
h.recordResult(err == nil)
return err
}
// Get retrieves a value and tracks health
func (h *HealthCheckBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
value, ttl, exists, err := h.backend.Get(ctx, key)
h.recordResult(err == nil)
return value, ttl, exists, err
}
// Delete removes a key and tracks health
func (h *HealthCheckBackend) Delete(ctx context.Context, key string) (bool, error) {
deleted, err := h.backend.Delete(ctx, key)
h.recordResult(err == nil)
return deleted, err
}
// Exists checks if a key exists and tracks health
func (h *HealthCheckBackend) Exists(ctx context.Context, key string) (bool, error) {
exists, err := h.backend.Exists(ctx, key)
h.recordResult(err == nil)
return exists, err
}
// Clear removes all keys and tracks health
func (h *HealthCheckBackend) Clear(ctx context.Context) error {
err := h.backend.Clear(ctx)
h.recordResult(err == nil)
return err
}
// GetStats returns statistics including health status
func (h *HealthCheckBackend) GetStats() map[string]interface{} {
stats := h.backend.GetStats()
if stats == nil {
stats = make(map[string]interface{})
}
h.checkMutex.RLock()
lastCheck := h.lastCheck
h.checkMutex.RUnlock()
status := HealthStatus(h.status.Load())
stats["health"] = map[string]interface{}{
"status": status.String(),
"consecutive_fails": h.consecutiveFails.Load(),
"consecutive_ok": h.consecutiveOK.Load(),
"last_check": lastCheck.Format(time.RFC3339),
"time_since_check": time.Since(lastCheck).Seconds(),
"check_interval_sec": h.config.CheckInterval.Seconds(),
}
return stats
}
// Ping checks backend health
func (h *HealthCheckBackend) Ping(ctx context.Context) error {
err := h.backend.Ping(ctx)
h.recordResult(err == nil)
return err
}
// Close shuts down the health checker and backend
func (h *HealthCheckBackend) Close() error {
// Stop health check routine
h.cancel()
// Wait for routine to finish
done := make(chan struct{})
go func() {
h.wg.Wait()
close(done)
}()
select {
case <-done:
// Finished normally
case <-time.After(2 * time.Second):
// Timeout
}
return h.backend.Close()
}
// IsHealthy returns true if the backend is healthy
func (h *HealthCheckBackend) IsHealthy() bool {
status := HealthStatus(h.status.Load())
return status == HealthHealthy || status == HealthDegraded
}
// recordResult records the result of an operation for health tracking
func (h *HealthCheckBackend) recordResult(success bool) {
if success {
fails := h.consecutiveFails.Swap(0)
oks := h.consecutiveOK.Add(1)
// Check if we should transition to healthy
if fails > 0 && oks >= int32(h.config.HealthyThreshold) {
oldStatus := HealthStatus(h.status.Swap(int32(HealthHealthy)))
if oldStatus != HealthHealthy && h.config.OnStatusChange != nil {
h.config.OnStatusChange(oldStatus, HealthHealthy)
}
}
} else {
oks := h.consecutiveOK.Swap(0)
fails := h.consecutiveFails.Add(1)
// Check if we should transition to unhealthy
if oks > 0 && fails >= int32(h.config.UnhealthyThreshold) {
oldStatus := HealthStatus(h.status.Swap(int32(HealthUnhealthy)))
if oldStatus != HealthUnhealthy && h.config.OnStatusChange != nil {
h.config.OnStatusChange(oldStatus, HealthUnhealthy)
}
} else if fails >= int32(h.config.UnhealthyThreshold)*2 {
// Severely degraded
h.status.Store(int32(HealthUnhealthy))
} else if fails >= int32(h.config.UnhealthyThreshold) {
// Degraded but still trying
h.status.Store(int32(HealthDegraded))
}
}
}
// healthCheckLoop runs periodic health checks
func (h *HealthCheckBackend) healthCheckLoop() {
defer h.wg.Done()
ticker := time.NewTicker(h.config.CheckInterval)
defer ticker.Stop()
// Do initial check
h.performHealthCheck()
for {
select {
case <-h.ctx.Done():
return
case <-ticker.C:
h.performHealthCheck()
}
}
}
// performHealthCheck performs a single health check
func (h *HealthCheckBackend) performHealthCheck() {
h.checkMutex.Lock()
h.lastCheck = time.Now()
h.checkMutex.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), h.config.Timeout)
defer cancel()
err := h.backend.Ping(ctx)
h.recordResult(err == nil)
}
+447
View File
@@ -0,0 +1,447 @@
package resilience
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TestHealthChecker_StatusTransitions tests health status transitions
func TestHealthChecker_StatusTransitions(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
var shouldFail atomic.Bool
checkFunc := func(ctx context.Context) error {
callCount.Add(1)
if shouldFail.Load() {
return errors.New("health check failed")
}
return nil
}
config := &HealthCheckConfig{
CheckInterval: 50 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
// Initially unknown
assert.Equal(t, HealthUnknown, hc.GetStatus())
// Trigger failures
shouldFail.Store(true)
time.Sleep(200 * time.Millisecond)
// Should be unhealthy after threshold failures
status := hc.GetStatus()
assert.True(t, status == HealthUnhealthy || status == HealthDegraded)
// Recover
shouldFail.Store(false)
time.Sleep(150 * time.Millisecond)
// Should recover towards healthy
finalStatus := hc.GetStatus()
assert.True(t, finalStatus == HealthHealthy || finalStatus == HealthDegraded || finalStatus == HealthUnknown)
}
// TestHealthChecker_InitialState tests initial health status
func TestHealthChecker_InitialState(t *testing.T) {
t.Parallel()
checkFunc := func(ctx context.Context) error {
return nil
}
config := &HealthCheckConfig{
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
assert.Equal(t, HealthUnknown, hc.GetStatus())
assert.False(t, hc.IsHealthy())
}
// TestHealthChecker_ForceCheck tests manual health check trigger
func TestHealthChecker_ForceCheck(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
checkFunc := func(ctx context.Context) error {
callCount.Add(1)
return nil
}
config := &HealthCheckConfig{
CheckInterval: 10 * time.Second, // Long interval
Timeout: 1 * time.Second,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
initialCount := callCount.Load()
// Force check
hc.Check(context.Background())
// Should have been called
assert.Greater(t, callCount.Load(), initialCount)
}
// TestHealthChecker_StatusChangeCallback tests status change notifications
func TestHealthChecker_StatusChangeCallback(t *testing.T) {
t.Parallel()
var transitions []string
var mu sync.Mutex
var shouldFail atomic.Bool
checkFunc := func(ctx context.Context) error {
if shouldFail.Load() {
return errors.New("health check failed")
}
return nil
}
config := &HealthCheckConfig{
CheckInterval: 30 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 2,
HealthyThreshold: 2,
CheckFunc: checkFunc,
OnStatusChange: func(from, to HealthStatus) {
mu.Lock()
defer mu.Unlock()
transitions = append(transitions, from.String()+"->"+to.String())
},
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
// Trigger failures
shouldFail.Store(true)
time.Sleep(100 * time.Millisecond)
// Recover
shouldFail.Store(false)
time.Sleep(100 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
// Should have status transitions
assert.NotEmpty(t, transitions)
}
// TestHealthChecker_Stats tests statistics tracking
func TestHealthChecker_Stats(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
checkFunc := func(ctx context.Context) error {
callCount.Add(1)
if callCount.Load()%2 == 0 {
return errors.New("failure")
}
return nil
}
config := &HealthCheckConfig{
CheckInterval: 20 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 5,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
time.Sleep(150 * time.Millisecond)
stats := hc.Stats()
assert.Greater(t, stats.TotalChecks, int64(0))
assert.Greater(t, stats.TotalFailures, int64(0))
assert.Greater(t, stats.SuccessRate, 0.0)
assert.Less(t, stats.SuccessRate, 1.0)
}
// TestHealthChecker_Timeout tests check timeout handling
func TestHealthChecker_Timeout(t *testing.T) {
t.Parallel()
checkFunc := func(ctx context.Context) error {
// Simulate slow check
select {
case <-time.After(100 * time.Millisecond):
return nil
case <-ctx.Done():
return ctx.Err()
}
}
config := &HealthCheckConfig{
CheckInterval: 50 * time.Millisecond,
Timeout: 10 * time.Millisecond, // Short timeout
UnhealthyThreshold: 2,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
time.Sleep(150 * time.Millisecond)
// Should be unhealthy due to timeouts
status := hc.GetStatus()
assert.NotEqual(t, HealthHealthy, status)
}
// TestHealthChecker_ConcurrentAccess tests thread safety
func TestHealthChecker_ConcurrentAccess(t *testing.T) {
t.Parallel()
checkFunc := func(ctx context.Context) error {
return nil
}
config := &HealthCheckConfig{
CheckInterval: 10 * time.Millisecond,
Timeout: 5 * time.Millisecond,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
var wg sync.WaitGroup
goroutines := 20
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 50; j++ {
_ = hc.GetStatus()
_ = hc.IsHealthy()
_ = hc.Stats()
hc.Check(context.Background())
}
}()
}
wg.Wait()
// Should complete without panics
}
// TestHealthChecker_StopAndStart tests lifecycle management
func TestHealthChecker_StopAndStart(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
checkFunc := func(ctx context.Context) error {
callCount.Add(1)
return nil
}
config := &HealthCheckConfig{
CheckInterval: 20 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
// Start
hc.Start()
time.Sleep(100 * time.Millisecond)
count1 := callCount.Load()
assert.Greater(t, count1, int32(0))
// Stop
hc.Stop()
time.Sleep(100 * time.Millisecond)
count2 := callCount.Load()
// Should not have increased significantly after stop
assert.Less(t, count2-count1, int32(3))
}
// TestHealthChecker_DegradedState tests degraded status
func TestHealthChecker_DegradedState(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
checkFunc := func(ctx context.Context) error {
count := callCount.Add(1)
// Fail once, then succeed
if count == 1 {
return errors.New("single failure")
}
return nil
}
config := &HealthCheckConfig{
CheckInterval: 30 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 3, // Need 3 failures for unhealthy
HealthyThreshold: 2, // Need 2 successes for healthy
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
time.Sleep(100 * time.Millisecond)
// After initial checks, status should be set (might be healthy or degraded based on execution)
status := hc.GetStatus()
assert.True(t, status != HealthUnknown, "Status should not be unknown after checks")
}
// TestHealthChecker_DefaultConfig tests default configuration
func TestHealthChecker_DefaultConfig(t *testing.T) {
t.Parallel()
checkFunc := func(ctx context.Context) error {
return nil
}
config := &HealthCheckConfig{
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
assert.NotNil(t, hc)
assert.Equal(t, HealthUnknown, hc.GetStatus())
// Verify default config was applied (we can't access private fields, so just check it works)
assert.NotNil(t, hc)
}
// TestHealthChecker_StatusString tests status string representation
func TestHealthChecker_StatusString(t *testing.T) {
t.Parallel()
assert.Equal(t, "healthy", HealthHealthy.String())
assert.Equal(t, "unhealthy", HealthUnhealthy.String())
assert.Equal(t, "degraded", HealthDegraded.String())
assert.Equal(t, "unknown", HealthStatus(999).String())
}
// TestHealthChecker_RecoveryPattern tests typical failure and recovery
func TestHealthChecker_RecoveryPattern(t *testing.T) {
t.Parallel()
var checkNumber atomic.Int32
checkFunc := func(ctx context.Context) error {
n := checkNumber.Add(1)
// Fail checks 3-5, succeed others
if n >= 3 && n <= 5 {
return errors.New("temporary failure")
}
return nil
}
var statusLog []HealthStatus
var mu sync.Mutex
config := &HealthCheckConfig{
CheckInterval: 30 * time.Millisecond,
Timeout: 10 * time.Millisecond,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
OnStatusChange: func(from, to HealthStatus) {
mu.Lock()
defer mu.Unlock()
statusLog = append(statusLog, to)
},
}
hc := NewHealthChecker(config)
hc.Start()
defer hc.Stop()
time.Sleep(300 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
// Should see transitions through unhealthy and back to healthy
assert.NotEmpty(t, statusLog)
// Final status should be healthy or degraded (recovered)
finalStatus := hc.GetStatus()
assert.True(t, finalStatus == HealthHealthy || finalStatus == HealthDegraded, "Should have recovered")
}
// Benchmark health checker performance
func BenchmarkHealthChecker_ForceCheck(b *testing.B) {
checkFunc := func(ctx context.Context) error {
return nil
}
config := &HealthCheckConfig{
CheckInterval: 10 * time.Minute,
Timeout: 1 * time.Second,
UnhealthyThreshold: 3,
HealthyThreshold: 2,
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
b.ResetTimer()
for i := 0; i < b.N; i++ {
hc.Check(context.Background())
}
}
func BenchmarkHealthChecker_Status(b *testing.B) {
checkFunc := func(ctx context.Context) error {
return nil
}
config := &HealthCheckConfig{
CheckFunc: checkFunc,
}
hc := NewHealthChecker(config)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = hc.GetStatus()
}
}
+4 -4
View File
@@ -39,25 +39,25 @@ func (p *Auth0Provider) BuildAuthParams(baseParams url.Values, scopes []string)
// Ensure offline_access scope is present for refresh tokens
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
if scope == ScopeOfflineAccess {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
scopes = append(scopes, ScopeOfflineAccess)
}
// Ensure openid scope is present
hasOpenID := false
for _, scope := range scopes {
if scope == "openid" {
if scope == ScopeOpenID {
hasOpenID = true
break
}
}
if !hasOpenID {
scopes = append(scopes, "openid")
scopes = append(scopes, ScopeOpenID)
}
return &AuthParams{
+5 -5
View File
@@ -40,7 +40,7 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
// Remove offline_access scope as Cognito doesn't use it (case-insensitive)
var filteredScopes []string
for _, scope := range scopes {
if strings.ToLower(scope) != "offline_access" {
if strings.ToLower(scope) != ScopeOfflineAccess {
filteredScopes = append(filteredScopes, scope)
}
}
@@ -48,18 +48,18 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
// Ensure openid scope is present
hasOpenID := false
for _, scope := range filteredScopes {
if scope == "openid" {
if scope == ScopeOpenID {
hasOpenID = true
break
}
}
if !hasOpenID {
filteredScopes = append(filteredScopes, "openid")
filteredScopes = append(filteredScopes, ScopeOpenID)
}
// Default Cognito scopes if none specified
if len(filteredScopes) == 1 && filteredScopes[0] == "openid" {
filteredScopes = append(filteredScopes, "email", "profile")
if len(filteredScopes) == 1 && filteredScopes[0] == ScopeOpenID {
filteredScopes = append(filteredScopes, ScopeEmail, ScopeProfile)
}
return &AuthParams{
+2 -2
View File
@@ -38,13 +38,13 @@ func (p *AzureProvider) BuildAuthParams(baseParams url.Values, scopes []string)
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
if scope == ScopeOfflineAccess {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
scopes = append(scopes, ScopeOfflineAccess)
}
return &AuthParams{
+3 -3
View File
@@ -102,17 +102,17 @@ func (p *BaseProvider) ValidateTokenExpiry(session Session, token string, tokenC
}
// BuildAuthParams constructs authorization parameters for the provider.
// It includes the "offline_access" scope by default for refresh token support.
// It includes the offline_access scope by default for refresh token support.
func (p *BaseProvider) BuildAuthParams(baseParams url.Values, scopes []string) (*AuthParams, error) {
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
if scope == ScopeOfflineAccess {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
scopes = append(scopes, ScopeOfflineAccess)
}
return &AuthParams{
+1 -1
View File
@@ -38,7 +38,7 @@ func (p *GitHubProvider) BuildAuthParams(baseParams url.Values, scopes []string)
// GitHub doesn't use offline_access scope, so remove it if present
var filteredScopes []string
for _, scope := range scopes {
if scope != "offline_access" {
if scope != ScopeOfflineAccess {
filteredScopes = append(filteredScopes, scope)
}
}
+5 -5
View File
@@ -39,7 +39,7 @@ func (p *GitLabProvider) BuildAuthParams(baseParams url.Values, scopes []string)
// Remove offline_access scope as GitLab doesn't use it
var filteredScopes []string
for _, scope := range scopes {
if scope != "offline_access" {
if scope != ScopeOfflineAccess {
filteredScopes = append(filteredScopes, scope)
}
}
@@ -47,18 +47,18 @@ func (p *GitLabProvider) BuildAuthParams(baseParams url.Values, scopes []string)
// Ensure openid scope is present for OIDC
hasOpenID := false
for _, scope := range filteredScopes {
if scope == "openid" {
if scope == ScopeOpenID {
hasOpenID = true
break
}
}
if !hasOpenID {
filteredScopes = append(filteredScopes, "openid")
filteredScopes = append(filteredScopes, ScopeOpenID)
}
// Default GitLab scopes if none specified
if len(filteredScopes) == 1 && filteredScopes[0] == "openid" {
filteredScopes = append(filteredScopes, "profile", "email")
if len(filteredScopes) == 1 && filteredScopes[0] == ScopeOpenID {
filteredScopes = append(filteredScopes, ScopeProfile, ScopeEmail)
}
return &AuthParams{
+2 -2
View File
@@ -36,10 +36,10 @@ func (p *GoogleProvider) BuildAuthParams(baseParams url.Values, scopes []string)
baseParams.Set("access_type", "offline")
baseParams.Set("prompt", "consent")
// Google does not use the "offline_access" scope, so we remove it if present.
// Google does not use the ScopeOfflineAccess scope, so we remove it if present.
var filteredScopes []string
for _, scope := range scopes {
if scope != "offline_access" {
if scope != ScopeOfflineAccess {
filteredScopes = append(filteredScopes, scope)
}
}
+8
View File
@@ -33,6 +33,14 @@ const (
ProviderTypeGitLab
)
// Standard OAuth2/OIDC scope constants
const (
ScopeOfflineAccess = "offline_access"
ScopeOpenID = "openid"
ScopeProfile = "profile"
ScopeEmail = "email"
)
// ProviderCapabilities defines the specific features and behaviors of an OIDC provider.
type ProviderCapabilities struct {
PreferredTokenValidation string
+4 -4
View File
@@ -39,25 +39,25 @@ func (p *KeycloakProvider) BuildAuthParams(baseParams url.Values, scopes []strin
// Ensure offline_access scope is present for refresh tokens
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
if scope == ScopeOfflineAccess {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
scopes = append(scopes, ScopeOfflineAccess)
}
// Ensure openid scope is present
hasOpenID := false
for _, scope := range scopes {
if scope == "openid" {
if scope == ScopeOpenID {
hasOpenID = true
break
}
}
if !hasOpenID {
scopes = append(scopes, "openid")
scopes = append(scopes, ScopeOpenID)
}
return &AuthParams{
+4 -4
View File
@@ -39,25 +39,25 @@ func (p *OktaProvider) BuildAuthParams(baseParams url.Values, scopes []string) (
// Ensure offline_access scope is present for refresh tokens
hasOfflineAccess := false
for _, scope := range scopes {
if scope == "offline_access" {
if scope == ScopeOfflineAccess {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
scopes = append(scopes, "offline_access")
scopes = append(scopes, ScopeOfflineAccess)
}
// Ensure openid scope is present
hasOpenID := false
for _, scope := range scopes {
if scope == "openid" {
if scope == ScopeOpenID {
hasOpenID = true
break
}
}
if !hasOpenID {
scopes = append(scopes, "openid")
scopes = append(scopes, ScopeOpenID)
}
return &AuthParams{
+1 -1
View File
@@ -61,7 +61,7 @@ func (v *ConfigValidator) ValidateScopes(scopes []string) error {
hasOpenIDScope := false
for _, scope := range scopes {
if strings.TrimSpace(scope) == "openid" {
if strings.TrimSpace(scope) == ScopeOpenID {
hasOpenIDScope = true
break
}
+1 -1
View File
@@ -124,7 +124,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
httpClient = CreateDefaultHTTPClient()
}
goroutineWG := &sync.WaitGroup{}
cacheManager := GetGlobalCacheManager(goroutineWG)
cacheManager := GetGlobalCacheManagerWithConfig(goroutineWG, config)
// Use provided context instead of creating new one
var pluginCtx context.Context
+404
View File
@@ -0,0 +1,404 @@
package traefikoidc
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestRedisIntegration_MultipleInstances tests cache sharing across multiple instances
func TestRedisIntegration_MultipleInstances(t *testing.T) {
t.Parallel()
// Start miniredis server
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
ctx := context.Background()
// Create two backend instances sharing the same Redis
config1 := backends.DefaultRedisConfig(mr.Addr())
config1.RedisPrefix = "shared:"
backend1, err := backends.NewRedisBackend(config1)
require.NoError(t, err)
defer backend1.Close()
config2 := backends.DefaultRedisConfig(mr.Addr())
config2.RedisPrefix = "shared:"
backend2, err := backends.NewRedisBackend(config2)
require.NoError(t, err)
defer backend2.Close()
t.Run("ShareTokenBlacklist", func(t *testing.T) {
// Instance 1 blacklists a JTI
jti := "test-jti-12345"
err := backend1.Set(ctx, "jti:"+jti, []byte("blacklisted"), 10*time.Minute)
require.NoError(t, err)
// Instance 2 should see the blacklisted JTI
_, _, exists, err := backend2.Get(ctx, "jti:"+jti)
require.NoError(t, err)
assert.True(t, exists, "JTI should be visible across instances")
})
t.Run("ShareTokenCache", func(t *testing.T) {
// Instance 1 caches a token
token := "access-token-xyz"
tokenData := []byte(`{"sub":"user123","exp":1234567890}`)
err := backend1.Set(ctx, "token:"+token, tokenData, 5*time.Minute)
require.NoError(t, err)
// Instance 2 retrieves the cached token
retrieved, _, exists, err := backend2.Get(ctx, "token:"+token)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, tokenData, retrieved)
})
t.Run("ShareMetadataCache", func(t *testing.T) {
// Instance 1 caches provider metadata
metadataKey := "metadata:provider123"
metadata := []byte(`{"issuer":"https://example.com","jwks_uri":"https://example.com/jwks"}`)
err := backend1.Set(ctx, metadataKey, metadata, 1*time.Hour)
require.NoError(t, err)
// Instance 2 retrieves the metadata
retrieved, ttl, exists, err := backend2.Get(ctx, metadataKey)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, metadata, retrieved)
assert.Greater(t, ttl, 50*time.Minute)
})
}
// TestRedisIntegration_JTIReplayDetection tests JTI replay detection across instances
func TestRedisIntegration_JTIReplayDetection(t *testing.T) {
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
ctx := context.Background()
// Multiple Traefik instances
instances := make([]*backends.RedisBackend, 3)
for i := 0; i < 3; i++ {
config := backends.DefaultRedisConfig(mr.Addr())
config.RedisPrefix = "jti:"
instances[i], err = backends.NewRedisBackend(config)
require.NoError(t, err)
defer instances[i].Close()
}
t.Run("PreventReplayAcrossInstances", func(t *testing.T) {
jti := "replay-test-jti"
// First instance processes token and blacklists JTI
err := instances[0].Set(ctx, jti, []byte("used"), 24*time.Hour)
require.NoError(t, err)
// Other instances should detect the used JTI
for i := 1; i < 3; i++ {
exists, err := instances[i].Exists(ctx, jti)
require.NoError(t, err)
assert.True(t, exists, "Instance %d should see blacklisted JTI", i)
}
})
t.Run("ConcurrentJTIChecks", func(t *testing.T) {
jtiBase := "concurrent-jti"
var wg sync.WaitGroup
// Simulate concurrent token processing across instances
for instanceID := 0; instanceID < 3; instanceID++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 10; j++ {
jti := fmt.Sprintf("%s-%d-%d", jtiBase, id, j)
// Check if JTI exists
exists, _ := instances[id].Exists(ctx, jti)
if !exists {
// Mark as used
instances[id].Set(ctx, jti, []byte("used"), 1*time.Hour)
}
}
}(instanceID)
}
wg.Wait()
// Verify all JTIs were recorded
for instanceID := 0; instanceID < 3; instanceID++ {
for j := 0; j < 10; j++ {
jti := fmt.Sprintf("%s-%d-%d", jtiBase, instanceID, j)
exists, err := instances[0].Exists(ctx, jti)
require.NoError(t, err)
assert.True(t, exists, "JTI %s should exist", jti)
}
}
})
}
// TestRedisIntegration_Failover tests failover scenarios
func TestRedisIntegration_Failover(t *testing.T) {
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
ctx := context.Background()
config := backends.DefaultRedisConfig(mr.Addr())
redisBackend, err := backends.NewRedisBackend(config)
require.NoError(t, err)
defer redisBackend.Close()
t.Run("RedisTemporaryFailure", func(t *testing.T) {
// Set some data
key := "failover-key"
value := []byte("failover-value")
err := redisBackend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
// Simulate Redis error
mr.SetError("simulated connection error")
// Operations should fail gracefully
_, _, exists, err := redisBackend.Get(ctx, key)
assert.Error(t, err)
assert.False(t, exists)
// Clear error
mr.SetError("")
// Operations should work again
retrieved, _, exists, err := redisBackend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Equal(t, value, retrieved)
})
}
// TestRedisIntegration_HighLoad tests high load scenarios
func TestRedisIntegration_HighLoad(t *testing.T) {
if testing.Short() {
t.Skip("Skipping high load test in short mode")
}
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
ctx := context.Background()
config := backends.DefaultRedisConfig(mr.Addr())
config.PoolSize = 20
redisBackend, err := backends.NewRedisBackend(config)
require.NoError(t, err)
defer redisBackend.Close()
t.Run("HighConcurrency", func(t *testing.T) {
var wg sync.WaitGroup
goroutines := 50
operations := 100
errors := make(chan error, goroutines*operations)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < operations; j++ {
key := fmt.Sprintf("high-load-key-%d-%d", id, j)
value := []byte(fmt.Sprintf("high-load-value-%d-%d", id, j))
// Write
if err := redisBackend.Set(ctx, key, value, 1*time.Minute); err != nil {
errors <- err
continue
}
// Read
retrieved, _, exists, err := redisBackend.Get(ctx, key)
if err != nil {
errors <- err
continue
}
if !exists {
errors <- fmt.Errorf("key %s does not exist", key)
continue
}
if string(retrieved) != string(value) {
errors <- fmt.Errorf("value mismatch for key %s", key)
}
}
}(i)
}
wg.Wait()
close(errors)
// Check for errors
errorCount := 0
for err := range errors {
t.Logf("Operation error: %v", err)
errorCount++
}
// Allow small error rate (< 1%)
totalOps := goroutines * operations
errorRate := float64(errorCount) / float64(totalOps)
assert.Less(t, errorRate, 0.01, "Error rate should be less than 1%%")
})
}
// TestRedisIntegration_TTLConsistency tests TTL consistency across operations
func TestRedisIntegration_TTLConsistency(t *testing.T) {
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
ctx := context.Background()
config := backends.DefaultRedisConfig(mr.Addr())
redisBackend, err := backends.NewRedisBackend(config)
require.NoError(t, err)
defer redisBackend.Close()
t.Run("TTLAccuracy", func(t *testing.T) {
key := "ttl-test-key"
value := []byte("ttl-test-value")
ttl := 5 * time.Second
err := redisBackend.Set(ctx, key, value, ttl)
require.NoError(t, err)
// Check TTL immediately
_, ttl1, exists, err := redisBackend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Greater(t, ttl1, 4*time.Second)
assert.LessOrEqual(t, ttl1, ttl)
// Fast forward 2 seconds
mr.FastForward(2 * time.Second)
// Check TTL again
_, ttl2, exists, err := redisBackend.Get(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
assert.Less(t, ttl2, ttl1)
assert.Greater(t, ttl2, 2*time.Second)
// Fast forward past expiration
mr.FastForward(4 * time.Second)
// Should be expired
_, _, exists, err = redisBackend.Get(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
})
}
// TestRedisIntegration_MemoryUsage tests memory efficiency
func TestRedisIntegration_MemoryUsage(t *testing.T) {
if testing.Short() {
t.Skip("Skipping memory usage test in short mode")
}
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
ctx := context.Background()
config := backends.DefaultRedisConfig(mr.Addr())
redisBackend, err := backends.NewRedisBackend(config)
require.NoError(t, err)
defer redisBackend.Close()
t.Run("LargeDataset", func(t *testing.T) {
// Store 10,000 items
itemCount := 10000
for i := 0; i < itemCount; i++ {
key := fmt.Sprintf("memory-test-key-%d", i)
value := []byte(fmt.Sprintf("memory-test-value-%d-with-some-padding-to-make-it-larger", i))
err := redisBackend.Set(ctx, key, value, 10*time.Minute)
require.NoError(t, err)
// Log progress
if i%1000 == 0 {
t.Logf("Stored %d items", i)
}
}
// Verify all items exist
for i := 0; i < itemCount; i += 100 {
key := fmt.Sprintf("memory-test-key-%d", i)
exists, err := redisBackend.Exists(ctx, key)
require.NoError(t, err)
assert.True(t, exists)
}
// Check stats
stats := redisBackend.GetStats()
t.Logf("Redis backend stats: %+v", stats)
})
}
// TestRedisIntegration_Cleanup tests cache cleanup functionality
func TestRedisIntegration_Cleanup(t *testing.T) {
t.Parallel()
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
ctx := context.Background()
config := backends.DefaultRedisConfig(mr.Addr())
config.RedisPrefix = "cleanup-test:"
redisBackend, err := backends.NewRedisBackend(config)
require.NoError(t, err)
defer redisBackend.Close()
t.Run("BulkCleanup", func(t *testing.T) {
// Add many items
for i := 0; i < 100; i++ {
key := fmt.Sprintf("cleanup-key-%d", i)
value := []byte(fmt.Sprintf("cleanup-value-%d", i))
err := redisBackend.Set(ctx, key, value, 1*time.Minute)
require.NoError(t, err)
}
// Clear all
err := redisBackend.Clear(ctx)
require.NoError(t, err)
// Verify all items are gone
for i := 0; i < 100; i++ {
key := fmt.Sprintf("cleanup-key-%d", i)
exists, err := redisBackend.Exists(ctx, key)
require.NoError(t, err)
assert.False(t, exists)
}
})
}
+17 -49
View File
@@ -971,7 +971,9 @@ func (cm *ChunkManager) validateTokenExpiration(token string, config TokenConfig
// Returns:
// - The expiration time if present, nil if no 'exp' claim.
// - An error if JWT parsing fails.
func (cm *ChunkManager) extractJWTExpiration(token string) (*time.Time, error) {
//
// extractJWTClaim extracts a time claim from a JWT token
func (cm *ChunkManager) extractJWTClaim(token, claimName string) (*time.Time, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format")
@@ -992,25 +994,29 @@ func (cm *ChunkManager) extractJWTExpiration(token string) (*time.Time, error) {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
exp, exists := claims["exp"]
claimValue, exists := claims[claimName]
if !exists {
return nil, nil
}
// Convert expiration to time.Time
var expTime time.Time
switch v := exp.(type) {
// Convert claim to time.Time
var claimTime time.Time
switch v := claimValue.(type) {
case float64:
expTime = time.Unix(int64(v), 0)
claimTime = time.Unix(int64(v), 0)
case int64:
expTime = time.Unix(v, 0)
claimTime = time.Unix(v, 0)
case int:
expTime = time.Unix(int64(v), 0)
claimTime = time.Unix(int64(v), 0)
default:
return nil, fmt.Errorf("invalid expiration format: %T", exp)
return nil, fmt.Errorf("invalid %s format: %T", claimName, claimValue)
}
return &expTime, nil
return &claimTime, nil
}
func (cm *ChunkManager) extractJWTExpiration(token string) (*time.Time, error) {
return cm.extractJWTClaim(token, "exp")
}
// validateTokenFreshness checks if token is fresh enough for storage.
@@ -1062,45 +1068,7 @@ func (cm *ChunkManager) validateTokenFreshness(token string, config TokenConfig)
// - The issued at time if present, nil if no 'iat' claim.
// - An error if JWT parsing fails.
func (cm *ChunkManager) extractJWTIssuedAt(token string) (*time.Time, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
}
// Parse the JSON payload using pooled decoder
var claims map[string]interface{}
pm := pool.Get()
decoder := pm.GetJSONDecoder(bytes.NewReader(payload))
defer pm.PutJSONDecoder(decoder)
if err := decoder.Decode(&claims); err != nil {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
iat, exists := claims["iat"]
if !exists {
return nil, nil
}
// Convert issued at to time.Time
var iatTime time.Time
switch v := iat.(type) {
case float64:
iatTime = time.Unix(int64(v), 0)
case int64:
iatTime = time.Unix(v, 0)
case int:
iatTime = time.Unix(int64(v), 0)
default:
return nil, fmt.Errorf("invalid issued at format: %T", iat)
}
return &iatTime, nil
return cm.extractJWTClaim(token, "iat")
}
// CleanupExpiredSessions removes expired sessions to prevent memory leaks.
+413
View File
@@ -89,6 +89,73 @@ type Config struct {
// Recommended: true for multi-replica deployments
DisableReplayDetection bool `json:"disableReplayDetection,omitempty"`
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
// Redis configures the Redis cache backend for distributed caching.
// When enabled, provides cache sharing across multiple Traefik replicas.
// Default: nil (disabled - uses in-memory caching)
Redis *RedisConfig `json:"redis,omitempty"`
}
// RedisConfig configures Redis cache backend settings for distributed caching.
// All fields support both JSON and YAML configuration for compatibility with Traefik's
// dynamic configuration (labels, YAML files, etc.)
type RedisConfig struct {
// Enabled indicates if Redis caching should be used (default: false)
Enabled bool `json:"enabled" yaml:"enabled"`
// Address is the Redis server address (e.g., "localhost:6379", "redis:6379")
Address string `json:"address" yaml:"address"`
// Password for Redis authentication (optional, leave empty for no auth)
Password string `json:"password,omitempty" yaml:"password,omitempty"`
// DB is the Redis database number to use (default: 0)
DB int `json:"db" yaml:"db"`
// KeyPrefix is the prefix for all Redis keys (default: "traefikoidc:")
KeyPrefix string `json:"keyPrefix" yaml:"keyPrefix"`
// PoolSize is the maximum number of socket connections (default: 10)
PoolSize int `json:"poolSize" yaml:"poolSize"`
// ConnectTimeout is the timeout for establishing connections in seconds (default: 5)
ConnectTimeout int `json:"connectTimeout" yaml:"connectTimeout"`
// ReadTimeout is the timeout for read operations in seconds (default: 3)
ReadTimeout int `json:"readTimeout" yaml:"readTimeout"`
// WriteTimeout is the timeout for write operations in seconds (default: 3)
WriteTimeout int `json:"writeTimeout" yaml:"writeTimeout"`
// EnableTLS indicates if TLS should be used for Redis connections (default: false)
EnableTLS bool `json:"enableTLS" yaml:"enableTLS"`
// TLSSkipVerify skips TLS certificate verification (not recommended for production)
TLSSkipVerify bool `json:"tlsSkipVerify" yaml:"tlsSkipVerify"`
// CacheMode determines the caching strategy: "redis" (Redis only), "hybrid" (Memory+Redis), "memory" (Memory only)
// Default: "redis" when enabled
CacheMode string `json:"cacheMode" yaml:"cacheMode"`
// HybridL1Size is the maximum number of items in L1 cache for hybrid mode (default: 500)
HybridL1Size int `json:"hybridL1Size" yaml:"hybridL1Size"`
// HybridL1MemoryMB is the maximum memory in MB for L1 cache in hybrid mode (default: 10)
HybridL1MemoryMB int64 `json:"hybridL1MemoryMB" yaml:"hybridL1MemoryMB"`
// EnableCircuitBreaker enables circuit breaker for Redis failures (default: true)
EnableCircuitBreaker bool `json:"enableCircuitBreaker" yaml:"enableCircuitBreaker"`
// CircuitBreakerThreshold is the number of failures before opening circuit (default: 5)
CircuitBreakerThreshold int `json:"circuitBreakerThreshold" yaml:"circuitBreakerThreshold"`
// CircuitBreakerTimeout is the timeout in seconds before attempting to close circuit (default: 60)
CircuitBreakerTimeout int `json:"circuitBreakerTimeout" yaml:"circuitBreakerTimeout"`
// EnableHealthCheck enables periodic health checks for Redis (default: true)
EnableHealthCheck bool `json:"enableHealthCheck" yaml:"enableHealthCheck"`
// HealthCheckInterval is the interval in seconds between health checks (default: 30)
HealthCheckInterval int `json:"healthCheckInterval" yaml:"healthCheckInterval"`
}
// SecurityHeadersConfig configures security headers for the plugin
@@ -167,11 +234,14 @@ const (
// - PostLogoutRedirectURI: "/"
// - ForceHTTPS: true (for security)
// - EnablePKCE: false (PKCE is opt-in)
// - Redis: nil (disabled by default, can be configured via Traefik config or env vars)
//
// CreateConfig initializes a new Config struct with default values for optional fields.
// It sets default scopes, log level, rate limit, enables ForceHTTPS, and sets the
// default refresh grace period. Required fields like ProviderURL, ClientID, ClientSecret,
// CallbackURL, and SessionEncryptionKey must be set explicitly after creation.
// Redis configuration can be provided through Traefik's dynamic configuration or
// as a fallback through environment variables.
//
// Returns:
// - A pointer to a new Config struct with default settings applied.
@@ -185,6 +255,7 @@ func CreateConfig() *Config {
OverrideScopes: false, // Default to appending scopes, not overriding
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
SecurityHeaders: createDefaultSecurityConfig(),
Redis: nil, // Redis is disabled by default, configure via Traefik or env vars
}
return c
@@ -329,6 +400,13 @@ func (c *Config) Validate() error {
}
}
// Validate Redis configuration if provided
if c.Redis != nil && c.Redis.Enabled {
if err := c.Redis.Validate(); err != nil {
return fmt.Errorf("redis configuration error: %w", err)
}
}
// Validate headers configuration for template security
for _, header := range c.Headers {
if header.Name == "" {
@@ -803,6 +881,341 @@ func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Req
}
// isOriginAllowed checks if an origin is in the allowed list
// Validate checks if the Redis configuration is valid
func (rc *RedisConfig) Validate() error {
if !rc.Enabled {
return nil
}
if rc.Address == "" {
return fmt.Errorf("redis address is required when Redis is enabled")
}
// Validate cache mode
if rc.CacheMode != "" {
validModes := map[string]bool{
"redis": true,
"hybrid": true,
"memory": true,
}
if !validModes[rc.CacheMode] {
return fmt.Errorf("invalid cache mode: %s (must be 'redis', 'hybrid', or 'memory')", rc.CacheMode)
}
}
// Validate connection settings
if rc.PoolSize < 0 {
return fmt.Errorf("pool size cannot be negative")
}
if rc.ConnectTimeout < 0 {
return fmt.Errorf("connect timeout cannot be negative")
}
if rc.ReadTimeout < 0 {
return fmt.Errorf("read timeout cannot be negative")
}
if rc.WriteTimeout < 0 {
return fmt.Errorf("write timeout cannot be negative")
}
// Validate hybrid mode settings
if rc.CacheMode == "hybrid" {
if rc.HybridL1Size < 0 {
return fmt.Errorf("hybrid L1 size cannot be negative")
}
if rc.HybridL1MemoryMB < 0 {
return fmt.Errorf("hybrid L1 memory cannot be negative")
}
}
// Validate circuit breaker settings
if rc.CircuitBreakerThreshold < 0 {
return fmt.Errorf("circuit breaker threshold cannot be negative")
}
if rc.CircuitBreakerTimeout < 0 {
return fmt.Errorf("circuit breaker timeout cannot be negative")
}
// Validate health check settings
if rc.HealthCheckInterval < 0 {
return fmt.Errorf("health check interval cannot be negative")
}
return nil
}
// ApplyDefaults sets default values for Redis configuration when fields are not explicitly set.
// This ensures reasonable defaults while allowing full customization through configuration.
func (rc *RedisConfig) ApplyDefaults() {
// Only apply defaults if Redis is enabled
if !rc.Enabled {
return
}
// Connection defaults
if rc.KeyPrefix == "" {
rc.KeyPrefix = "traefikoidc:"
}
if rc.PoolSize == 0 {
rc.PoolSize = 10
}
if rc.ConnectTimeout == 0 {
rc.ConnectTimeout = 5
}
if rc.ReadTimeout == 0 {
rc.ReadTimeout = 3
}
if rc.WriteTimeout == 0 {
rc.WriteTimeout = 3
}
// Cache mode defaults
if rc.CacheMode == "" {
rc.CacheMode = "redis" // Default to redis-only mode for simplicity
}
// Hybrid mode specific defaults
if rc.CacheMode == "hybrid" {
if rc.HybridL1Size == 0 {
rc.HybridL1Size = 500
}
if rc.HybridL1MemoryMB == 0 {
rc.HybridL1MemoryMB = 10
}
}
// Resilience features - these use a different pattern to detect if they were explicitly set
// Since bool fields default to false, we need to be careful about defaults
// For now, we'll enable by default only if not explicitly disabled via environment
if rc.CircuitBreakerThreshold == 0 {
rc.CircuitBreakerThreshold = 5
}
if rc.CircuitBreakerTimeout == 0 {
rc.CircuitBreakerTimeout = 60
}
if rc.HealthCheckInterval == 0 {
rc.HealthCheckInterval = 30
}
}
// ApplyEnvFallbacks applies environment variable values as fallbacks for empty config fields.
// This allows environment variables to be used as optional overrides only when the
// corresponding config field is not set through Traefik's dynamic configuration.
// The plugin configuration takes precedence over environment variables.
func (rc *RedisConfig) ApplyEnvFallbacks() {
// Only apply env fallbacks if Redis is not already configured
if !rc.Enabled {
// Check if Redis should be enabled from environment
enabledStr := os.Getenv("REDIS_ENABLED")
if enabledStr == "true" || enabledStr == "1" {
rc.Enabled = true
}
}
// Only apply other env vars if Redis is enabled
if !rc.Enabled {
return
}
// Apply environment variables only for empty fields
if rc.Address == "" {
if addr := os.Getenv("REDIS_ADDRESS"); addr != "" {
rc.Address = addr
}
}
if rc.Password == "" {
rc.Password = os.Getenv("REDIS_PASSWORD")
}
if rc.KeyPrefix == "" {
if prefix := os.Getenv("REDIS_KEY_PREFIX"); prefix != "" {
rc.KeyPrefix = prefix
}
}
if rc.CacheMode == "" {
if mode := os.Getenv("REDIS_CACHE_MODE"); mode != "" {
rc.CacheMode = mode
}
}
// Apply numeric values only if not already set
if rc.DB == 0 {
if dbStr := os.Getenv("REDIS_DB"); dbStr != "" {
if db, err := strconv.Atoi(dbStr); err == nil && db > 0 {
rc.DB = db
}
}
}
if rc.PoolSize == 0 {
if poolSizeStr := os.Getenv("REDIS_POOL_SIZE"); poolSizeStr != "" {
if poolSize, err := strconv.Atoi(poolSizeStr); err == nil && poolSize > 0 {
rc.PoolSize = poolSize
}
}
}
if rc.ConnectTimeout == 0 {
if timeoutStr := os.Getenv("REDIS_CONNECT_TIMEOUT"); timeoutStr != "" {
if timeout, err := strconv.Atoi(timeoutStr); err == nil && timeout > 0 {
rc.ConnectTimeout = timeout
}
}
}
if rc.ReadTimeout == 0 {
if timeoutStr := os.Getenv("REDIS_READ_TIMEOUT"); timeoutStr != "" {
if timeout, err := strconv.Atoi(timeoutStr); err == nil && timeout > 0 {
rc.ReadTimeout = timeout
}
}
}
if rc.WriteTimeout == 0 {
if timeoutStr := os.Getenv("REDIS_WRITE_TIMEOUT"); timeoutStr != "" {
if timeout, err := strconv.Atoi(timeoutStr); err == nil && timeout > 0 {
rc.WriteTimeout = timeout
}
}
}
// Apply boolean values from env only if not already set in config
if !rc.EnableTLS {
if tlsStr := os.Getenv("REDIS_ENABLE_TLS"); tlsStr == "true" || tlsStr == "1" {
rc.EnableTLS = true
}
}
if !rc.TLSSkipVerify {
if skipStr := os.Getenv("REDIS_TLS_SKIP_VERIFY"); skipStr == "true" || skipStr == "1" {
rc.TLSSkipVerify = true
}
}
// Hybrid mode settings
if rc.HybridL1Size == 0 {
if sizeStr := os.Getenv("REDIS_HYBRID_L1_SIZE"); sizeStr != "" {
if size, err := strconv.Atoi(sizeStr); err == nil && size > 0 {
rc.HybridL1Size = size
}
}
}
if rc.HybridL1MemoryMB == 0 {
if memStr := os.Getenv("REDIS_HYBRID_L1_MEMORY_MB"); memStr != "" {
if mem, err := strconv.ParseInt(memStr, 10, 64); err == nil && mem > 0 {
rc.HybridL1MemoryMB = mem
}
}
}
}
// LoadRedisConfigFromEnv loads Redis configuration from environment variables.
// Deprecated: Use RedisConfig.ApplyEnvFallbacks() on an existing config instead.
// This function is kept for backward compatibility but should not be used directly.
func LoadRedisConfigFromEnv() *RedisConfig {
// Check if Redis is enabled
enabledStr := os.Getenv("REDIS_ENABLED")
if enabledStr == "" || enabledStr == "false" || enabledStr == "0" {
return nil
}
config := &RedisConfig{
Enabled: true,
}
// Parse numeric values
if dbStr := os.Getenv("REDIS_DB"); dbStr != "" {
if db, err := strconv.Atoi(dbStr); err == nil {
config.DB = db
}
}
if poolSizeStr := os.Getenv("REDIS_POOL_SIZE"); poolSizeStr != "" {
if poolSize, err := strconv.Atoi(poolSizeStr); err == nil {
config.PoolSize = poolSize
}
}
if connectTimeoutStr := os.Getenv("REDIS_CONNECT_TIMEOUT"); connectTimeoutStr != "" {
if timeout, err := strconv.Atoi(connectTimeoutStr); err == nil {
config.ConnectTimeout = timeout
}
}
if readTimeoutStr := os.Getenv("REDIS_READ_TIMEOUT"); readTimeoutStr != "" {
if timeout, err := strconv.Atoi(readTimeoutStr); err == nil {
config.ReadTimeout = timeout
}
}
if writeTimeoutStr := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeoutStr != "" {
if timeout, err := strconv.Atoi(writeTimeoutStr); err == nil {
config.WriteTimeout = timeout
}
}
// Parse boolean values
if enableTLSStr := os.Getenv("REDIS_ENABLE_TLS"); enableTLSStr == "true" || enableTLSStr == "1" {
config.EnableTLS = true
}
if skipVerifyStr := os.Getenv("REDIS_TLS_SKIP_VERIFY"); skipVerifyStr == "true" || skipVerifyStr == "1" {
config.TLSSkipVerify = true
}
// Parse hybrid mode settings
if l1SizeStr := os.Getenv("REDIS_HYBRID_L1_SIZE"); l1SizeStr != "" {
if size, err := strconv.Atoi(l1SizeStr); err == nil {
config.HybridL1Size = size
}
}
if l1MemoryStr := os.Getenv("REDIS_HYBRID_L1_MEMORY_MB"); l1MemoryStr != "" {
if memory, err := strconv.ParseInt(l1MemoryStr, 10, 64); err == nil {
config.HybridL1MemoryMB = memory
}
}
// Parse circuit breaker settings
if enableCBStr := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCBStr == "false" || enableCBStr == "0" {
config.EnableCircuitBreaker = false
} else {
config.EnableCircuitBreaker = true // Default to enabled
}
if cbThresholdStr := os.Getenv("REDIS_CIRCUIT_BREAKER_THRESHOLD"); cbThresholdStr != "" {
if threshold, err := strconv.Atoi(cbThresholdStr); err == nil {
config.CircuitBreakerThreshold = threshold
}
}
if cbTimeoutStr := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeoutStr != "" {
if timeout, err := strconv.Atoi(cbTimeoutStr); err == nil {
config.CircuitBreakerTimeout = timeout
}
}
// Parse health check settings
if enableHCStr := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHCStr == "false" || enableHCStr == "0" {
config.EnableHealthCheck = false
} else {
config.EnableHealthCheck = true // Default to enabled
}
if hcIntervalStr := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcIntervalStr != "" {
if interval, err := strconv.Atoi(hcIntervalStr); err == nil {
config.HealthCheckInterval = interval
}
}
// Apply defaults after loading from env
config.ApplyDefaults()
return config
}
func isOriginAllowed(origin string, allowedOrigins []string) bool {
for _, allowed := range allowedOrigins {
if origin == allowed || allowed == "*" {
+144
View File
@@ -3,10 +3,13 @@ package traefikoidc
import (
"container/list"
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
)
// CacheType defines the type of cache for optimized behavior
@@ -89,6 +92,9 @@ type UniversalCache struct {
config UniversalCacheConfig
logger *Logger
// Backend for distributed caching (NEW)
backend backends.CacheBackend
// Memory management
currentSize int64
currentMemory int64
@@ -110,6 +116,13 @@ func NewUniversalCache(config UniversalCacheConfig) *UniversalCache {
return createUniversalCache(config)
}
// NewUniversalCacheWithBackend creates a new universal cache with a specific backend
func NewUniversalCacheWithBackend(config UniversalCacheConfig, cacheBackend backends.CacheBackend) *UniversalCache {
cache := createUniversalCache(config)
cache.backend = cacheBackend
return cache
}
// createUniversalCache is the internal constructor
func createUniversalCache(config UniversalCacheConfig) *UniversalCache {
// Apply type-specific defaults first (including MaxSize)
@@ -223,6 +236,25 @@ func (c *UniversalCache) Set(key string, value interface{}, ttl time.Duration) e
ttl = c.config.DefaultTTL
}
// If we have a backend, use it for distributed caching
if c.backend != nil {
// Serialize the value
data, err := c.serialize(value)
if err != nil {
c.logger.Errorf("Failed to serialize value for key %s: %v", key, err)
return err
}
// Store in backend
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
if err := c.backend.Set(ctx, c.prefixKey(key), data, ttl); err != nil {
c.logger.Infof("Backend set error for key %s: %v", key, err)
// Continue with local cache even if backend fails
}
}
size := c.estimateSize(value)
c.mu.Lock()
@@ -285,6 +317,32 @@ func (c *UniversalCache) Set(key string, value interface{}, ttl time.Duration) e
// Get retrieves a value from the cache
func (c *UniversalCache) Get(key string) (interface{}, bool) {
// Try backend first if available (for distributed consistency)
if c.backend != nil {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
data, _, exists, err := c.backend.Get(ctx, c.prefixKey(key))
if err != nil {
c.logger.Debugf("Backend get error for key %s: %v", key, err)
// Fall through to local cache
} else if exists {
// Deserialize the value
var value interface{}
if err := c.deserialize(data, &value); err != nil {
c.logger.Errorf("Failed to deserialize value for key %s: %v", key, err)
// Fall through to local cache
} else {
atomic.AddInt64(&c.hits, 1)
// Update local cache with backend value
go func() {
_ = c.updateLocalCache(key, value, c.config.DefaultTTL)
}()
return value, true
}
}
}
c.mu.Lock()
defer c.mu.Unlock()
@@ -341,6 +399,17 @@ func (c *UniversalCache) Get(key string) (interface{}, bool) {
// Delete removes a key from the cache
func (c *UniversalCache) Delete(key string) bool {
// Delete from backend if available
if c.backend != nil {
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
if _, err := c.backend.Delete(ctx, c.prefixKey(key)); err != nil {
c.logger.Debugf("Backend delete error for key %s: %v", key, err)
// Continue with local delete
}
}
c.mu.Lock()
defer c.mu.Unlock()
@@ -355,6 +424,17 @@ func (c *UniversalCache) Delete(key string) bool {
// Clear removes all items from the cache
func (c *UniversalCache) Clear() {
// Clear backend if available
if c.backend != nil {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
if err := c.backend.Clear(ctx); err != nil {
c.logger.Infof("Backend clear error: %v", err)
// Continue with local clear
}
}
c.mu.Lock()
defer c.mu.Unlock()
@@ -437,6 +517,13 @@ func (c *UniversalCache) Close() error {
// Clear all items
c.Clear()
// Close backend if present
if c.backend != nil {
if err := c.backend.Close(); err != nil {
c.logger.Infof("Failed to close cache backend: %v", err)
}
}
c.logger.Debugf("UniversalCache[%s]: Closed", c.config.Type)
return nil
}
@@ -701,3 +788,60 @@ func (c *UniversalCache) Mutex() *sync.RWMutex {
func (c *UniversalCache) Strategy() CacheStrategy {
return c.config.Strategy
}
// serialize converts a value to bytes for backend storage
func (c *UniversalCache) serialize(value interface{}) ([]byte, error) {
// Use JSON for serialization - simple and universal
return json.Marshal(value)
}
// deserialize converts bytes from backend storage to a value
func (c *UniversalCache) deserialize(data []byte, value interface{}) error {
// Use JSON for deserialization
return json.Unmarshal(data, value)
}
// prefixKey adds a cache type prefix to the key for backend storage
func (c *UniversalCache) prefixKey(key string) string {
return fmt.Sprintf("%s:%s", c.config.Type, key)
}
// updateLocalCache updates the local cache with a value from the backend
func (c *UniversalCache) updateLocalCache(key string, value interface{}, ttl time.Duration) error {
size := c.estimateSize(value)
c.mu.Lock()
defer c.mu.Unlock()
// Check memory limits
if c.config.MaxMemoryBytes > 0 {
for c.currentMemory+size > c.config.MaxMemoryBytes && c.lruList.Len() > 0 {
c.evictOldest()
}
}
// Check size limits
if c.lruList.Len() >= c.config.MaxSize {
c.evictOldest()
}
now := time.Now()
item := &CacheItem{
Key: key,
Value: value,
Size: size,
ExpiresAt: now.Add(ttl),
LastAccessed: now,
AccessCount: 1,
CacheType: c.config.Type,
Metadata: make(map[string]interface{}),
}
item.element = c.lruList.PushFront(key)
c.items[key] = item
c.currentSize++
c.currentMemory += size
return nil
}
+270 -39
View File
@@ -3,6 +3,9 @@ package traefikoidc
import (
"sync"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
"github.com/lukaszraczylo/traefikoidc/internal/cache/resilience"
)
// UniversalCacheManager manages all cache instances using the universal cache
@@ -34,25 +37,217 @@ func GetUniversalCacheManager(logger *Logger) *UniversalCacheManager {
logger: logger,
}
// Initialize token cache - CRITICAL FIX: Reduced from 5000 to 1000
universalCacheManager.tokenCache = NewUniversalCache(UniversalCacheConfig{
// Initialize with default in-memory backends
initializeDefaultCaches(universalCacheManager, logger)
})
return universalCacheManager
}
// GetUniversalCacheManagerWithConfig returns the singleton universal cache manager with Redis configuration
func GetUniversalCacheManagerWithConfig(logger *Logger, redisConfig *RedisConfig) *UniversalCacheManager {
universalCacheManagerOnce.Do(func() {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
universalCacheManager = &UniversalCacheManager{
logger: logger,
}
if redisConfig != nil && redisConfig.Enabled {
logger.Infof("Initializing cache manager with Redis backend: %s", redisConfig.Address)
initializeCachesWithRedis(universalCacheManager, logger, redisConfig)
} else {
logger.Info("Initializing cache manager with memory-only backend")
initializeDefaultCaches(universalCacheManager, logger)
}
})
return universalCacheManager
}
// initializeDefaultCaches initializes caches with memory-only backends
func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) {
// Initialize token cache - CRITICAL FIX: Reduced from 5000 to 1000
manager.tokenCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 1000, // CRITICAL FIX: Reduced from 5000 to 1000 items
MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit
DefaultTTL: 1 * time.Hour,
Logger: logger,
})
// Initialize blacklist cache
manager.blacklistCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 1000,
DefaultTTL: 24 * time.Hour,
Logger: logger,
})
// Initialize metadata cache with grace periods
manager.metadataCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeMetadata,
MaxSize: 100,
DefaultTTL: 1 * time.Hour,
MetadataConfig: &MetadataCacheConfig{
GracePeriod: 5 * time.Minute,
ExtendedGracePeriod: 15 * time.Minute,
MaxGracePeriod: 30 * time.Minute,
SecurityCriticalMaxGracePeriod: 15 * time.Minute,
SecurityCriticalFields: []string{
"jwks_uri",
"token_endpoint",
"authorization_endpoint",
"issuer",
},
},
Logger: logger,
})
// Initialize JWK cache
manager.jwkCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeJWK,
MaxSize: 200,
DefaultTTL: 1 * time.Hour,
Logger: logger,
})
// Initialize session cache - CRITICAL FIX: Reduced from 10000 to 2000
manager.sessionCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeSession,
MaxSize: 2000, // CRITICAL FIX: Reduced from 10000 to 2000 items
MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit
DefaultTTL: 30 * time.Minute,
Logger: logger,
})
// Initialize introspection cache for OAuth 2.0 Token Introspection (RFC 7662)
manager.introspectionCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeToken, // Use token cache type for introspection results
MaxSize: 1000, // Cache up to 1000 introspection results
DefaultTTL: 5 * time.Minute, // Short TTL for security (introspect frequently)
Logger: logger,
})
// Initialize token type cache for performance optimization
manager.tokenTypeCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeToken, // Use token cache type for token type detection
MaxSize: 2000, // Cache up to 2000 token type detections
DefaultTTL: 5 * time.Minute, // 5 minute TTL for token type detection
Logger: logger,
})
}
// initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration
func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, redisConfig *RedisConfig) {
// Apply defaults to Redis config
redisConfig.ApplyDefaults()
// Create Redis backend
redisBackendConfig := &backends.Config{
Type: backends.BackendTypeRedis,
RedisAddr: redisConfig.Address,
RedisPassword: redisConfig.Password,
RedisDB: redisConfig.DB,
RedisPrefix: redisConfig.KeyPrefix,
PoolSize: redisConfig.PoolSize,
EnableMetrics: true,
}
var redisBackend backends.CacheBackend
var err error
// Create Redis backend with resilience features if enabled
redisBackend, err = backends.NewRedisBackend(redisBackendConfig)
if err != nil {
logger.Errorf("Failed to create Redis backend: %v. Falling back to memory-only mode.", err)
initializeDefaultCaches(manager, logger)
return
}
// Wrap with circuit breaker if enabled
if redisConfig.EnableCircuitBreaker {
cbConfig := resilience.DefaultCircuitBreakerConfig()
cbConfig.MaxFailures = redisConfig.CircuitBreakerThreshold
cbConfig.Timeout = time.Duration(redisConfig.CircuitBreakerTimeout) * time.Second
cbConfig.OnStateChange = func(from, to resilience.State) {
logger.Infof("Circuit breaker state changed from %s to %s", from, to)
}
redisBackend = resilience.NewCircuitBreakerBackend(redisBackend, cbConfig)
logger.Info("Redis backend wrapped with circuit breaker")
}
// Wrap with health checker if enabled
if redisConfig.EnableHealthCheck {
hcConfig := &resilience.HealthCheckConfig{
CheckInterval: time.Duration(redisConfig.HealthCheckInterval) * time.Second,
Timeout: 5 * time.Second,
HealthyThreshold: 2,
UnhealthyThreshold: 3,
OnStatusChange: func(from, to resilience.HealthStatus) {
logger.Infof("Redis backend health status changed from %s to %s", from, to)
},
}
redisBackend = resilience.NewHealthCheckBackend(redisBackend, hcConfig)
logger.Info("Redis backend wrapped with health checker")
}
// Decide which backend to use based on cache mode
var createBackend func(cacheType CacheType) backends.CacheBackend
switch redisConfig.CacheMode {
case "redis":
// Redis-only mode
createBackend = func(cacheType CacheType) backends.CacheBackend {
return redisBackend
}
logger.Info("Using Redis-only cache backend")
case "hybrid":
// Hybrid mode is not currently supported due to interface incompatibilities
// Fall back to Redis-only mode
logger.Info("Hybrid mode not currently supported, using Redis-only mode")
createBackend = func(cacheType CacheType) backends.CacheBackend {
return redisBackend
}
default:
// Memory-only mode (fallback)
logger.Infof("Invalid cache mode: %s. Using memory-only mode.", redisConfig.CacheMode)
initializeDefaultCaches(manager, logger)
return
}
// Initialize token cache with backend
manager.tokenCache = NewUniversalCacheWithBackend(
UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 1000, // CRITICAL FIX: Reduced from 5000 to 1000 items
MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit
MaxSize: 1000,
MaxMemoryBytes: 5 * 1024 * 1024,
DefaultTTL: 1 * time.Hour,
Logger: logger,
})
},
createBackend(CacheTypeToken),
)
// Initialize blacklist cache
universalCacheManager.blacklistCache = NewUniversalCache(UniversalCacheConfig{
// Initialize blacklist cache (CRITICAL - must be consistent across replicas)
manager.blacklistCache = NewUniversalCacheWithBackend(
UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 1000,
DefaultTTL: 24 * time.Hour,
Logger: logger,
})
},
createBackend("blacklist"),
)
// Initialize metadata cache with grace periods
universalCacheManager.metadataCache = NewUniversalCache(UniversalCacheConfig{
// Initialize metadata cache
manager.metadataCache = NewUniversalCacheWithBackend(
UniversalCacheConfig{
Type: CacheTypeMetadata,
MaxSize: 100,
DefaultTTL: 1 * time.Hour,
@@ -69,43 +264,50 @@ func GetUniversalCacheManager(logger *Logger) *UniversalCacheManager {
},
},
Logger: logger,
})
},
createBackend(CacheTypeMetadata),
)
// Initialize JWK cache
universalCacheManager.jwkCache = NewUniversalCache(UniversalCacheConfig{
// Initialize JWK cache
manager.jwkCache = NewUniversalCacheWithBackend(
UniversalCacheConfig{
Type: CacheTypeJWK,
MaxSize: 200,
DefaultTTL: 1 * time.Hour,
Logger: logger,
})
},
createBackend(CacheTypeJWK),
)
// Initialize session cache - CRITICAL FIX: Reduced from 10000 to 2000
universalCacheManager.sessionCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeSession,
MaxSize: 2000, // CRITICAL FIX: Reduced from 10000 to 2000 items
MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit
DefaultTTL: 30 * time.Minute,
Logger: logger,
})
// Initialize introspection cache for OAuth 2.0 Token Introspection (RFC 7662)
universalCacheManager.introspectionCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeToken, // Use token cache type for introspection results
MaxSize: 1000, // Cache up to 1000 introspection results
DefaultTTL: 5 * time.Minute, // Short TTL for security (introspect frequently)
Logger: logger,
})
// Initialize token type cache for performance optimization
universalCacheManager.tokenTypeCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeToken, // Use token cache type for token type detection
MaxSize: 2000, // Cache up to 2000 token type detections
DefaultTTL: 5 * time.Minute, // 5 minute TTL for token type detection
Logger: logger,
})
// Session cache stays memory-only (high volume, local state)
manager.sessionCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeSession,
MaxSize: 2000,
MaxMemoryBytes: 5 * 1024 * 1024,
DefaultTTL: 30 * time.Minute,
Logger: logger,
})
return universalCacheManager
// Introspection cache uses backend for sharing results
manager.introspectionCache = NewUniversalCacheWithBackend(
UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 1000,
DefaultTTL: 5 * time.Minute,
Logger: logger,
},
createBackend(CacheTypeToken),
)
// Token type cache stays memory-only (local optimization)
manager.tokenTypeCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 2000,
DefaultTTL: 5 * time.Minute,
Logger: logger,
})
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
}
// GetTokenCache returns the token cache
@@ -174,6 +376,35 @@ func (m *UniversalCacheManager) Close() error {
return nil
}
// InitializeCacheManagerFromConfig initializes the cache manager with configuration
// This should be called early in the application startup with the loaded configuration
func InitializeCacheManagerFromConfig(config *Config) *UniversalCacheManager {
logger := NewLogger(config.LogLevel)
// Initialize Redis config if not present
if config.Redis == nil {
config.Redis = &RedisConfig{}
}
// Apply environment variable fallbacks for fields not set in config
// This allows env vars to be used as optional overrides only when
// the config field is not explicitly set through Traefik
config.Redis.ApplyEnvFallbacks()
// Apply defaults after env fallbacks
config.Redis.ApplyDefaults()
// Log cache backend selection
if config.Redis != nil && config.Redis.Enabled {
logger.Infof("Initializing cache backend with Redis: mode=%s, address=%s",
config.Redis.CacheMode, config.Redis.Address)
} else {
logger.Info("Initializing cache backend with memory-only mode")
}
return GetUniversalCacheManagerWithConfig(logger, config.Redis)
}
// ResetUniversalCacheManagerForTesting resets the singleton for testing purposes only
// This should only be called in test code to ensure proper cleanup between tests
func ResetUniversalCacheManagerForTesting() {
+6
View File
@@ -0,0 +1,6 @@
/integration/redis_src/
/integration/dump.rdb
*.swp
/integration/nodes.conf
.idea/
miniredis.iml
+328
View File
@@ -0,0 +1,328 @@
## Changelog
## v2.35.0
- add Lua redis.setresp({2,3})
- embed gopher-json package
- fix XAUTOCLAIM (thanks @kgunning)
- fix writeXpending (thanks @gnpaone)
- fix BLMOVE TTL special case
- constants for key types @alyssaruth
### v2.34.0
- fix ZINTERSTORE where target is one of the source sets
- added support for ZRank and ZRevRank with score (thanks Jeff Howell)
- fix MEMORY subcommand casing (thanks @joshaber)
- use streamCmp in Xtrim (thanks @daniel-cohere)
### v2.33.0
- minimum Go version is now 1.17
- fix integer overflow (thanks @wszaranski)
- test against the last BSD redis (7.2.4)
- ignore 'redis.set_repl()' call (thanks @TingluoHuang)
- various build fixes (thanks @wszaranski)
- add StartAddrTLS function (thanks @agriffaut)
- support for the NOMKSTREAM option for XADD (thanks @Jahaja)
- return empty array for SRANDMEMBER on nonexistent key (thanks @WKBae)
### v2.32.1
- support for SINTERCARD (thanks @s-barr-fetch)
- support for EXPIRETIME and PEXPIRETIME (thanks @wszaranski)
- fix GEO* units to be case insensitive
### v2.31.1
- support COUNT in SCAN and ZSCAN (thanks @BarakSilverfort)
- support for OBJECT IDLETIME (thanks @nerd2)
- support for HRANDFIELD (thanks @sejin-P)
### v2.31.0
- support for MEMORY USAGE (thanks @davidroman0O)
- test against Redis 7.2.0
- support for CLIENT SETNAME/GETNAME (thanks @mr-karan)
- fix very small numbers (thanks @zsh1995)
- use the same float-to-string logic real Redis uses
### v2.30.5
- support SMISMEMBER (thanks @sandyharvie)
### v2.30.4
- fix ZADD LT/LG (thanks @sejin-P)
- fix COPY (thanks @jerargus)
- quicker SPOP
### v2.30.3
- fix lua error_reply (thanks @pkierski)
- fix use of blocking functions in lua
- support for ZMSCORE (thanks @lsgndln)
- lua cache (thanks @tonyhb)
### v2.30.2
- support MINID in XADD (thanks @nathan-cormier)
- support BLMOVE (thanks @sevein)
- fix COMMAND (thanks @pje)
- fix 'XREAD ... $' on a non-existing stream
### v2.30.1
- support SET NX GET special case
### v2.30.0
- implement redis 7.0.x (from 6.X). Main changes:
- test against 7.0.7
- update error messages
- support nx|xx|gt|lt options in [P]EXPIRE[AT]
- update how deleted items are processed in pending queues in streams
### v2.23.1
- resolve $ to latest ID in XREAD (thanks @josh-hook)
- handle disconnect in blocking functions (thanks @jgirtakovskis)
- fix type conversion bug in redisToLua (thanks Sandy Harvie)
- BRPOP{LPUSH} timeout can be float since 6.0
### v2.23.0
- basic INFO support (thanks @kirill-a-belov)
- support COUNT in SSCAN (thanks @Abdi-dd)
- test and support Go 1.19
- support LPOS (thanks @ianstarz)
- support XPENDING, XGROUP {CREATECONSUMER,DESTROY,DELCONSUMER}, XINFO {CONSUMERS,GROUPS}, XCLAIM (thanks @sandyharvie)
### v2.22.0
- set miniredis.DumpMaxLineLen to get more Dump() info (thanks @afjoseph)
- fix invalid resposne of COMMAND (thanks @zsh1995)
- fix possibility to generate duplicate IDs in XADD (thanks @readams)
- adds support for XAUTOCLAIM min-idle parameter (thanks @readams)
### v2.21.0
- support for GETEX (thanks @dntj)
- support for GT and LT in ZADD (thanks @lsgndln)
- support for XAUTOCLAIM (thanks @randall-fulton)
### v2.20.0
- back to support Go >= 1.14 (thanks @ajatprabha and @marcind)
### v2.19.0
- support for TYPE in SCAN (thanks @0xDiddi)
- update BITPOS (thanks @dirkm)
- fix a lua redis.call() return value (thanks @mpetronic)
- update ZRANGE (thanks @valdemarpereira)
### v2.18.0
- support for ZUNION (thanks @propan)
- support for COPY (thanks @matiasinsaurralde and @rockitbaby)
- support for LMOVE (thanks @btwear)
### v2.17.0
- added miniredis.RunT(t)
### v2.16.1
- fix ZINTERSTORE with sets (thanks @lingjl2010 and @okhowang)
- fix exclusive ranges in XRANGE (thanks @joseotoro)
### v2.16.0
- simplify some code (thanks @zonque)
- support for EXAT/PXAT in SET
- support for XTRIM (thanks @joseotoro)
- support for ZRANDMEMBER
- support for redis.log() in lua (thanks @dirkm)
### v2.15.2
- Fix race condition in blocking code (thanks @zonque and @robx)
- XREAD accepts '$' as ID (thanks @bradengroom)
### v2.15.1
- EVAL should cache the script (thanks @guoshimin)
### v2.15.0
- target redis 6.2 and added new args to various commands
- support for all hyperlog commands (thanks @ilbaktin)
- support for GETDEL (thanks @wszaranski)
### v2.14.5
- added XPENDING
- support for BLOCK option in XREAD and XREADGROUP
### v2.14.4
- fix BITPOS error (thanks @xiaoyuzdy)
- small fixes for XREAD, XACK, and XDEL. Mostly error cases.
- fix empty EXEC return type (thanks @ashanbrown)
- fix XDEL (thanks @svakili and @yvesf)
- fix FLUSHALL for streams (thanks @svakili)
### v2.14.3
- fix problem where Lua code didn't set the selected DB
- update to redis 6.0.10 (thanks @lazappa)
### v2.14.2
- update LUA dependency
- deal with (p)unsubscribe when there are no channels
### v2.14.1
- mod tidy
### v2.14.0
- support for HELLO and the RESP3 protocol
- KEEPTTL in SET (thanks @johnpena)
### v2.13.3
- support Go 1.14 and 1.15
- update the `Check...()` methods
- support for XREAD (thanks @pieterlexis)
### v2.13.2
- Use SAN instead of CN in self signed cert for testing (thanks @johejo)
- Travis CI now tests against the most recent two versions of Go (thanks @johejo)
- changed unit and integration tests to compare raw payloads, not parsed payloads
- remove "redigo" dependency
### v2.13.1
- added HSTRLEN
- minimal support for ACL users in AUTH
### v2.13.0
- added RunTLS(...)
- added SetError(...)
### v2.12.0
- redis 6
- Lua json update (thanks @gsmith85)
- CLUSTER commands (thanks @kratisto)
- fix TOUCH
- fix a shutdown race condition
### v2.11.4
- ZUNIONSTORE now supports standard set types (thanks @wshirey)
### v2.11.3
- support for TOUCH (thanks @cleroux)
- support for cluster and stream commands (thanks @kak-tus)
### v2.11.2
- make sure Lua code is executed concurrently
- add command GEORADIUSBYMEMBER (thanks @kyeett)
### v2.11.1
- globals protection for Lua code (thanks @vk-outreach)
- HSET update (thanks @carlgreen)
- fix BLPOP block on shutdown (thanks @Asalle)
### v2.11.0
- added XRANGE/XREVRANGE, XADD, and XLEN (thanks @skateinmars)
- added GEODIST
- improved precision for geohashes, closer to what real redis does
- use 128bit floats internally for INCRBYFLOAT and related (thanks @timnd)
### v2.10.1
- added m.Server()
### v2.10.0
- added UNLINK
- fix DEL zero-argument case
- cleanup some direct access commands
- added GEOADD, GEOPOS, GEORADIUS, and GEORADIUS_RO
### v2.9.1
- fix issue with ZRANGEBYLEX
- fix issue with BRPOPLPUSH and direct access
### v2.9.0
- proper versioned import of github.com/gomodule/redigo (thanks @yfei1)
- fix messages generated by PSUBSCRIBE
- optional internal seed (thanks @zikaeroh)
### v2.8.0
Proper `v2` in go.mod.
### older
See https://github.com/alicebob/miniredis/releases for the full changelog
+21
View File
@@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Harmen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+33
View File
@@ -0,0 +1,33 @@
.PHONY: test
test: ### Run unit tests
go test ./...
.PHONY: testrace
testrace: ### Run unit tests with race detector
go test -race ./...
.PHONY: int
int: ### Run integration tests (doesn't download redis server)
${MAKE} -C integration int
.PHONY: ci
ci: ### Run full tests suite (including download and compilation of proper redis server)
${MAKE} test
${MAKE} -C integration redis_src/redis-server int
${MAKE} testrace
.PHONY: clean
clean: ### Clean integration test files and remove compiled redis from integration/redis_src
${MAKE} -C integration clean
.PHONY: help
help:
ifeq ($(UNAME), Linux)
@grep -P '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | \
awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
else
@# this is not tested, but prepared in advance for you, Mac drivers
@awk -F ':.*###' '$$0 ~ FS {printf "%15s%s\n", $$1 ":", $$2}' \
$(MAKEFILE_LIST) | grep -v '@awk' | sort
endif
+342
View File
@@ -0,0 +1,342 @@
# Miniredis
Pure Go Redis test server, used in Go unittests.
##
Sometimes you want to test code which uses Redis, without making it a full-blown
integration test.
Miniredis implements (parts of) the Redis server, to be used in unittests. It
enables a simple, cheap, in-memory, Redis replacement, with a real TCP interface. Think of it as the Redis version of `net/http/httptest`.
It saves you from using mock code, and since the redis server lives in the
test process you can query for values directly, without going through the server
stack.
There are no dependencies on external binaries, so you can easily integrate it in automated build processes.
Be sure to import v2:
```
import "github.com/alicebob/miniredis/v2"
```
## Commands
Implemented commands:
- Connection (complete)
- AUTH -- see RequireAuth()
- ECHO
- HELLO -- see RequireUserAuth()
- PING
- SELECT
- SWAPDB
- QUIT
- Key
- COPY
- DEL
- EXISTS
- EXPIRE
- EXPIREAT
- EXPIRETIME
- KEYS
- MOVE
- PERSIST
- PEXPIRE
- PEXPIREAT
- PEXPIRETIME
- PTTL
- RANDOMKEY -- see m.Seed(...)
- RENAME
- RENAMENX
- SCAN
- TOUCH
- TTL
- TYPE
- UNLINK
- Transactions (complete)
- DISCARD
- EXEC
- MULTI
- UNWATCH
- WATCH
- Server
- DBSIZE
- FLUSHALL
- FLUSHDB
- TIME -- returns time.Now() or value set by SetTime()
- COMMAND -- partly
- INFO -- partly, returns only "clients" section with one field "connected_clients"
- String keys (complete)
- APPEND
- BITCOUNT
- BITOP
- BITPOS
- DECR
- DECRBY
- GET
- GETBIT
- GETRANGE
- GETSET
- GETDEL
- GETEX
- INCR
- INCRBY
- INCRBYFLOAT
- MGET
- MSET
- MSETNX
- PSETEX
- SET
- SETBIT
- SETEX
- SETNX
- SETRANGE
- STRLEN
- Hash keys (complete)
- HDEL
- HEXISTS
- HGET
- HGETALL
- HINCRBY
- HINCRBYFLOAT
- HKEYS
- HLEN
- HMGET
- HMSET
- HRANDFIELD
- HSET
- HSETNX
- HSTRLEN
- HVALS
- HSCAN
- List keys (complete)
- BLPOP
- BRPOP
- BRPOPLPUSH
- LINDEX
- LINSERT
- LLEN
- LPOP
- LPUSH
- LPUSHX
- LRANGE
- LREM
- LSET
- LTRIM
- RPOP
- RPOPLPUSH
- RPUSH
- RPUSHX
- LMOVE
- BLMOVE
- Pub/Sub (complete)
- PSUBSCRIBE
- PUBLISH
- PUBSUB
- PUNSUBSCRIBE
- SUBSCRIBE
- UNSUBSCRIBE
- Set keys (complete)
- SADD
- SCARD
- SDIFF
- SDIFFSTORE
- SINTER
- SINTERSTORE
- SINTERCARD
- SISMEMBER
- SMEMBERS
- SMISMEMBER
- SMOVE
- SPOP -- see m.Seed(...)
- SRANDMEMBER -- see m.Seed(...)
- SREM
- SSCAN
- SUNION
- SUNIONSTORE
- Sorted Set keys (complete)
- ZADD
- ZCARD
- ZCOUNT
- ZINCRBY
- ZINTER
- ZINTERSTORE
- ZLEXCOUNT
- ZPOPMIN
- ZPOPMAX
- ZRANDMEMBER
- ZRANGE
- ZRANGEBYLEX
- ZRANGEBYSCORE
- ZRANK
- ZREM
- ZREMRANGEBYLEX
- ZREMRANGEBYRANK
- ZREMRANGEBYSCORE
- ZREVRANGE
- ZREVRANGEBYLEX
- ZREVRANGEBYSCORE
- ZREVRANK
- ZSCORE
- ZUNION
- ZUNIONSTORE
- ZSCAN
- Stream keys
- XACK
- XADD
- XAUTOCLAIM
- XCLAIM
- XDEL
- XGROUP CREATE
- XGROUP CREATECONSUMER
- XGROUP DESTROY
- XGROUP DELCONSUMER
- XINFO STREAM -- partly
- XINFO GROUPS
- XINFO CONSUMERS -- partly
- XLEN
- XRANGE
- XREAD
- XREADGROUP
- XREVRANGE
- XPENDING
- XTRIM
- Scripting
- EVAL
- EVALSHA
- SCRIPT LOAD
- SCRIPT EXISTS
- SCRIPT FLUSH
- GEO
- GEOADD
- GEODIST
- ~~GEOHASH~~
- GEOPOS
- GEORADIUS
- GEORADIUS_RO
- GEORADIUSBYMEMBER
- GEORADIUSBYMEMBER_RO
- Cluster
- CLUSTER SLOTS
- CLUSTER KEYSLOT
- CLUSTER NODES
- HyperLogLog (complete)
- PFADD
- PFCOUNT
- PFMERGE
## TTLs, key expiration, and time
Since miniredis is intended to be used in unittests TTLs don't decrease
automatically. You can use `TTL()` to get the TTL (as a time.Duration) of a
key. It will return 0 when no TTL is set.
`m.FastForward(d)` can be used to decrement all TTLs. All TTLs which become <=
0 will be removed.
EXPIREAT and PEXPIREAT values will be
converted to a duration. For that you can either set m.SetTime(t) to use that
time as the base for the (P)EXPIREAT conversion, or don't call SetTime(), in
which case time.Now() will be used.
SetTime() also sets the value returned by TIME, which defaults to time.Now().
It is not updated by FastForward, only by SetTime.
## Randomness and Seed()
Miniredis will use `math/rand`'s global RNG for randomness unless a seed is
provided by calling `m.Seed(...)`. If a seed is provided, then miniredis will
use its own RNG based on that seed.
Commands which use randomness are: RANDOMKEY, SPOP, and SRANDMEMBER.
## Example
``` Go
import (
...
"github.com/alicebob/miniredis/v2"
...
)
func TestSomething(t *testing.T) {
s := miniredis.RunT(t)
// Optionally set some keys your code expects:
s.Set("foo", "bar")
s.HSet("some", "other", "key")
// Run your code and see if it behaves.
// An example using the redigo library from "github.com/gomodule/redigo/redis":
c, err := redis.Dial("tcp", s.Addr())
_, err = c.Do("SET", "foo", "bar")
// Optionally check values in redis...
if got, err := s.Get("foo"); err != nil || got != "bar" {
t.Error("'foo' has the wrong value")
}
// ... or use a helper for that:
s.CheckGet(t, "foo", "bar")
// TTL and expiration:
s.Set("foo", "bar")
s.SetTTL("foo", 10*time.Second)
s.FastForward(11 * time.Second)
if s.Exists("foo") {
t.Fatal("'foo' should not have existed anymore")
}
}
```
## Not supported
Commands which will probably not be implemented:
- CLUSTER (all)
- ~~CLUSTER *~~
- ~~READONLY~~
- ~~READWRITE~~
- Key
- ~~DUMP~~
- ~~MIGRATE~~
- ~~OBJECT~~
- ~~RESTORE~~
- ~~WAIT~~
- Scripting
- ~~FCALL / FCALL_RO *~~
- ~~FUNCTION *~~
- ~~SCRIPT DEBUG~~
- ~~SCRIPT KILL~~
- Server
- ~~BGSAVE~~
- ~~BGWRITEAOF~~
- ~~CLIENT *~~
- ~~CONFIG *~~
- ~~DEBUG *~~
- ~~LASTSAVE~~
- ~~MONITOR~~
- ~~ROLE~~
- ~~SAVE~~
- ~~SHUTDOWN~~
- ~~SLAVEOF~~
- ~~SLOWLOG~~
- ~~SYNC~~
## &c.
Integration tests are run against Redis 7.2.4. The [./integration](./integration/) subdir
compares miniredis against a real redis instance.
The Redis 6 RESP3 protocol is supported. If there are problems, please open
an issue.
If you want to test Redis Sentinel have a look at [minisentinel](https://github.com/Bose/minisentinel).
A changelog is kept at [CHANGELOG.md](https://github.com/alicebob/miniredis/blob/master/CHANGELOG.md).
[![Go Reference](https://pkg.go.dev/badge/github.com/alicebob/miniredis/v2.svg)](https://pkg.go.dev/github.com/alicebob/miniredis/v2)
+63
View File
@@ -0,0 +1,63 @@
package miniredis
import (
"reflect"
"sort"
)
// T is implemented by Testing.T
type T interface {
Helper()
Errorf(string, ...interface{})
}
// CheckGet does not call Errorf() iff there is a string key with the
// expected value. Normal use case is `m.CheckGet(t, "username", "theking")`.
func (m *Miniredis) CheckGet(t T, key, expected string) {
t.Helper()
found, err := m.Get(key)
if err != nil {
t.Errorf("GET error, key %#v: %v", key, err)
return
}
if found != expected {
t.Errorf("GET error, key %#v: Expected %#v, got %#v", key, expected, found)
return
}
}
// CheckList does not call Errorf() iff there is a list key with the
// expected values.
// Normal use case is `m.CheckGet(t, "favorite_colors", "red", "green", "infrared")`.
func (m *Miniredis) CheckList(t T, key string, expected ...string) {
t.Helper()
found, err := m.List(key)
if err != nil {
t.Errorf("List error, key %#v: %v", key, err)
return
}
if !reflect.DeepEqual(expected, found) {
t.Errorf("List error, key %#v: Expected %#v, got %#v", key, expected, found)
return
}
}
// CheckSet does not call Errorf() iff there is a set key with the
// expected values.
// Normal use case is `m.CheckSet(t, "visited", "Rome", "Stockholm", "Dublin")`.
func (m *Miniredis) CheckSet(t T, key string, expected ...string) {
t.Helper()
found, err := m.Members(key)
if err != nil {
t.Errorf("Set error, key %#v: %v", key, err)
return
}
sort.Strings(expected)
if !reflect.DeepEqual(expected, found) {
t.Errorf("Set error, key %#v: Expected %#v, got %#v", key, expected, found)
return
}
}
+68
View File
@@ -0,0 +1,68 @@
package miniredis
import (
"fmt"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
// commandsClient handles client operations.
func commandsClient(m *Miniredis) {
m.srv.Register("CLIENT", m.cmdClient)
}
// CLIENT
func (m *Miniredis) cmdClient(c *server.Peer, cmd string, args []string) {
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR wrong number of arguments for 'client' command")
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
switch cmd := strings.ToUpper(args[0]); cmd {
case "SETNAME":
m.cmdClientSetName(c, args[1:])
case "GETNAME":
m.cmdClientGetName(c, args[1:])
default:
setDirty(c)
c.WriteError(fmt.Sprintf("ERR unknown subcommand '%s'. Try CLIENT HELP.", cmd))
}
})
}
// CLIENT SETNAME
func (m *Miniredis) cmdClientSetName(c *server.Peer, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError("ERR wrong number of arguments for 'client setname' command")
return
}
name := args[0]
if strings.ContainsAny(name, " \n") {
setDirty(c)
c.WriteError("ERR Client names cannot contain spaces, newlines or special characters.")
return
}
c.ClientName = name
c.WriteOK()
}
// CLIENT GETNAME
func (m *Miniredis) cmdClientGetName(c *server.Peer, args []string) {
if len(args) > 0 {
setDirty(c)
c.WriteError("ERR wrong number of arguments for 'client getname' command")
return
}
if c.ClientName == "" {
c.WriteNull()
} else {
c.WriteBulk(c.ClientName)
}
}
+67
View File
@@ -0,0 +1,67 @@
// Commands from https://redis.io/commands#cluster
package miniredis
import (
"fmt"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
// commandsCluster handles some cluster operations.
func commandsCluster(m *Miniredis) {
m.srv.Register("CLUSTER", m.cmdCluster)
}
func (m *Miniredis) cmdCluster(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
switch strings.ToUpper(args[0]) {
case "SLOTS":
m.cmdClusterSlots(c, cmd, args)
case "KEYSLOT":
m.cmdClusterKeySlot(c, cmd, args)
case "NODES":
m.cmdClusterNodes(c, cmd, args)
default:
setDirty(c)
c.WriteError(fmt.Sprintf("ERR 'CLUSTER %s' not supported", strings.Join(args, " ")))
return
}
}
// CLUSTER SLOTS
func (m *Miniredis) cmdClusterSlots(c *server.Peer, cmd string, args []string) {
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
c.WriteLen(1)
c.WriteLen(3)
c.WriteInt(0)
c.WriteInt(16383)
c.WriteLen(3)
c.WriteBulk(m.srv.Addr().IP.String())
c.WriteInt(m.srv.Addr().Port)
c.WriteBulk("09dbe9720cda62f7865eabc5fd8857c5d2678366")
})
}
// CLUSTER KEYSLOT
func (m *Miniredis) cmdClusterKeySlot(c *server.Peer, cmd string, args []string) {
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
c.WriteInt(163)
})
}
// CLUSTER NODES
func (m *Miniredis) cmdClusterNodes(c *server.Peer, cmd string, args []string) {
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
c.WriteBulk("e7d1eecce10fd6bb5eb35b9f99a514335d9ba9ca 127.0.0.1:7000@7000 myself,master - 0 0 1 connected 0-16383")
})
}
File diff suppressed because one or more lines are too long
+285
View File
@@ -0,0 +1,285 @@
// Commands from https://redis.io/commands#connection
package miniredis
import (
"fmt"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
func commandsConnection(m *Miniredis) {
m.srv.Register("AUTH", m.cmdAuth)
m.srv.Register("ECHO", m.cmdEcho)
m.srv.Register("HELLO", m.cmdHello)
m.srv.Register("PING", m.cmdPing)
m.srv.Register("QUIT", m.cmdQuit)
m.srv.Register("SELECT", m.cmdSelect)
m.srv.Register("SWAPDB", m.cmdSwapdb)
}
// PING
func (m *Miniredis) cmdPing(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if len(args) > 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
payload := ""
if len(args) > 0 {
payload = args[0]
}
// PING is allowed in subscribed state
if sub := getCtx(c).subscriber; sub != nil {
c.Block(func(c *server.Writer) {
c.WriteLen(2)
c.WriteBulk("pong")
c.WriteBulk(payload)
})
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if payload == "" {
c.WriteInline("PONG")
return
}
c.WriteBulk(payload)
})
}
// AUTH
func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if len(args) > 2 {
c.WriteError(msgSyntaxError)
return
}
if m.checkPubsub(c, cmd) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
var opts = struct {
username string
password string
}{
username: "default",
password: args[0],
}
if len(args) == 2 {
opts.username, opts.password = args[0], args[1]
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if len(m.passwords) == 0 && opts.username == "default" {
c.WriteError("ERR AUTH <password> called without any password configured for the default user. Are you sure your configuration is correct?")
return
}
setPW, ok := m.passwords[opts.username]
if !ok {
c.WriteError("WRONGPASS invalid username-password pair")
return
}
if setPW != opts.password {
c.WriteError("WRONGPASS invalid username-password pair")
return
}
ctx.authenticated = true
c.WriteOK()
})
}
// HELLO
func (m *Miniredis) cmdHello(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
c.WriteError(errWrongNumber(cmd))
return
}
var opts struct {
version int
username string
password string
}
if ok := optIntErr(c, args[0], &opts.version, "ERR Protocol version is not an integer or out of range"); !ok {
return
}
args = args[1:]
switch opts.version {
case 2, 3:
default:
c.WriteError("NOPROTO unsupported protocol version")
return
}
var checkAuth bool
for len(args) > 0 {
switch strings.ToUpper(args[0]) {
case "AUTH":
if len(args) < 3 {
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
return
}
opts.username, opts.password, args = args[1], args[2], args[3:]
checkAuth = true
case "SETNAME":
if len(args) < 2 {
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
return
}
_, args = args[1], args[2:]
default:
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
return
}
}
if len(m.passwords) == 0 && opts.username == "default" {
// redis ignores legacy "AUTH" if it's not enabled.
checkAuth = false
}
if checkAuth {
setPW, ok := m.passwords[opts.username]
if !ok {
c.WriteError("WRONGPASS invalid username-password pair")
return
}
if setPW != opts.password {
c.WriteError("WRONGPASS invalid username-password pair")
return
}
getCtx(c).authenticated = true
}
c.Resp3 = opts.version == 3
c.WriteMapLen(7)
c.WriteBulk("server")
c.WriteBulk("miniredis")
c.WriteBulk("version")
c.WriteBulk("6.0.5")
c.WriteBulk("proto")
c.WriteInt(opts.version)
c.WriteBulk("id")
c.WriteInt(42)
c.WriteBulk("mode")
c.WriteBulk("standalone")
c.WriteBulk("role")
c.WriteBulk("master")
c.WriteBulk("modules")
c.WriteLen(0)
}
// ECHO
func (m *Miniredis) cmdEcho(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
msg := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
c.WriteBulk(msg)
})
}
// SELECT
func (m *Miniredis) cmdSelect(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.isValidCMD(c, cmd) {
return
}
var opts struct {
id int
}
if ok := optInt(c, args[0], &opts.id); !ok {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if opts.id < 0 {
c.WriteError(msgDBIndexOutOfRange)
setDirty(c)
return
}
ctx.selectedDB = opts.id
c.WriteOK()
})
}
// SWAPDB
func (m *Miniredis) cmdSwapdb(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
var opts struct {
id1 int
id2 int
}
if ok := optIntErr(c, args[0], &opts.id1, "ERR invalid first DB index"); !ok {
return
}
if ok := optIntErr(c, args[1], &opts.id2, "ERR invalid second DB index"); !ok {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if opts.id1 < 0 || opts.id2 < 0 {
c.WriteError(msgDBIndexOutOfRange)
setDirty(c)
return
}
m.swapDB(opts.id1, opts.id2)
c.WriteOK()
})
}
// QUIT
func (m *Miniredis) cmdQuit(c *server.Peer, cmd string, args []string) {
// QUIT isn't transactionfied and accepts any arguments.
c.WriteOK()
c.Close()
}
+813
View File
@@ -0,0 +1,813 @@
// Commands from https://redis.io/commands#generic
package miniredis
import (
"errors"
"fmt"
"sort"
"strconv"
"strings"
"time"
"github.com/alicebob/miniredis/v2/server"
)
const (
// expiretimeReplyNoExpiration is return value for EXPIRETIME and PEXPIRETIME if the key exists but has no associated expiration time
expiretimeReplyNoExpiration = -1
// expiretimeReplyMissingKey is return value for EXPIRETIME and PEXPIRETIME if the key does not exist
expiretimeReplyMissingKey = -2
)
func inSeconds(t time.Time) int {
return int(t.Unix())
}
func inMilliSeconds(t time.Time) int {
return int(t.UnixMilli())
}
// commandsGeneric handles EXPIRE, TTL, PERSIST, &c.
func commandsGeneric(m *Miniredis) {
m.srv.Register("COPY", m.cmdCopy)
m.srv.Register("DEL", m.cmdDel)
// DUMP
m.srv.Register("EXISTS", m.cmdExists)
m.srv.Register("EXPIRE", makeCmdExpire(m, false, time.Second))
m.srv.Register("EXPIREAT", makeCmdExpire(m, true, time.Second))
m.srv.Register("EXPIRETIME", m.makeCmdExpireTime(inSeconds))
m.srv.Register("PEXPIRETIME", m.makeCmdExpireTime(inMilliSeconds))
m.srv.Register("KEYS", m.cmdKeys)
// MIGRATE
m.srv.Register("MOVE", m.cmdMove)
// OBJECT
m.srv.Register("PERSIST", m.cmdPersist)
m.srv.Register("PEXPIRE", makeCmdExpire(m, false, time.Millisecond))
m.srv.Register("PEXPIREAT", makeCmdExpire(m, true, time.Millisecond))
m.srv.Register("PTTL", m.cmdPTTL)
m.srv.Register("RANDOMKEY", m.cmdRandomkey)
m.srv.Register("RENAME", m.cmdRename)
m.srv.Register("RENAMENX", m.cmdRenamenx)
// RESTORE
m.srv.Register("TOUCH", m.cmdTouch)
m.srv.Register("TTL", m.cmdTTL)
m.srv.Register("TYPE", m.cmdType)
m.srv.Register("SCAN", m.cmdScan)
// SORT
m.srv.Register("UNLINK", m.cmdDel)
}
type expireOpts struct {
key string
value int
nx bool
xx bool
gt bool
lt bool
}
func expireParse(cmd string, args []string) (*expireOpts, error) {
var opts expireOpts
opts.key = args[0]
if err := optIntSimple(args[1], &opts.value); err != nil {
return nil, err
}
args = args[2:]
for len(args) > 0 {
switch strings.ToLower(args[0]) {
case "nx":
opts.nx = true
case "xx":
opts.xx = true
case "gt":
opts.gt = true
case "lt":
opts.lt = true
default:
return nil, fmt.Errorf("ERR Unsupported option %s", args[0])
}
args = args[1:]
}
if opts.gt && opts.lt {
return nil, errors.New("ERR GT and LT options at the same time are not compatible")
}
if opts.nx && (opts.xx || opts.gt || opts.lt) {
return nil, errors.New("ERR NX and XX, GT or LT options at the same time are not compatible")
}
return &opts, nil
}
// generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT
// d is the time unit. If unix is set it'll be seen as a unixtimestamp and
// converted to a duration.
func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, string, []string) {
return func(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts, err := expireParse(cmd, args)
if err != nil {
setDirty(c)
c.WriteError(err.Error())
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
// Key must be present.
if _, ok := db.keys[opts.key]; !ok {
c.WriteInt(0)
return
}
oldTTL, ok := db.ttl[opts.key]
var newTTL time.Duration
if unix {
newTTL = m.at(opts.value, d)
} else {
newTTL = time.Duration(opts.value) * d
}
// > NX -- Set expiry only when the key has no expiry
if opts.nx && ok {
c.WriteInt(0)
return
}
// > XX -- Set expiry only when the key has an existing expiry
if opts.xx && !ok {
c.WriteInt(0)
return
}
// > GT -- Set expiry only when the new expiry is greater than current one
// (no exp == infinity)
if opts.gt && (!ok || newTTL <= oldTTL) {
c.WriteInt(0)
return
}
// > LT -- Set expiry only when the new expiry is less than current one
if opts.lt && ok && newTTL > oldTTL {
c.WriteInt(0)
return
}
db.ttl[opts.key] = newTTL
db.incr(opts.key)
db.checkTTL(opts.key)
c.WriteInt(1)
})
}
}
// makeCmdExpireTime creates server command function that returns the absolute Unix timestamp (since January 1, 1970)
// at which the given key will expire, in unit selected by time result strategy (e.g. seconds, milliseconds).
// For more information see redis documentation for [expiretime] and [pexpiretime].
//
// [expiretime]: https://redis.io/commands/expiretime/
// [pexpiretime]: https://redis.io/commands/pexpiretime/
func (m *Miniredis) makeCmdExpireTime(timeResultStrategy func(time.Time) int) server.Cmd {
return func(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if _, ok := db.keys[key]; !ok {
c.WriteInt(expiretimeReplyMissingKey)
return
}
ttl, ok := db.ttl[key]
if !ok {
c.WriteInt(expiretimeReplyNoExpiration)
return
}
c.WriteInt(timeResultStrategy(m.effectiveNow().Add(ttl)))
})
}
}
// TOUCH
func (m *Miniredis) cmdTouch(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
count := 0
for _, key := range args {
if db.exists(key) {
count++
}
}
c.WriteInt(count)
})
}
// TTL
func (m *Miniredis) cmdTTL(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if _, ok := db.keys[key]; !ok {
// No such key
c.WriteInt(-2)
return
}
v, ok := db.ttl[key]
if !ok {
// no expire value
c.WriteInt(-1)
return
}
c.WriteInt(int(v.Seconds()))
})
}
// PTTL
func (m *Miniredis) cmdPTTL(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if _, ok := db.keys[key]; !ok {
// no such key
c.WriteInt(-2)
return
}
v, ok := db.ttl[key]
if !ok {
// no expire value
c.WriteInt(-1)
return
}
c.WriteInt(int(v.Nanoseconds() / 1000000))
})
}
// PERSIST
func (m *Miniredis) cmdPersist(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if _, ok := db.keys[key]; !ok {
// no such key
c.WriteInt(0)
return
}
if _, ok := db.ttl[key]; !ok {
// no expire value
c.WriteInt(0)
return
}
delete(db.ttl, key)
db.incr(key)
c.WriteInt(1)
})
}
// DEL and UNLINK
func (m *Miniredis) cmdDel(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
count := 0
for _, key := range args {
if db.exists(key) {
count++
}
db.del(key, true) // delete expire
}
c.WriteInt(count)
})
}
// TYPE
func (m *Miniredis) cmdType(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError("usage error")
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteInline("none")
return
}
c.WriteInline(t)
})
}
// EXISTS
func (m *Miniredis) cmdExists(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
found := 0
for _, k := range args {
if db.exists(k) {
found++
}
}
c.WriteInt(found)
})
}
// MOVE
func (m *Miniredis) cmdMove(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
var opts struct {
key string
targetDB int
}
opts.key = args[0]
opts.targetDB, _ = strconv.Atoi(args[1])
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if ctx.selectedDB == opts.targetDB {
c.WriteError("ERR source and destination objects are the same")
return
}
db := m.db(ctx.selectedDB)
targetDB := m.db(opts.targetDB)
if !db.move(opts.key, targetDB) {
c.WriteInt(0)
return
}
c.WriteInt(1)
})
}
// KEYS
func (m *Miniredis) cmdKeys(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
keys, _ := matchKeys(db.allKeys(), key)
c.WriteLen(len(keys))
for _, s := range keys {
c.WriteBulk(s)
}
})
}
// RANDOMKEY
func (m *Miniredis) cmdRandomkey(c *server.Peer, cmd string, args []string) {
if len(args) != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if len(db.keys) == 0 {
c.WriteNull()
return
}
nr := m.randIntn(len(db.keys))
for k := range db.keys {
if nr == 0 {
c.WriteBulk(k)
return
}
nr--
}
})
}
// RENAME
func (m *Miniredis) cmdRename(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
from string
to string
}{
from: args[0],
to: args[1],
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(opts.from) {
c.WriteError(msgKeyNotFound)
return
}
db.rename(opts.from, opts.to)
c.WriteOK()
})
}
// RENAMENX
func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
from string
to string
}{
from: args[0],
to: args[1],
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(opts.from) {
c.WriteError(msgKeyNotFound)
return
}
if db.exists(opts.to) {
c.WriteInt(0)
return
}
db.rename(opts.from, opts.to)
c.WriteInt(1)
})
}
type scanOpts struct {
cursor int
count int
withMatch bool
match string
withType bool
_type string
}
func scanParse(cmd string, args []string) (*scanOpts, error) {
var opts scanOpts
if err := optIntSimple(args[0], &opts.cursor); err != nil {
return nil, errors.New(msgInvalidCursor)
}
args = args[1:]
// MATCH, COUNT and TYPE options
for len(args) > 0 {
if strings.ToLower(args[0]) == "count" {
if len(args) < 2 {
return nil, errors.New(msgSyntaxError)
}
count, err := strconv.Atoi(args[1])
if err != nil || count < 0 {
return nil, errors.New(msgInvalidInt)
}
if count == 0 {
return nil, errors.New(msgSyntaxError)
}
opts.count = count
args = args[2:]
continue
}
if strings.ToLower(args[0]) == "match" {
if len(args) < 2 {
return nil, errors.New(msgSyntaxError)
}
opts.withMatch = true
opts.match, args = args[1], args[2:]
continue
}
if strings.ToLower(args[0]) == "type" {
if len(args) < 2 {
return nil, errors.New(msgSyntaxError)
}
opts.withType = true
opts._type, args = strings.ToLower(args[1]), args[2:]
continue
}
return nil, errors.New(msgSyntaxError)
}
return &opts, nil
}
// SCAN
func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts, err := scanParse(cmd, args)
if err != nil {
setDirty(c)
c.WriteError(err.Error())
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
// We return _all_ (matched) keys every time.
var keys []string
if opts.withType {
keys = make([]string, 0)
for k, t := range db.keys {
// type must be given exactly; no pattern matching is performed
if t == opts._type {
keys = append(keys, k)
}
}
} else {
keys = db.allKeys()
}
sort.Strings(keys) // To make things deterministic.
if opts.withMatch {
keys, _ = matchKeys(keys, opts.match)
}
low := opts.cursor
high := low + opts.count
// validate high is correct
if high > len(keys) || high == 0 {
high = len(keys)
}
if opts.cursor > high {
// invalid cursor
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
c.WriteLen(0) // no elements
return
}
cursorValue := low + opts.count
if cursorValue >= len(keys) {
cursorValue = 0 // no next cursor
}
keys = keys[low:high]
c.WriteLen(2)
c.WriteBulk(fmt.Sprintf("%d", cursorValue))
c.WriteLen(len(keys))
for _, k := range keys {
c.WriteBulk(k)
}
})
}
type copyOpts struct {
from string
to string
destinationDB int
replace bool
}
func copyParse(cmd string, args []string) (*copyOpts, error) {
opts := copyOpts{
destinationDB: -1,
}
opts.from, opts.to, args = args[0], args[1], args[2:]
for len(args) > 0 {
switch strings.ToLower(args[0]) {
case "db":
if len(args) < 2 {
return nil, errors.New(msgSyntaxError)
}
if err := optIntSimple(args[1], &opts.destinationDB); err != nil {
return nil, err
}
if opts.destinationDB < 0 {
return nil, errors.New(msgDBIndexOutOfRange)
}
args = args[2:]
case "replace":
opts.replace = true
args = args[1:]
default:
return nil, errors.New(msgSyntaxError)
}
}
return &opts, nil
}
// COPY
func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts, err := copyParse(cmd, args)
if err != nil {
setDirty(c)
c.WriteError(err.Error())
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
fromDB, toDB := ctx.selectedDB, opts.destinationDB
if toDB == -1 {
toDB = fromDB
}
if fromDB == toDB && opts.from == opts.to {
c.WriteError("ERR source and destination objects are the same")
return
}
if !m.db(fromDB).exists(opts.from) {
c.WriteInt(0)
return
}
if !opts.replace {
if m.db(toDB).exists(opts.to) {
c.WriteInt(0)
return
}
}
m.copy(m.db(fromDB), opts.from, m.db(toDB), opts.to)
c.WriteInt(1)
})
}
+609
View File
@@ -0,0 +1,609 @@
// Commands from https://redis.io/commands#geo
package miniredis
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
// commandsGeo handles GEOADD, GEORADIUS etc.
func commandsGeo(m *Miniredis) {
m.srv.Register("GEOADD", m.cmdGeoadd)
m.srv.Register("GEODIST", m.cmdGeodist)
m.srv.Register("GEOPOS", m.cmdGeopos)
m.srv.Register("GEORADIUS", m.cmdGeoradius)
m.srv.Register("GEORADIUS_RO", m.cmdGeoradius)
m.srv.Register("GEORADIUSBYMEMBER", m.cmdGeoradiusbymember)
m.srv.Register("GEORADIUSBYMEMBER_RO", m.cmdGeoradiusbymember)
}
// GEOADD
func (m *Miniredis) cmdGeoadd(c *server.Peer, cmd string, args []string) {
if len(args) < 3 || len(args[1:])%3 != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if db.exists(key) && db.t(key) != keyTypeSortedSet {
c.WriteError(ErrWrongType.Error())
return
}
toSet := map[string]float64{}
for len(args) > 2 {
rawLong, rawLat, name := args[0], args[1], args[2]
args = args[3:]
longitude, err := strconv.ParseFloat(rawLong, 64)
if err != nil {
c.WriteError("ERR value is not a valid float")
return
}
latitude, err := strconv.ParseFloat(rawLat, 64)
if err != nil {
c.WriteError("ERR value is not a valid float")
return
}
if latitude < -85.05112878 ||
latitude > 85.05112878 ||
longitude < -180 ||
longitude > 180 {
c.WriteError(fmt.Sprintf("ERR invalid longitude,latitude pair %.6f,%.6f", longitude, latitude))
return
}
toSet[name] = float64(toGeohash(longitude, latitude))
}
set := 0
for name, score := range toSet {
if db.ssetAdd(key, score, name) {
set++
}
}
c.WriteInt(set)
})
}
// GEODIST
func (m *Miniredis) cmdGeodist(c *server.Peer, cmd string, args []string) {
if len(args) < 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, from, to, args := args[0], args[1], args[2], args[3:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteNull()
return
}
if db.t(key) != keyTypeSortedSet {
c.WriteError(ErrWrongType.Error())
return
}
unit := "m"
if len(args) > 0 {
unit, args = args[0], args[1:]
}
if len(args) > 0 {
c.WriteError(msgSyntaxError)
return
}
toMeter := parseUnit(unit)
if toMeter == 0 {
c.WriteError(msgUnsupportedUnit)
return
}
members := db.sortedsetKeys[key]
fromD, okFrom := members.get(from)
toD, okTo := members.get(to)
if !okFrom || !okTo {
c.WriteNull()
return
}
fromLo, fromLat := fromGeohash(uint64(fromD))
toLo, toLat := fromGeohash(uint64(toD))
dist := distance(fromLat, fromLo, toLat, toLo) / toMeter
c.WriteBulk(fmt.Sprintf("%.4f", dist))
})
}
// GEOPOS
func (m *Miniredis) cmdGeopos(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if db.exists(key) && db.t(key) != keyTypeSortedSet {
c.WriteError(ErrWrongType.Error())
return
}
c.WriteLen(len(args))
for _, l := range args {
if !db.ssetExists(key, l) {
c.WriteLen(-1)
continue
}
score := db.ssetScore(key, l)
c.WriteLen(2)
long, lat := fromGeohash(uint64(score))
c.WriteBulk(fmt.Sprintf("%f", long))
c.WriteBulk(fmt.Sprintf("%f", lat))
}
})
}
type geoDistance struct {
Name string
Score float64
Distance float64
Longitude float64
Latitude float64
}
// GEORADIUS and GEORADIUS_RO
func (m *Miniredis) cmdGeoradius(c *server.Peer, cmd string, args []string) {
if len(args) < 5 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
longitude, err := strconv.ParseFloat(args[1], 64)
if err != nil {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
latitude, err := strconv.ParseFloat(args[2], 64)
if err != nil {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
radius, err := strconv.ParseFloat(args[3], 64)
if err != nil || radius < 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
toMeter := parseUnit(args[4])
if toMeter == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
args = args[5:]
var opts struct {
withDist bool
withCoord bool
direction direction // unsorted
count int
withStore bool
storeKey string
withStoredist bool
storedistKey string
}
for len(args) > 0 {
arg := args[0]
args = args[1:]
switch strings.ToUpper(arg) {
case "WITHCOORD":
opts.withCoord = true
case "WITHDIST":
opts.withDist = true
case "ASC":
opts.direction = asc
case "DESC":
opts.direction = desc
case "COUNT":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
n, err := strconv.Atoi(args[0])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
if n <= 0 {
setDirty(c)
c.WriteError("ERR COUNT must be > 0")
return
}
args = args[1:]
opts.count = n
case "STORE":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
opts.withStore = true
opts.storeKey = args[0]
args = args[1:]
case "STOREDIST":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
opts.withStoredist = true
opts.storedistKey = args[0]
args = args[1:]
default:
setDirty(c)
c.WriteError("ERR syntax error")
return
}
}
if strings.ToUpper(cmd) == "GEORADIUS_RO" && (opts.withStore || opts.withStoredist) {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
return
}
db := m.db(ctx.selectedDB)
members := db.ssetElements(key)
matches := withinRadius(members, longitude, latitude, radius*toMeter)
// deal with ASC/DESC
if opts.direction != unsorted {
sort.Slice(matches, func(i, j int) bool {
if opts.direction == desc {
return matches[i].Distance > matches[j].Distance
}
return matches[i].Distance < matches[j].Distance
})
}
// deal with COUNT
if opts.count > 0 && len(matches) > opts.count {
matches = matches[:opts.count]
}
// deal with "STORE x"
if opts.withStore {
db.del(opts.storeKey, true)
for _, member := range matches {
db.ssetAdd(opts.storeKey, member.Score, member.Name)
}
c.WriteInt(len(matches))
return
}
// deal with "STOREDIST x"
if opts.withStoredist {
db.del(opts.storedistKey, true)
for _, member := range matches {
db.ssetAdd(opts.storedistKey, member.Distance/toMeter, member.Name)
}
c.WriteInt(len(matches))
return
}
c.WriteLen(len(matches))
for _, member := range matches {
if !opts.withDist && !opts.withCoord {
c.WriteBulk(member.Name)
continue
}
len := 1
if opts.withDist {
len++
}
if opts.withCoord {
len++
}
c.WriteLen(len)
c.WriteBulk(member.Name)
if opts.withDist {
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/toMeter))
}
if opts.withCoord {
c.WriteLen(2)
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
}
}
})
}
// GEORADIUSBYMEMBER and GEORADIUSBYMEMBER_RO
func (m *Miniredis) cmdGeoradiusbymember(c *server.Peer, cmd string, args []string) {
if len(args) < 4 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
member string
radius float64
toMeter float64
withDist bool
withCoord bool
direction direction // unsorted
count int
withStore bool
storeKey string
withStoredist bool
storedistKey string
}{
key: args[0],
member: args[1],
}
r, err := strconv.ParseFloat(args[2], 64)
if err != nil || r < 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
opts.radius = r
opts.toMeter = parseUnit(args[3])
if opts.toMeter == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
args = args[4:]
for len(args) > 0 {
arg := args[0]
args = args[1:]
switch strings.ToUpper(arg) {
case "WITHCOORD":
opts.withCoord = true
case "WITHDIST":
opts.withDist = true
case "ASC":
opts.direction = asc
case "DESC":
opts.direction = desc
case "COUNT":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
n, err := strconv.Atoi(args[0])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
if n <= 0 {
setDirty(c)
c.WriteError("ERR COUNT must be > 0")
return
}
args = args[1:]
opts.count = n
case "STORE":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
opts.withStore = true
opts.storeKey = args[0]
args = args[1:]
case "STOREDIST":
if len(args) == 0 {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
opts.withStoredist = true
opts.storedistKey = args[0]
args = args[1:]
default:
setDirty(c)
c.WriteError("ERR syntax error")
return
}
}
if strings.ToUpper(cmd) == "GEORADIUSBYMEMBER_RO" && (opts.withStore || opts.withStoredist) {
setDirty(c)
c.WriteError("ERR syntax error")
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
return
}
db := m.db(ctx.selectedDB)
if !db.exists(opts.key) {
c.WriteNull()
return
}
if db.t(opts.key) != keyTypeSortedSet {
c.WriteError(ErrWrongType.Error())
return
}
// get position of member
if !db.ssetExists(opts.key, opts.member) {
c.WriteError("ERR could not decode requested zset member")
return
}
score := db.ssetScore(opts.key, opts.member)
longitude, latitude := fromGeohash(uint64(score))
members := db.ssetElements(opts.key)
matches := withinRadius(members, longitude, latitude, opts.radius*opts.toMeter)
// deal with ASC/DESC
if opts.direction != unsorted {
sort.Slice(matches, func(i, j int) bool {
if opts.direction == desc {
return matches[i].Distance > matches[j].Distance
}
return matches[i].Distance < matches[j].Distance
})
}
// deal with COUNT
if opts.count > 0 && len(matches) > opts.count {
matches = matches[:opts.count]
}
// deal with "STORE x"
if opts.withStore {
db.del(opts.storeKey, true)
for _, member := range matches {
db.ssetAdd(opts.storeKey, member.Score, member.Name)
}
c.WriteInt(len(matches))
return
}
// deal with "STOREDIST x"
if opts.withStoredist {
db.del(opts.storedistKey, true)
for _, member := range matches {
db.ssetAdd(opts.storedistKey, member.Distance/opts.toMeter, member.Name)
}
c.WriteInt(len(matches))
return
}
c.WriteLen(len(matches))
for _, member := range matches {
if !opts.withDist && !opts.withCoord {
c.WriteBulk(member.Name)
continue
}
len := 1
if opts.withDist {
len++
}
if opts.withCoord {
len++
}
c.WriteLen(len)
c.WriteBulk(member.Name)
if opts.withDist {
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/opts.toMeter))
}
if opts.withCoord {
c.WriteLen(2)
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
}
}
})
}
func withinRadius(members []ssElem, longitude, latitude, radius float64) []geoDistance {
matches := []geoDistance{}
for _, el := range members {
elLo, elLat := fromGeohash(uint64(el.score))
distanceInMeter := distance(latitude, longitude, elLat, elLo)
if distanceInMeter <= radius {
matches = append(matches, geoDistance{
Name: el.member,
Score: el.score,
Distance: distanceInMeter,
Longitude: elLo,
Latitude: elLat,
})
}
}
return matches
}
func parseUnit(u string) float64 {
switch strings.ToLower(u) {
case "m":
return 1
case "km":
return 1000
case "mi":
return 1609.34
case "ft":
return 0.3048
default:
return 0
}
}
+777
View File
@@ -0,0 +1,777 @@
// Commands from https://redis.io/commands#hash
package miniredis
import (
"math/big"
"strconv"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
// commandsHash handles all hash value operations.
func commandsHash(m *Miniredis) {
m.srv.Register("HDEL", m.cmdHdel)
m.srv.Register("HEXISTS", m.cmdHexists)
m.srv.Register("HGET", m.cmdHget)
m.srv.Register("HGETALL", m.cmdHgetall)
m.srv.Register("HINCRBY", m.cmdHincrby)
m.srv.Register("HINCRBYFLOAT", m.cmdHincrbyfloat)
m.srv.Register("HKEYS", m.cmdHkeys)
m.srv.Register("HLEN", m.cmdHlen)
m.srv.Register("HMGET", m.cmdHmget)
m.srv.Register("HMSET", m.cmdHmset)
m.srv.Register("HSET", m.cmdHset)
m.srv.Register("HSETNX", m.cmdHsetnx)
m.srv.Register("HSTRLEN", m.cmdHstrlen)
m.srv.Register("HVALS", m.cmdHvals)
m.srv.Register("HSCAN", m.cmdHscan)
m.srv.Register("HRANDFIELD", m.cmdHrandfield)
}
// HSET
func (m *Miniredis) cmdHset(c *server.Peer, cmd string, args []string) {
if len(args) < 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, pairs := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if len(pairs)%2 == 1 {
c.WriteError(errWrongNumber(cmd))
return
}
if t, ok := db.keys[key]; ok && t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
new := db.hashSet(key, pairs...)
c.WriteInt(new)
})
}
// HSETNX
func (m *Miniredis) cmdHsetnx(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
field string
value string
}{
key: args[0],
field: args[1],
value: args[2],
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[opts.key]; ok && t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
if _, ok := db.hashKeys[opts.key]; !ok {
db.hashKeys[opts.key] = map[string]string{}
db.keys[opts.key] = keyTypeHash
}
_, ok := db.hashKeys[opts.key][opts.field]
if ok {
c.WriteInt(0)
return
}
db.hashKeys[opts.key][opts.field] = opts.value
db.incr(opts.key)
c.WriteInt(1)
})
}
// HMSET
func (m *Miniredis) cmdHmset(c *server.Peer, cmd string, args []string) {
if len(args) < 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, args := args[0], args[1:]
if len(args)%2 != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[key]; ok && t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
for len(args) > 0 {
field, value := args[0], args[1]
args = args[2:]
db.hashSet(key, field, value)
}
c.WriteOK()
})
}
// HGET
func (m *Miniredis) cmdHget(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, field := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteNull()
return
}
if t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
value, ok := db.hashKeys[key][field]
if !ok {
c.WriteNull()
return
}
c.WriteBulk(value)
})
}
// HDEL
func (m *Miniredis) cmdHdel(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
fields []string
}{
key: args[0],
fields: args[1:],
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[opts.key]
if !ok {
// No key is zero deleted
c.WriteInt(0)
return
}
if t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
deleted := 0
for _, f := range opts.fields {
_, ok := db.hashKeys[opts.key][f]
if !ok {
continue
}
delete(db.hashKeys[opts.key], f)
deleted++
}
c.WriteInt(deleted)
// Nothing left. Remove the whole key.
if len(db.hashKeys[opts.key]) == 0 {
db.del(opts.key, true)
}
})
}
// HEXISTS
func (m *Miniredis) cmdHexists(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
field string
}{
key: args[0],
field: args[1],
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[opts.key]
if !ok {
c.WriteInt(0)
return
}
if t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
if _, ok := db.hashKeys[opts.key][opts.field]; !ok {
c.WriteInt(0)
return
}
c.WriteInt(1)
})
}
// HGETALL
func (m *Miniredis) cmdHgetall(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteMapLen(0)
return
}
if t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
c.WriteMapLen(len(db.hashKeys[key]))
for _, k := range db.hashFields(key) {
c.WriteBulk(k)
c.WriteBulk(db.hashGet(key, k))
}
})
}
// HKEYS
func (m *Miniredis) cmdHkeys(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteLen(0)
return
}
if db.t(key) != keyTypeHash {
c.WriteError(msgWrongType)
return
}
fields := db.hashFields(key)
c.WriteLen(len(fields))
for _, f := range fields {
c.WriteBulk(f)
}
})
}
// HSTRLEN
func (m *Miniredis) cmdHstrlen(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
hash, key := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[hash]
if !ok {
c.WriteInt(0)
return
}
if t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
keys := db.hashKeys[hash]
c.WriteInt(len(keys[key]))
})
}
// HVALS
func (m *Miniredis) cmdHvals(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteLen(0)
return
}
if t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
vals := db.hashValues(key)
c.WriteLen(len(vals))
for _, v := range vals {
c.WriteBulk(v)
}
})
}
// HLEN
func (m *Miniredis) cmdHlen(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.keys[key]
if !ok {
c.WriteInt(0)
return
}
if t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
c.WriteInt(len(db.hashKeys[key]))
})
}
// HMGET
func (m *Miniredis) cmdHmget(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[key]; ok && t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
f, ok := db.hashKeys[key]
if !ok {
f = map[string]string{}
}
c.WriteLen(len(args) - 1)
for _, k := range args[1:] {
v, ok := f[k]
if !ok {
c.WriteNull()
continue
}
c.WriteBulk(v)
}
})
}
// HINCRBY
func (m *Miniredis) cmdHincrby(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
field string
delta int
}{
key: args[0],
field: args[1],
}
if ok := optInt(c, args[2], &opts.delta); !ok {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[opts.key]; ok && t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
v, err := db.hashIncr(opts.key, opts.field, opts.delta)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteInt(v)
})
}
// HINCRBYFLOAT
func (m *Miniredis) cmdHincrbyfloat(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
field string
delta *big.Float
}{
key: args[0],
field: args[1],
}
delta, _, err := big.ParseFloat(args[2], 10, 128, 0)
if err != nil {
setDirty(c)
c.WriteError(msgInvalidFloat)
return
}
opts.delta = delta
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if t, ok := db.keys[opts.key]; ok && t != keyTypeHash {
c.WriteError(msgWrongType)
return
}
v, err := db.hashIncrfloat(opts.key, opts.field, opts.delta)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteBulk(formatBig(v))
})
}
// HSCAN
func (m *Miniredis) cmdHscan(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
cursor int
withMatch bool
match string
}{
key: args[0],
}
if ok := optIntErr(c, args[1], &opts.cursor, msgInvalidCursor); !ok {
return
}
args = args[2:]
// MATCH and COUNT options
for len(args) > 0 {
if strings.ToLower(args[0]) == "count" {
// we do nothing with count
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
_, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
args = args[2:]
continue
}
if strings.ToLower(args[0]) == "match" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
opts.withMatch = true
opts.match, args = args[1], args[2:]
continue
}
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
// return _all_ (matched) keys every time
if opts.cursor != 0 {
// Invalid cursor.
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
c.WriteLen(0) // no elements
return
}
if db.exists(opts.key) && db.t(opts.key) != keyTypeHash {
c.WriteError(ErrWrongType.Error())
return
}
members := db.hashFields(opts.key)
if opts.withMatch {
members, _ = matchKeys(members, opts.match)
}
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
// HSCAN gives key, values.
c.WriteLen(len(members) * 2)
for _, k := range members {
c.WriteBulk(k)
c.WriteBulk(db.hashGet(opts.key, k))
}
})
}
// HRANDFIELD
func (m *Miniredis) cmdHrandfield(c *server.Peer, cmd string, args []string) {
if len(args) > 3 || len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
count int
countSet bool
withValues bool
}{
key: args[0],
}
if len(args) > 1 {
if ok := optIntErr(c, args[1], &opts.count, msgInvalidInt); !ok {
return
}
opts.countSet = true
}
if len(args) == 3 {
if strings.ToLower(args[2]) == "withvalues" {
opts.withValues = true
} else {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
}
withTx(m, c, func(peer *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
members := db.hashFields(opts.key)
m.shuffle(members)
if !opts.countSet {
// > When called with just the key argument, return a random field from the
// hash value stored at key.
if len(members) == 0 {
peer.WriteNull()
return
}
peer.WriteBulk(members[0])
return
}
if len(members) > abs(opts.count) {
members = members[:abs(opts.count)]
}
switch {
case opts.count >= 0:
// if count is positive there can't be duplicates, and the length is restricted
case opts.count < 0:
// if count is negative there can be duplicates, but length will match
if len(members) > 0 {
for len(members) < -opts.count {
members = append(members, members[m.randIntn(len(members))])
}
}
}
if opts.withValues {
peer.WriteMapLen(len(members))
for _, m := range members {
peer.WriteBulk(m)
peer.WriteBulk(db.hashGet(opts.key, m))
}
return
}
peer.WriteLen(len(members))
for _, m := range members {
peer.WriteBulk(m)
}
})
}
func abs(n int) int {
if n < 0 {
return -n
}
return n
}
+95
View File
@@ -0,0 +1,95 @@
package miniredis
import "github.com/alicebob/miniredis/v2/server"
// commandsHll handles all hll related operations.
func commandsHll(m *Miniredis) {
m.srv.Register("PFADD", m.cmdPfadd)
m.srv.Register("PFCOUNT", m.cmdPfcount)
m.srv.Register("PFMERGE", m.cmdPfmerge)
}
// PFADD
func (m *Miniredis) cmdPfadd(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, items := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if db.exists(key) && db.t(key) != keyTypeHll {
c.WriteError(ErrNotValidHllValue.Error())
return
}
altered := db.hllAdd(key, items...)
c.WriteInt(altered)
})
}
// PFCOUNT
func (m *Miniredis) cmdPfcount(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
keys := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
count, err := db.hllCount(keys)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteInt(count)
})
}
// PFMERGE
func (m *Miniredis) cmdPfmerge(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
keys := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if err := db.hllMerge(keys); err != nil {
c.WriteError(err.Error())
return
}
c.WriteOK()
})
}
+40
View File
@@ -0,0 +1,40 @@
package miniredis
import (
"fmt"
"github.com/alicebob/miniredis/v2/server"
)
// Command 'INFO' from https://redis.io/commands/info/
func (m *Miniredis) cmdInfo(c *server.Peer, cmd string, args []string) {
if !m.isValidCMD(c, cmd) {
return
}
if len(args) > 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
const (
clientsSectionName = "clients"
clientsSectionContent = "# Clients\nconnected_clients:%d\r\n"
)
var result string
for _, key := range args {
if key != clientsSectionName {
setDirty(c)
c.WriteError(fmt.Sprintf("section (%s) is not supported", key))
return
}
}
result = fmt.Sprintf(clientsSectionContent, m.Server().ClientsLen())
c.WriteBulk(result)
})
}
File diff suppressed because it is too large Load Diff
+58
View File
@@ -0,0 +1,58 @@
package miniredis
import (
"fmt"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
// commandsObject handles all object operations.
func commandsObject(m *Miniredis) {
m.srv.Register("OBJECT", m.cmdObject)
}
// OBJECT
func (m *Miniredis) cmdObject(c *server.Peer, cmd string, args []string) {
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
switch sub := strings.ToLower(args[0]); sub {
case "idletime":
m.cmdObjectIdletime(c, args[1:])
default:
setDirty(c)
c.WriteError(fmt.Sprintf(msgFObjectUsage, sub))
}
}
// OBJECT IDLETIME
func (m *Miniredis) cmdObjectIdletime(c *server.Peer, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber("object|idletime"))
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
t, ok := db.lru[key]
if !ok {
c.WriteNull()
return
}
c.WriteInt(int(db.master.effectiveNow().Sub(t).Seconds()))
})
}
+262
View File
@@ -0,0 +1,262 @@
// Commands from https://redis.io/commands#pubsub
package miniredis
import (
"fmt"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
// commandsPubsub handles all PUB/SUB operations.
func commandsPubsub(m *Miniredis) {
m.srv.Register("SUBSCRIBE", m.cmdSubscribe)
m.srv.Register("UNSUBSCRIBE", m.cmdUnsubscribe)
m.srv.Register("PSUBSCRIBE", m.cmdPsubscribe)
m.srv.Register("PUNSUBSCRIBE", m.cmdPunsubscribe)
m.srv.Register("PUBLISH", m.cmdPublish)
m.srv.Register("PUBSUB", m.cmdPubSub)
}
// SUBSCRIBE
func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
sub := m.subscribedState(c)
for _, channel := range args {
n := sub.Subscribe(channel)
c.Block(func(w *server.Writer) {
w.WritePushLen(3)
w.WriteBulk("subscribe")
w.WriteBulk(channel)
w.WriteInt(n)
})
}
})
}
// UNSUBSCRIBE
func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
channels := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
sub := m.subscribedState(c)
if len(channels) == 0 {
channels = sub.Channels()
}
// there is no de-duplication
for _, channel := range channels {
n := sub.Unsubscribe(channel)
c.Block(func(w *server.Writer) {
w.WritePushLen(3)
w.WriteBulk("unsubscribe")
w.WriteBulk(channel)
w.WriteInt(n)
})
}
if len(channels) == 0 {
// special case: there is always a reply
c.Block(func(w *server.Writer) {
w.WritePushLen(3)
w.WriteBulk("unsubscribe")
w.WriteNull()
w.WriteInt(0)
})
}
if sub.Count() == 0 {
endSubscriber(m, c)
}
})
}
// PSUBSCRIBE
func (m *Miniredis) cmdPsubscribe(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
sub := m.subscribedState(c)
for _, pat := range args {
n := sub.Psubscribe(pat)
c.Block(func(w *server.Writer) {
w.WritePushLen(3)
w.WriteBulk("psubscribe")
w.WriteBulk(pat)
w.WriteInt(n)
})
}
})
}
// PUNSUBSCRIBE
func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) {
if !m.handleAuth(c) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
patterns := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
sub := m.subscribedState(c)
if len(patterns) == 0 {
patterns = sub.Patterns()
}
// there is no de-duplication
for _, pat := range patterns {
n := sub.Punsubscribe(pat)
c.Block(func(w *server.Writer) {
w.WritePushLen(3)
w.WriteBulk("punsubscribe")
w.WriteBulk(pat)
w.WriteInt(n)
})
}
if len(patterns) == 0 {
// special case: there is always a reply
c.Block(func(w *server.Writer) {
w.WritePushLen(3)
w.WriteBulk("punsubscribe")
w.WriteNull()
w.WriteInt(0)
})
}
if sub.Count() == 0 {
endSubscriber(m, c)
}
})
}
// PUBLISH
func (m *Miniredis) cmdPublish(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
channel, mesg := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
c.WriteInt(m.publish(channel, mesg))
})
}
// PUBSUB
func (m *Miniredis) cmdPubSub(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if m.checkPubsub(c, cmd) {
return
}
subcommand := strings.ToUpper(args[0])
subargs := args[1:]
var argsOk bool
switch subcommand {
case "CHANNELS":
argsOk = len(subargs) < 2
case "NUMSUB":
argsOk = true
case "NUMPAT":
argsOk = len(subargs) == 0
default:
setDirty(c)
c.WriteError(fmt.Sprintf(msgFPubsubUsageSimple, subcommand))
return
}
if !argsOk {
setDirty(c)
c.WriteError(fmt.Sprintf(msgFPubsubUsage, subcommand))
return
}
if !m.handleAuth(c) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
switch subcommand {
case "CHANNELS":
pat := ""
if len(subargs) == 1 {
pat = subargs[0]
}
allsubs := m.allSubscribers()
channels := activeChannels(allsubs, pat)
c.WriteLen(len(channels))
for _, channel := range channels {
c.WriteBulk(channel)
}
case "NUMSUB":
subs := m.allSubscribers()
c.WriteLen(len(subargs) * 2)
for _, channel := range subargs {
c.WriteBulk(channel)
c.WriteInt(countSubs(subs, channel))
}
case "NUMPAT":
c.WriteInt(countPsubs(m.allSubscribers()))
}
})
}
+343
View File
@@ -0,0 +1,343 @@
package miniredis
import (
"crypto/sha1"
"encoding/hex"
"fmt"
"io"
"strconv"
"strings"
"sync"
lua "github.com/yuin/gopher-lua"
"github.com/yuin/gopher-lua/parse"
luajson "github.com/alicebob/miniredis/v2/gopher-json"
"github.com/alicebob/miniredis/v2/server"
)
func commandsScripting(m *Miniredis) {
m.srv.Register("EVAL", m.cmdEval)
m.srv.Register("EVALSHA", m.cmdEvalsha)
m.srv.Register("SCRIPT", m.cmdScript)
}
var (
parsedScripts = sync.Map{}
)
// Execute lua. Needs to run m.Lock()ed, from within withTx().
// Returns true if the lua was OK (and hence should be cached).
func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []string) bool {
l := lua.NewState(lua.Options{SkipOpenLibs: true})
defer l.Close()
// Taken from the go-lua manual
for _, pair := range []struct {
n string
f lua.LGFunction
}{
{lua.LoadLibName, lua.OpenPackage},
{lua.BaseLibName, lua.OpenBase},
{lua.CoroutineLibName, lua.OpenCoroutine},
{lua.TabLibName, lua.OpenTable},
{lua.StringLibName, lua.OpenString},
{lua.MathLibName, lua.OpenMath},
{lua.DebugLibName, lua.OpenDebug},
} {
if err := l.CallByParam(lua.P{
Fn: l.NewFunction(pair.f),
NRet: 0,
Protect: true,
}, lua.LString(pair.n)); err != nil {
panic(err)
}
}
luajson.Preload(l)
requireGlobal(l, "cjson", "json")
// set global variable KEYS
keysTable := l.NewTable()
keysS, args := args[0], args[1:]
keysLen, err := strconv.Atoi(keysS)
if err != nil {
c.WriteError(msgInvalidInt)
return false
}
if keysLen < 0 {
c.WriteError(msgNegativeKeysNumber)
return false
}
if keysLen > len(args) {
c.WriteError(msgInvalidKeysNumber)
return false
}
keys, args := args[:keysLen], args[keysLen:]
for i, k := range keys {
l.RawSet(keysTable, lua.LNumber(i+1), lua.LString(k))
}
l.SetGlobal("KEYS", keysTable)
argvTable := l.NewTable()
for i, a := range args {
l.RawSet(argvTable, lua.LNumber(i+1), lua.LString(a))
}
l.SetGlobal("ARGV", argvTable)
redisFuncs, redisConstants := mkLua(m.srv, c, sha)
// Register command handlers
l.Push(l.NewFunction(func(l *lua.LState) int {
mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable)
for k, v := range redisConstants {
mod.RawSetString(k, v)
}
l.Push(mod)
return 1
}))
_ = doScript(l, protectGlobals)
l.Push(lua.LString("redis"))
l.Call(1, 0)
// lua can call redis.setresp(...), but it's tmp state.
oldresp := c.Resp3
if err := doScript(l, script); err != nil {
c.WriteError(err.Error())
return false
}
luaToRedis(l, c, l.Get(1))
c.Resp3 = oldresp
c.SwitchResp3 = nil
return true
}
// doScript pre-compiles the given script into a Lua prototype,
// then executes the pre-compiled function against the given lua state.
//
// This is thread-safe.
func doScript(l *lua.LState, script string) error {
proto, err := compile(script)
if err != nil {
return fmt.Errorf(errLuaParseError(err))
}
lfunc := l.NewFunctionFromProto(proto)
l.Push(lfunc)
if err := l.PCall(0, lua.MultRet, nil); err != nil {
// ensure we wrap with the correct format.
return fmt.Errorf(errLuaParseError(err))
}
return nil
}
func compile(script string) (*lua.FunctionProto, error) {
if val, ok := parsedScripts.Load(script); ok {
return val.(*lua.FunctionProto), nil
}
chunk, err := parse.Parse(strings.NewReader(script), "<string>")
if err != nil {
return nil, err
}
proto, err := lua.Compile(chunk, "")
if err != nil {
return nil, err
}
parsedScripts.Store(script, proto)
return proto, nil
}
func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
script, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
sha := sha1Hex(script)
ok := m.runLuaScript(c, sha, script, args)
if ok {
m.scripts[sha] = script
}
})
}
func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
sha, args := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
script, ok := m.scripts[sha]
if !ok {
c.WriteError(msgNoScriptFound)
return
}
m.runLuaScript(c, sha, script, args)
})
}
func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
var opts struct {
subcmd string
script string
}
opts.subcmd, args = args[0], args[1:]
switch strings.ToLower(opts.subcmd) {
case "load":
if len(args) != 1 {
setDirty(c)
c.WriteError(fmt.Sprintf(msgFScriptUsage, "LOAD"))
return
}
opts.script = args[0]
case "exists":
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber("script|exists"))
return
}
case "flush":
if len(args) == 1 {
switch strings.ToUpper(args[0]) {
case "SYNC", "ASYNC":
args = args[1:]
default:
}
}
if len(args) != 0 {
setDirty(c)
c.WriteError(msgScriptFlush)
return
}
default:
setDirty(c)
c.WriteError(fmt.Sprintf(msgFScriptUsageSimple, strings.ToUpper(opts.subcmd)))
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
switch strings.ToLower(opts.subcmd) {
case "load":
if _, err := parse.Parse(strings.NewReader(opts.script), "user_script"); err != nil {
c.WriteError(errLuaParseError(err))
return
}
sha := sha1Hex(opts.script)
m.scripts[sha] = opts.script
c.WriteBulk(sha)
case "exists":
c.WriteLen(len(args))
for _, arg := range args {
if _, ok := m.scripts[arg]; ok {
c.WriteInt(1)
} else {
c.WriteInt(0)
}
}
case "flush":
m.scripts = map[string]string{}
c.WriteOK()
}
})
}
func sha1Hex(s string) string {
h := sha1.New()
io.WriteString(h, s)
return hex.EncodeToString(h.Sum(nil))
}
// requireGlobal imports module modName into the global namespace with the
// identifier id. panics if an error results from the function execution
func requireGlobal(l *lua.LState, id, modName string) {
if err := l.CallByParam(lua.P{
Fn: l.GetGlobal("require"),
NRet: 1,
Protect: true,
}, lua.LString(modName)); err != nil {
panic(err)
}
mod := l.Get(-1)
l.Pop(1)
l.SetGlobal(id, mod)
}
// the following script protects globals
// it is based on: http://metalua.luaforge.net/src/lib/strict.lua.html
var protectGlobals = `
local dbg=debug
local mt = {}
setmetatable(_G, mt)
mt.__newindex = function (t, n, v)
if dbg.getinfo(2) then
local w = dbg.getinfo(2, "S").what
if w ~= "C" then
error("Script attempted to create global variable '"..tostring(n).."'", 2)
end
end
rawset(t, n, v)
end
mt.__index = function (t, n)
if dbg.getinfo(2) and dbg.getinfo(2, "S").what ~= "C" then
error("Script attempted to access nonexistent global variable '"..tostring(n).."'", 2)
end
return rawget(t, n)
end
debug = nil
`
+177
View File
@@ -0,0 +1,177 @@
// Commands from https://redis.io/commands#server
package miniredis
import (
"fmt"
"strconv"
"strings"
"github.com/alicebob/miniredis/v2/server"
"github.com/alicebob/miniredis/v2/size"
)
func commandsServer(m *Miniredis) {
m.srv.Register("COMMAND", m.cmdCommand)
m.srv.Register("DBSIZE", m.cmdDbsize)
m.srv.Register("FLUSHALL", m.cmdFlushall)
m.srv.Register("FLUSHDB", m.cmdFlushdb)
m.srv.Register("INFO", m.cmdInfo)
m.srv.Register("TIME", m.cmdTime)
m.srv.Register("MEMORY", m.cmdMemory)
}
// MEMORY
func (m *Miniredis) cmdMemory(c *server.Peer, cmd string, args []string) {
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
cmd, args := strings.ToLower(args[0]), args[1:]
switch cmd {
case "usage":
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber("memory|usage"))
return
}
if len(args) > 1 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
var (
value interface{}
ok bool
)
switch db.keys[args[0]] {
case keyTypeString:
value, ok = db.stringKeys[args[0]]
case keyTypeSet:
value, ok = db.setKeys[args[0]]
case keyTypeHash:
value, ok = db.hashKeys[args[0]]
case keyTypeList:
value, ok = db.listKeys[args[0]]
case keyTypeHll:
value, ok = db.hllKeys[args[0]]
case keyTypeSortedSet:
value, ok = db.sortedsetKeys[args[0]]
case keyTypeStream:
value, ok = db.streamKeys[args[0]]
}
if !ok {
c.WriteNull()
return
}
c.WriteInt(size.Of(value))
default:
c.WriteError(fmt.Sprintf(msgMemorySubcommand, strings.ToUpper(cmd)))
}
})
}
// DBSIZE
func (m *Miniredis) cmdDbsize(c *server.Peer, cmd string, args []string) {
if len(args) > 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
c.WriteInt(len(db.keys))
})
}
// FLUSHALL
func (m *Miniredis) cmdFlushall(c *server.Peer, cmd string, args []string) {
if len(args) > 0 && strings.ToLower(args[0]) == "async" {
args = args[1:]
}
if len(args) > 0 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
m.flushAll()
c.WriteOK()
})
}
// FLUSHDB
func (m *Miniredis) cmdFlushdb(c *server.Peer, cmd string, args []string) {
if len(args) > 0 && strings.ToLower(args[0]) == "async" {
args = args[1:]
}
if len(args) > 0 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
m.db(ctx.selectedDB).flush()
c.WriteOK()
})
}
// TIME
func (m *Miniredis) cmdTime(c *server.Peer, cmd string, args []string) {
if len(args) > 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
now := m.effectiveNow()
nanos := now.UnixNano()
seconds := nanos / 1_000_000_000
microseconds := (nanos / 1_000) % 1_000_000
c.WriteLen(2)
c.WriteBulk(strconv.FormatInt(seconds, 10))
c.WriteBulk(strconv.FormatInt(microseconds, 10))
})
}
+836
View File
@@ -0,0 +1,836 @@
// Commands from https://redis.io/commands#set
package miniredis
import (
"fmt"
"strconv"
"strings"
"github.com/alicebob/miniredis/v2/server"
)
// commandsSet handles all set value operations.
func commandsSet(m *Miniredis) {
m.srv.Register("SADD", m.cmdSadd)
m.srv.Register("SCARD", m.cmdScard)
m.srv.Register("SDIFF", m.cmdSdiff)
m.srv.Register("SDIFFSTORE", m.cmdSdiffstore)
m.srv.Register("SINTERCARD", m.cmdSintercard)
m.srv.Register("SINTER", m.cmdSinter)
m.srv.Register("SINTERSTORE", m.cmdSinterstore)
m.srv.Register("SISMEMBER", m.cmdSismember)
m.srv.Register("SMEMBERS", m.cmdSmembers)
m.srv.Register("SMISMEMBER", m.cmdSmismember)
m.srv.Register("SMOVE", m.cmdSmove)
m.srv.Register("SPOP", m.cmdSpop)
m.srv.Register("SRANDMEMBER", m.cmdSrandmember)
m.srv.Register("SREM", m.cmdSrem)
m.srv.Register("SUNION", m.cmdSunion)
m.srv.Register("SUNIONSTORE", m.cmdSunionstore)
m.srv.Register("SSCAN", m.cmdSscan)
}
// SADD
func (m *Miniredis) cmdSadd(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, elems := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if db.exists(key) && db.t(key) != keyTypeSet {
c.WriteError(ErrWrongType.Error())
return
}
added := db.setAdd(key, elems...)
c.WriteInt(added)
})
}
// SCARD
func (m *Miniredis) cmdScard(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteInt(0)
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
members := db.setMembers(key)
c.WriteInt(len(members))
})
}
// SDIFF
func (m *Miniredis) cmdSdiff(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
keys := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setDiff(keys)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteSetLen(len(set))
for k := range set {
c.WriteBulk(k)
}
})
}
// SDIFFSTORE
func (m *Miniredis) cmdSdiffstore(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
dest, keys := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setDiff(keys)
if err != nil {
c.WriteError(err.Error())
return
}
db.del(dest, true)
db.setSet(dest, set)
c.WriteInt(len(set))
})
}
// SINTER
func (m *Miniredis) cmdSinter(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
keys := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setInter(keys)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteLen(len(set))
for k := range set {
c.WriteBulk(k)
}
})
}
// SINTERSTORE
func (m *Miniredis) cmdSinterstore(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
dest, keys := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setInter(keys)
if err != nil {
c.WriteError(err.Error())
return
}
db.del(dest, true)
db.setSet(dest, set)
c.WriteInt(len(set))
})
}
// SINTERCARD
func (m *Miniredis) cmdSintercard(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
keys []string
limit int
}{}
numKeys, err := strconv.Atoi(args[0])
if err != nil {
setDirty(c)
c.WriteError("ERR numkeys should be greater than 0")
return
}
if numKeys < 1 {
setDirty(c)
c.WriteError("ERR numkeys should be greater than 0")
return
}
args = args[1:]
if len(args) < numKeys {
setDirty(c)
c.WriteError("ERR Number of keys can't be greater than number of args")
return
}
opts.keys = args[:numKeys]
args = args[numKeys:]
if len(args) == 2 && strings.ToLower(args[0]) == "limit" {
l, err := strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
if l < 0 {
setDirty(c)
c.WriteError(msgLimitIsNegative)
return
}
opts.limit = l
} else if len(args) > 0 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
count, err := db.setIntercard(opts.keys, opts.limit)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteInt(count)
})
}
// SISMEMBER
func (m *Miniredis) cmdSismember(c *server.Peer, cmd string, args []string) {
if len(args) != 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, value := args[0], args[1]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteInt(0)
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
if db.setIsMember(key, value) {
c.WriteInt(1)
return
}
c.WriteInt(0)
})
}
// SMEMBERS
func (m *Miniredis) cmdSmembers(c *server.Peer, cmd string, args []string) {
if len(args) != 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteSetLen(0)
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
members := db.setMembers(key)
c.WriteSetLen(len(members))
for _, elem := range members {
c.WriteBulk(elem)
}
})
}
// SMISMEMBER
func (m *Miniredis) cmdSmismember(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, values := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteLen(len(values))
for range values {
c.WriteInt(0)
}
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
c.WriteLen(len(values))
for _, value := range values {
if db.setIsMember(key, value) {
c.WriteInt(1)
} else {
c.WriteInt(0)
}
}
return
})
}
// SMOVE
func (m *Miniredis) cmdSmove(c *server.Peer, cmd string, args []string) {
if len(args) != 3 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
src, dst, member := args[0], args[1], args[2]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(src) {
c.WriteInt(0)
return
}
if db.t(src) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
if db.exists(dst) && db.t(dst) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
if !db.setIsMember(src, member) {
c.WriteInt(0)
return
}
db.setRem(src, member)
db.setAdd(dst, member)
c.WriteInt(1)
})
}
// SPOP
func (m *Miniredis) cmdSpop(c *server.Peer, cmd string, args []string) {
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
opts := struct {
key string
withCount bool
count int
}{
count: 1,
}
opts.key, args = args[0], args[1:]
if len(args) > 0 {
v, err := strconv.Atoi(args[0])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
if v < 0 {
setDirty(c)
c.WriteError(msgOutOfRange)
return
}
opts.count = v
opts.withCount = true
args = args[1:]
}
if len(args) > 0 {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(opts.key) {
if !opts.withCount {
c.WriteNull()
return
}
c.WriteLen(0)
return
}
if db.t(opts.key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
var deleted []string
members := db.setMembers(opts.key)
for i := 0; i < opts.count; i++ {
if len(members) == 0 {
break
}
i := m.randIntn(len(members))
member := members[i]
members = delElem(members, i)
db.setRem(opts.key, member)
deleted = append(deleted, member)
}
// without `count` return a single value
if !opts.withCount {
if len(deleted) == 0 {
c.WriteNull()
return
}
c.WriteBulk(deleted[0])
return
}
// with `count` return a list
c.WriteLen(len(deleted))
for _, v := range deleted {
c.WriteBulk(v)
}
})
}
// SRANDMEMBER
func (m *Miniredis) cmdSrandmember(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if len(args) > 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key := args[0]
count := 0
withCount := false
if len(args) == 2 {
var err error
count, err = strconv.Atoi(args[1])
if err != nil {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
withCount = true
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
if withCount {
c.WriteLen(0)
return
}
c.WriteNull()
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
members := db.setMembers(key)
if count < 0 {
// Non-unique elements is allowed with negative count.
c.WriteLen(-count)
for count != 0 {
member := members[m.randIntn(len(members))]
c.WriteBulk(member)
count++
}
return
}
// Must be unique elements.
m.shuffle(members)
if count > len(members) {
count = len(members)
}
if !withCount {
c.WriteBulk(members[0])
return
}
c.WriteLen(count)
for i := range make([]struct{}, count) {
c.WriteBulk(members[i])
}
})
}
// SREM
func (m *Miniredis) cmdSrem(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
key, fields := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
if !db.exists(key) {
c.WriteInt(0)
return
}
if db.t(key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
c.WriteInt(db.setRem(key, fields...))
})
}
// SUNION
func (m *Miniredis) cmdSunion(c *server.Peer, cmd string, args []string) {
if len(args) < 1 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
keys := args
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setUnion(keys)
if err != nil {
c.WriteError(err.Error())
return
}
c.WriteLen(len(set))
for k := range set {
c.WriteBulk(k)
}
})
}
// SUNIONSTORE
func (m *Miniredis) cmdSunionstore(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
dest, keys := args[0], args[1:]
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
set, err := db.setUnion(keys)
if err != nil {
c.WriteError(err.Error())
return
}
db.del(dest, true)
db.setSet(dest, set)
c.WriteInt(len(set))
})
}
// SSCAN
func (m *Miniredis) cmdSscan(c *server.Peer, cmd string, args []string) {
if len(args) < 2 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
var opts struct {
key string
value int
cursor int
count int
withMatch bool
match string
}
opts.key = args[0]
if ok := optIntErr(c, args[1], &opts.cursor, msgInvalidCursor); !ok {
return
}
args = args[2:]
// MATCH and COUNT options
for len(args) > 0 {
if strings.ToLower(args[0]) == "count" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
count, err := strconv.Atoi(args[1])
if err != nil || count < 0 {
setDirty(c)
c.WriteError(msgInvalidInt)
return
}
if count == 0 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
opts.count = count
args = args[2:]
continue
}
if strings.ToLower(args[0]) == "match" {
if len(args) < 2 {
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
opts.withMatch = true
opts.match = args[1]
args = args[2:]
continue
}
setDirty(c)
c.WriteError(msgSyntaxError)
return
}
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
db := m.db(ctx.selectedDB)
// return _all_ (matched) keys every time
if db.exists(opts.key) && db.t(opts.key) != "set" {
c.WriteError(ErrWrongType.Error())
return
}
members := db.setMembers(opts.key)
if opts.withMatch {
members, _ = matchKeys(members, opts.match)
}
low := opts.cursor
high := low + opts.count
// validate high is correct
if high > len(members) || high == 0 {
high = len(members)
}
if opts.cursor > high {
// invalid cursor
c.WriteLen(2)
c.WriteBulk("0") // no next cursor
c.WriteLen(0) // no elements
return
}
cursorValue := low + opts.count
if cursorValue > len(members) {
cursorValue = 0 // no next cursor
}
members = members[low:high]
c.WriteLen(2)
c.WriteBulk(fmt.Sprintf("%d", cursorValue))
c.WriteLen(len(members))
for _, k := range members {
c.WriteBulk(k)
}
})
}
func delElem(ls []string, i int) []string {
// this swap+truncate is faster but changes behaviour:
// ls[i] = ls[len(ls)-1]
// ls = ls[:len(ls)-1]
// so we do the dumb thing:
ls = append(ls[:i], ls[i+1:]...)
return ls
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+179
View File
@@ -0,0 +1,179 @@
// Commands from https://redis.io/commands#transactions
package miniredis
import (
"github.com/alicebob/miniredis/v2/server"
)
// commandsTransaction handles MULTI &c.
func commandsTransaction(m *Miniredis) {
m.srv.Register("DISCARD", m.cmdDiscard)
m.srv.Register("EXEC", m.cmdExec)
m.srv.Register("MULTI", m.cmdMulti)
m.srv.Register("UNWATCH", m.cmdUnwatch)
m.srv.Register("WATCH", m.cmdWatch)
}
// MULTI
func (m *Miniredis) cmdMulti(c *server.Peer, cmd string, args []string) {
if len(args) != 0 {
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
if inTx(ctx) {
c.WriteError("ERR MULTI calls can not be nested")
return
}
startTx(ctx)
c.WriteOK()
}
// EXEC
func (m *Miniredis) cmdExec(c *server.Peer, cmd string, args []string) {
if len(args) != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
if !inTx(ctx) {
c.WriteError("ERR EXEC without MULTI")
return
}
if ctx.dirtyTransaction {
c.WriteError("EXECABORT Transaction discarded because of previous errors.")
// a failed EXEC finishes the tx
stopTx(ctx)
return
}
m.Lock()
defer m.Unlock()
// Check WATCHed keys.
for t, version := range ctx.watch {
if m.db(t.db).keyVersion[t.key] > version {
// Abort! Abort!
stopTx(ctx)
c.WriteLen(-1)
return
}
}
c.WriteLen(len(ctx.transaction))
for _, cb := range ctx.transaction {
cb(c, ctx)
}
// wake up anyone who waits on anything.
m.signal.Broadcast()
stopTx(ctx)
}
// DISCARD
func (m *Miniredis) cmdDiscard(c *server.Peer, cmd string, args []string) {
if len(args) != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
ctx := getCtx(c)
if !inTx(ctx) {
c.WriteError("ERR DISCARD without MULTI")
return
}
stopTx(ctx)
c.WriteOK()
}
// WATCH
func (m *Miniredis) cmdWatch(c *server.Peer, cmd string, args []string) {
if len(args) == 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
ctx := getCtx(c)
if ctx.nested {
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
return
}
if inTx(ctx) {
c.WriteError("ERR WATCH in MULTI")
return
}
m.Lock()
defer m.Unlock()
db := m.db(ctx.selectedDB)
for _, key := range args {
watch(db, ctx, key)
}
c.WriteOK()
}
// UNWATCH
func (m *Miniredis) cmdUnwatch(c *server.Peer, cmd string, args []string) {
if len(args) != 0 {
setDirty(c)
c.WriteError(errWrongNumber(cmd))
return
}
if !m.handleAuth(c) {
return
}
if m.checkPubsub(c, cmd) {
return
}
// Doesn't matter if UNWATCH is in a TX or not. Looks like a Redis bug to me.
unwatch(getCtx(c))
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
// Do nothing if it's called in a transaction.
c.WriteOK()
})
}
+790
View File
@@ -0,0 +1,790 @@
package miniredis
import (
"errors"
"fmt"
"math"
"math/big"
"sort"
"strconv"
"time"
)
var (
errInvalidEntryID = errors.New("stream ID is invalid")
)
// exists also updates the lru
func (db *RedisDB) exists(k string) bool {
_, ok := db.keys[k]
if ok {
db.lru[k] = db.master.effectiveNow()
}
return ok
}
// t gives the type of a key, or ""
func (db *RedisDB) t(k string) string {
return db.keys[k]
}
// incr increases the version and the lru timestamp
func (db *RedisDB) incr(k string) {
db.lru[k] = db.master.effectiveNow()
db.keyVersion[k]++
}
// allKeys returns all keys. Sorted.
func (db *RedisDB) allKeys() []string {
res := make([]string, 0, len(db.keys))
for k := range db.keys {
res = append(res, k)
}
sort.Strings(res) // To make things deterministic.
return res
}
// flush removes all keys and values.
func (db *RedisDB) flush() {
db.keys = map[string]string{}
db.lru = map[string]time.Time{}
db.stringKeys = map[string]string{}
db.hashKeys = map[string]hashKey{}
db.listKeys = map[string]listKey{}
db.setKeys = map[string]setKey{}
db.hllKeys = map[string]*hll{}
db.sortedsetKeys = map[string]sortedSet{}
db.ttl = map[string]time.Duration{}
db.streamKeys = map[string]*streamKey{}
}
// move something to another db. Will return ok. Or not.
func (db *RedisDB) move(key string, to *RedisDB) bool {
if _, ok := to.keys[key]; ok {
return false
}
t, ok := db.keys[key]
if !ok {
return false
}
to.keys[key] = db.keys[key]
switch t {
case keyTypeString:
to.stringKeys[key] = db.stringKeys[key]
case keyTypeHash:
to.hashKeys[key] = db.hashKeys[key]
case keyTypeList:
to.listKeys[key] = db.listKeys[key]
case keyTypeSet:
to.setKeys[key] = db.setKeys[key]
case keyTypeSortedSet:
to.sortedsetKeys[key] = db.sortedsetKeys[key]
case keyTypeStream:
to.streamKeys[key] = db.streamKeys[key]
case keyTypeHll:
to.hllKeys[key] = db.hllKeys[key]
default:
panic("unhandled key type")
}
if v, ok := db.ttl[key]; ok {
to.ttl[key] = v
}
to.incr(key)
db.del(key, true)
return true
}
func (db *RedisDB) rename(from, to string) {
db.del(to, true)
switch db.t(from) {
case keyTypeString:
db.stringKeys[to] = db.stringKeys[from]
case keyTypeHash:
db.hashKeys[to] = db.hashKeys[from]
case keyTypeList:
db.listKeys[to] = db.listKeys[from]
case keyTypeSet:
db.setKeys[to] = db.setKeys[from]
case keyTypeSortedSet:
db.sortedsetKeys[to] = db.sortedsetKeys[from]
case keyTypeStream:
db.streamKeys[to] = db.streamKeys[from]
case keyTypeHll:
db.hllKeys[to] = db.hllKeys[from]
default:
panic("missing case")
}
db.keys[to] = db.keys[from]
if v, ok := db.ttl[from]; ok {
db.ttl[to] = v
}
db.incr(to)
db.del(from, true)
}
func (db *RedisDB) del(k string, delTTL bool) {
if !db.exists(k) {
return
}
t := db.t(k)
delete(db.keys, k)
delete(db.lru, k)
db.keyVersion[k]++
if delTTL {
delete(db.ttl, k)
}
switch t {
case keyTypeString:
delete(db.stringKeys, k)
case keyTypeHash:
delete(db.hashKeys, k)
case keyTypeList:
delete(db.listKeys, k)
case keyTypeSet:
delete(db.setKeys, k)
case keyTypeSortedSet:
delete(db.sortedsetKeys, k)
case keyTypeStream:
delete(db.streamKeys, k)
case keyTypeHll:
delete(db.hllKeys, k)
default:
panic("Unknown key type: " + t)
}
}
// stringGet returns the string key or "" on error/nonexists.
func (db *RedisDB) stringGet(k string) string {
if t, ok := db.keys[k]; !ok || t != keyTypeString {
return ""
}
return db.stringKeys[k]
}
// stringSet force set()s a key. Does not touch expire.
func (db *RedisDB) stringSet(k, v string) {
db.del(k, false)
db.keys[k] = keyTypeString
db.stringKeys[k] = v
db.incr(k)
}
// change int key value
func (db *RedisDB) stringIncr(k string, delta int) (int, error) {
v := 0
if sv, ok := db.stringKeys[k]; ok {
var err error
v, err = strconv.Atoi(sv)
if err != nil {
return 0, ErrIntValueError
}
}
if delta > 0 {
if math.MaxInt-delta < v {
return 0, ErrIntValueOverflowError
}
} else {
if math.MinInt-delta > v {
return 0, ErrIntValueOverflowError
}
}
v += delta
db.stringSet(k, strconv.Itoa(v))
return v, nil
}
// change float key value
func (db *RedisDB) stringIncrfloat(k string, delta *big.Float) (*big.Float, error) {
v := big.NewFloat(0.0)
v.SetPrec(128)
if sv, ok := db.stringKeys[k]; ok {
var err error
v, _, err = big.ParseFloat(sv, 10, 128, 0)
if err != nil {
return nil, ErrFloatValueError
}
}
v.Add(v, delta)
db.stringSet(k, formatBig(v))
return v, nil
}
// listLpush is 'left push', aka unshift. Returns the new length.
func (db *RedisDB) listLpush(k, v string) int {
l, ok := db.listKeys[k]
if !ok {
db.keys[k] = keyTypeList
}
l = append([]string{v}, l...)
db.listKeys[k] = l
db.incr(k)
return len(l)
}
// 'left pop', aka shift.
func (db *RedisDB) listLpop(k string) string {
l := db.listKeys[k]
el := l[0]
l = l[1:]
if len(l) == 0 {
db.del(k, true)
} else {
db.listKeys[k] = l
}
db.incr(k)
return el
}
func (db *RedisDB) listPush(k string, v ...string) int {
l, ok := db.listKeys[k]
if !ok {
db.keys[k] = keyTypeList
}
l = append(l, v...)
db.listKeys[k] = l
db.incr(k)
return len(l)
}
func (db *RedisDB) listPop(k string) string {
l := db.listKeys[k]
el := l[len(l)-1]
l = l[:len(l)-1]
if len(l) == 0 {
db.del(k, true)
} else {
db.listKeys[k] = l
db.incr(k)
}
return el
}
// setset replaces a whole set.
func (db *RedisDB) setSet(k string, set setKey) {
db.keys[k] = keyTypeSet
db.setKeys[k] = set
db.incr(k)
}
// setadd adds members to a set. Returns nr of new keys.
func (db *RedisDB) setAdd(k string, elems ...string) int {
s, ok := db.setKeys[k]
if !ok {
s = setKey{}
db.keys[k] = keyTypeSet
}
added := 0
for _, e := range elems {
if _, ok := s[e]; !ok {
added++
}
s[e] = struct{}{}
}
db.setKeys[k] = s
db.incr(k)
return added
}
// setrem removes members from a set. Returns nr of deleted keys.
func (db *RedisDB) setRem(k string, fields ...string) int {
s, ok := db.setKeys[k]
if !ok {
return 0
}
removed := 0
for _, f := range fields {
if _, ok := s[f]; ok {
removed++
delete(s, f)
}
}
if len(s) == 0 {
db.del(k, true)
} else {
db.setKeys[k] = s
}
db.incr(k)
return removed
}
// All members of a set.
func (db *RedisDB) setMembers(k string) []string {
set := db.setKeys[k]
members := make([]string, 0, len(set))
for k := range set {
members = append(members, k)
}
sort.Strings(members)
return members
}
// Is a SET value present?
func (db *RedisDB) setIsMember(k, v string) bool {
set, ok := db.setKeys[k]
if !ok {
return false
}
_, ok = set[v]
return ok
}
// hashFields returns all (sorted) keys ('fields') for a hash key.
func (db *RedisDB) hashFields(k string) []string {
v := db.hashKeys[k]
var r []string
for k := range v {
r = append(r, k)
}
sort.Strings(r)
return r
}
// hashValues returns all (sorted) values a hash key.
func (db *RedisDB) hashValues(k string) []string {
h := db.hashKeys[k]
var r []string
for _, v := range h {
r = append(r, v)
}
sort.Strings(r)
return r
}
// hashGet a value
func (db *RedisDB) hashGet(key, field string) string {
return db.hashKeys[key][field]
}
// hashSet returns the number of new keys
func (db *RedisDB) hashSet(k string, fv ...string) int {
if t, ok := db.keys[k]; ok && t != keyTypeHash {
db.del(k, true)
}
db.keys[k] = keyTypeHash
if _, ok := db.hashKeys[k]; !ok {
db.hashKeys[k] = map[string]string{}
}
new := 0
for idx := 0; idx < len(fv)-1; idx = idx + 2 {
f, v := fv[idx], fv[idx+1]
_, ok := db.hashKeys[k][f]
db.hashKeys[k][f] = v
db.incr(k)
if !ok {
new++
}
}
return new
}
// hashIncr changes int key value
func (db *RedisDB) hashIncr(key, field string, delta int) (int, error) {
v := 0
if h, ok := db.hashKeys[key]; ok {
if f, ok := h[field]; ok {
var err error
v, err = strconv.Atoi(f)
if err != nil {
return 0, ErrIntValueError
}
}
}
v += delta
db.hashSet(key, field, strconv.Itoa(v))
return v, nil
}
// hashIncrfloat changes float key value
func (db *RedisDB) hashIncrfloat(key, field string, delta *big.Float) (*big.Float, error) {
v := big.NewFloat(0.0)
v.SetPrec(128)
if h, ok := db.hashKeys[key]; ok {
if f, ok := h[field]; ok {
var err error
v, _, err = big.ParseFloat(f, 10, 128, 0)
if err != nil {
return nil, ErrFloatValueError
}
}
}
v.Add(v, delta)
db.hashSet(key, field, formatBig(v))
return v, nil
}
// sortedSet set returns a sortedSet as map
func (db *RedisDB) sortedSet(key string) map[string]float64 {
ss := db.sortedsetKeys[key]
return map[string]float64(ss)
}
// ssetSet sets a complete sorted set.
func (db *RedisDB) ssetSet(key string, sset sortedSet) {
db.keys[key] = keyTypeSortedSet
db.incr(key)
db.sortedsetKeys[key] = sset
}
// ssetAdd adds member to a sorted set. Returns whether this was a new member.
func (db *RedisDB) ssetAdd(key string, score float64, member string) bool {
ss, ok := db.sortedsetKeys[key]
if !ok {
ss = newSortedSet()
db.keys[key] = keyTypeSortedSet
}
_, ok = ss[member]
ss[member] = score
db.sortedsetKeys[key] = ss
db.incr(key)
return !ok
}
// All members from a sorted set, ordered by score.
func (db *RedisDB) ssetMembers(key string) []string {
ss, ok := db.sortedsetKeys[key]
if !ok {
return nil
}
elems := ss.byScore(asc)
members := make([]string, 0, len(elems))
for _, e := range elems {
members = append(members, e.member)
}
return members
}
// All members+scores from a sorted set, ordered by score.
func (db *RedisDB) ssetElements(key string) ssElems {
ss, ok := db.sortedsetKeys[key]
if !ok {
return nil
}
return ss.byScore(asc)
}
func (db *RedisDB) ssetRandomMember(key string) string {
elems := db.ssetElements(key)
if len(elems) == 0 {
return ""
}
return elems[db.master.randIntn(len(elems))].member
}
// ssetCard is the sorted set cardinality.
func (db *RedisDB) ssetCard(key string) int {
ss := db.sortedsetKeys[key]
return ss.card()
}
// ssetRank is the sorted set rank.
func (db *RedisDB) ssetRank(key, member string, d direction) (int, bool) {
ss := db.sortedsetKeys[key]
return ss.rankByScore(member, d)
}
// ssetScore is sorted set score.
func (db *RedisDB) ssetScore(key, member string) float64 {
ss := db.sortedsetKeys[key]
return ss[member]
}
// ssetMScore returns multiple scores of a list of members in a sorted set.
func (db *RedisDB) ssetMScore(key string, members []string) []float64 {
scores := make([]float64, 0, len(members))
ss := db.sortedsetKeys[key]
for _, member := range members {
scores = append(scores, ss[member])
}
return scores
}
// ssetRem is sorted set key delete.
func (db *RedisDB) ssetRem(key, member string) bool {
ss := db.sortedsetKeys[key]
_, ok := ss[member]
delete(ss, member)
if len(ss) == 0 {
// Delete key on removal of last member
db.del(key, true)
}
return ok
}
// ssetExists tells if a member exists in a sorted set.
func (db *RedisDB) ssetExists(key, member string) bool {
ss := db.sortedsetKeys[key]
_, ok := ss[member]
return ok
}
// ssetIncrby changes float sorted set score.
func (db *RedisDB) ssetIncrby(k, m string, delta float64) float64 {
ss, ok := db.sortedsetKeys[k]
if !ok {
ss = newSortedSet()
db.keys[k] = keyTypeSortedSet
db.sortedsetKeys[k] = ss
}
v, _ := ss.get(m)
v += delta
ss.set(v, m)
db.incr(k)
return v
}
// setDiff implements the logic behind SDIFF*
func (db *RedisDB) setDiff(keys []string) (setKey, error) {
key := keys[0]
keys = keys[1:]
if db.exists(key) && db.t(key) != keyTypeSet {
return nil, ErrWrongType
}
s := setKey{}
for k := range db.setKeys[key] {
s[k] = struct{}{}
}
for _, sk := range keys {
if !db.exists(sk) {
continue
}
if db.t(sk) != keyTypeSet {
return nil, ErrWrongType
}
for e := range db.setKeys[sk] {
delete(s, e)
}
}
return s, nil
}
// setInter implements the logic behind SINTER*
// len keys needs to be > 0
func (db *RedisDB) setInter(keys []string) (setKey, error) {
// all keys must either not exist, or be of type "set".
for _, key := range keys {
if db.exists(key) && db.t(key) != keyTypeSet {
return nil, ErrWrongType
}
}
key := keys[0]
keys = keys[1:]
if !db.exists(key) {
return nil, nil
}
if db.t(key) != keyTypeSet {
return nil, ErrWrongType
}
s := setKey{}
for k := range db.setKeys[key] {
s[k] = struct{}{}
}
for _, sk := range keys {
if !db.exists(sk) {
return setKey{}, nil
}
if db.t(sk) != keyTypeSet {
return nil, ErrWrongType
}
other := db.setKeys[sk]
for e := range s {
if _, ok := other[e]; ok {
continue
}
delete(s, e)
}
}
return s, nil
}
// setIntercard implements the logic behind SINTER*
// len keys needs to be > 0
func (db *RedisDB) setIntercard(keys []string, limit int) (int, error) {
// all keys must either not exist, or be of type "set".
allExist := true
for _, key := range keys {
exists := db.exists(key)
allExist = allExist && exists
if exists && db.t(key) != "set" {
return 0, ErrWrongType
}
}
if !allExist {
return 0, nil
}
smallestKey := keys[0]
smallestIdx := 0
for i, key := range keys {
if len(db.setKeys[key]) < len(db.setKeys[smallestKey]) {
smallestKey = key
smallestIdx = i
}
}
keys[smallestIdx] = keys[len(keys)-1]
keys = keys[:len(keys)-1]
count := 0
for item := range db.setKeys[smallestKey] {
inIntersection := true
for _, key := range keys {
if _, ok := db.setKeys[key][item]; !ok {
inIntersection = false
break
}
}
if inIntersection {
count++
if count == limit {
break
}
}
}
return count, nil
}
// setUnion implements the logic behind SUNION*
func (db *RedisDB) setUnion(keys []string) (setKey, error) {
key := keys[0]
keys = keys[1:]
if db.exists(key) && db.t(key) != "set" {
return nil, ErrWrongType
}
s := setKey{}
for k := range db.setKeys[key] {
s[k] = struct{}{}
}
for _, sk := range keys {
if !db.exists(sk) {
continue
}
if db.t(sk) != "set" {
return nil, ErrWrongType
}
for e := range db.setKeys[sk] {
s[e] = struct{}{}
}
}
return s, nil
}
func (db *RedisDB) newStream(key string) (*streamKey, error) {
if s, err := db.stream(key); err != nil {
return nil, err
} else if s != nil {
return nil, fmt.Errorf("ErrAlreadyExists")
}
db.keys[key] = keyTypeStream
s := newStreamKey()
db.streamKeys[key] = s
db.incr(key)
return s, nil
}
// return existing stream, or nil.
func (db *RedisDB) stream(key string) (*streamKey, error) {
if db.exists(key) && db.t(key) != keyTypeStream {
return nil, ErrWrongType
}
return db.streamKeys[key], nil
}
// return existing stream group, or nil.
func (db *RedisDB) streamGroup(key, group string) (*streamGroup, error) {
s, err := db.stream(key)
if err != nil || s == nil {
return nil, err
}
return s.groups[group], nil
}
// fastForward proceeds the current timestamp with duration, works as a time machine
func (db *RedisDB) fastForward(duration time.Duration) {
for _, key := range db.allKeys() {
if value, ok := db.ttl[key]; ok {
db.ttl[key] = value - duration
db.checkTTL(key)
}
}
}
func (db *RedisDB) checkTTL(key string) {
if v, ok := db.ttl[key]; ok && v <= 0 {
db.del(key, true)
}
}
// hllAdd adds members to a hll. Returns 1 if at least 1 if internal HyperLogLog was altered, otherwise 0
func (db *RedisDB) hllAdd(k string, elems ...string) int {
s, ok := db.hllKeys[k]
if !ok {
s = newHll()
db.keys[k] = keyTypeHll
}
hllAltered := 0
for _, e := range elems {
if s.Add([]byte(e)) {
hllAltered = 1
}
}
db.hllKeys[k] = s
db.incr(k)
return hllAltered
}
// hllCount estimates the amount of members added to hll by hllAdd. If called with several arguments, hllCount returns a sum of estimations
func (db *RedisDB) hllCount(keys []string) (int, error) {
countOverall := 0
for _, key := range keys {
if db.exists(key) && db.t(key) != keyTypeHll {
return 0, ErrNotValidHllValue
}
if !db.exists(key) {
continue
}
countOverall += db.hllKeys[key].Count()
}
return countOverall, nil
}
// hllMerge merges all the hlls provided as keys to the first key. Creates a new hll in the first key if it contains nothing
func (db *RedisDB) hllMerge(keys []string) error {
for _, key := range keys {
if db.exists(key) && db.t(key) != keyTypeHll {
return ErrNotValidHllValue
}
}
destKey := keys[0]
restKeys := keys[1:]
var destHll *hll
if db.exists(destKey) {
destHll = db.hllKeys[destKey]
} else {
destHll = newHll()
}
for _, key := range restKeys {
if !db.exists(key) {
continue
}
destHll.Merge(db.hllKeys[key])
}
db.hllKeys[destKey] = destHll
db.keys[destKey] = keyTypeHll
db.incr(destKey)
return nil
}
+824
View File
@@ -0,0 +1,824 @@
package miniredis
// Commands to modify and query our databases directly.
import (
"errors"
"math/big"
"time"
)
var (
// ErrKeyNotFound is returned when a key doesn't exist.
ErrKeyNotFound = errors.New(msgKeyNotFound)
// ErrWrongType when a key is not the right type.
ErrWrongType = errors.New(msgWrongType)
// ErrNotValidHllValue when a key is not a valid HyperLogLog string value.
ErrNotValidHllValue = errors.New(msgNotValidHllValue)
// ErrIntValueError can returned by INCRBY
ErrIntValueError = errors.New(msgInvalidInt)
// ErrIntValueOverflowError can be returned by INCR, DECR, INCRBY, DECRBY
ErrIntValueOverflowError = errors.New(msgIntOverflow)
// ErrFloatValueError can returned by INCRBYFLOAT
ErrFloatValueError = errors.New(msgInvalidFloat)
)
// Select sets the DB id for all direct commands.
func (m *Miniredis) Select(i int) {
m.Lock()
defer m.Unlock()
m.selectedDB = i
}
// Keys returns all keys from the selected database, sorted.
func (m *Miniredis) Keys() []string {
return m.DB(m.selectedDB).Keys()
}
// Keys returns all keys, sorted.
func (db *RedisDB) Keys() []string {
db.master.Lock()
defer db.master.Unlock()
return db.allKeys()
}
// FlushAll removes all keys from all databases.
func (m *Miniredis) FlushAll() {
m.Lock()
defer m.Unlock()
defer m.signal.Broadcast()
m.flushAll()
}
func (m *Miniredis) flushAll() {
for _, db := range m.dbs {
db.flush()
}
}
// FlushDB removes all keys from the selected database.
func (m *Miniredis) FlushDB() {
m.DB(m.selectedDB).FlushDB()
}
// FlushDB removes all keys.
func (db *RedisDB) FlushDB() {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
db.flush()
}
// Get returns string keys added with SET.
func (m *Miniredis) Get(k string) (string, error) {
return m.DB(m.selectedDB).Get(k)
}
// Get returns a string key.
func (db *RedisDB) Get(k string) (string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return "", ErrKeyNotFound
}
if db.t(k) != keyTypeString {
return "", ErrWrongType
}
return db.stringGet(k), nil
}
// Set sets a string key. Removes expire.
func (m *Miniredis) Set(k, v string) error {
return m.DB(m.selectedDB).Set(k, v)
}
// Set sets a string key. Removes expire.
// Unlike redis the key can't be an existing non-string key.
func (db *RedisDB) Set(k, v string) error {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if db.exists(k) && db.t(k) != keyTypeString {
return ErrWrongType
}
db.del(k, true) // Remove expire
db.stringSet(k, v)
return nil
}
// Incr changes a int string value by delta.
func (m *Miniredis) Incr(k string, delta int) (int, error) {
return m.DB(m.selectedDB).Incr(k, delta)
}
// Incr changes a int string value by delta.
func (db *RedisDB) Incr(k string, delta int) (int, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if db.exists(k) && db.t(k) != keyTypeString {
return 0, ErrWrongType
}
return db.stringIncr(k, delta)
}
// IncrByFloat increments the float value of a key by the given delta.
// is an alias for Miniredis.Incrfloat
func (m *Miniredis) IncrByFloat(k string, delta float64) (float64, error) {
return m.Incrfloat(k, delta)
}
// Incrfloat changes a float string value by delta.
func (m *Miniredis) Incrfloat(k string, delta float64) (float64, error) {
return m.DB(m.selectedDB).Incrfloat(k, delta)
}
// Incrfloat changes a float string value by delta.
func (db *RedisDB) Incrfloat(k string, delta float64) (float64, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if db.exists(k) && db.t(k) != keyTypeString {
return 0, ErrWrongType
}
v, err := db.stringIncrfloat(k, big.NewFloat(delta))
if err != nil {
return 0, err
}
vf, _ := v.Float64()
return vf, nil
}
// List returns the list k, or an error if it's not there or something else.
// This is the same as the Redis command `LRANGE 0 -1`, but you can do your own
// range-ing.
func (m *Miniredis) List(k string) ([]string, error) {
return m.DB(m.selectedDB).List(k)
}
// List returns the list k, or an error if it's not there or something else.
// This is the same as the Redis command `LRANGE 0 -1`, but you can do your own
// range-ing.
func (db *RedisDB) List(k string) ([]string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return nil, ErrKeyNotFound
}
if db.t(k) != keyTypeList {
return nil, ErrWrongType
}
return db.listKeys[k], nil
}
// Lpush prepends one value to a list. Returns the new length.
func (m *Miniredis) Lpush(k, v string) (int, error) {
return m.DB(m.selectedDB).Lpush(k, v)
}
// Lpush prepends one value to a list. Returns the new length.
func (db *RedisDB) Lpush(k, v string) (int, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if db.exists(k) && db.t(k) != keyTypeList {
return 0, ErrWrongType
}
return db.listLpush(k, v), nil
}
// Lpop removes and returns the last element in a list.
func (m *Miniredis) Lpop(k string) (string, error) {
return m.DB(m.selectedDB).Lpop(k)
}
// Lpop removes and returns the last element in a list.
func (db *RedisDB) Lpop(k string) (string, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if !db.exists(k) {
return "", ErrKeyNotFound
}
if db.t(k) != keyTypeList {
return "", ErrWrongType
}
return db.listLpop(k), nil
}
// RPush appends one or multiple values to a list. Returns the new length.
// An alias for Push
func (m *Miniredis) RPush(k string, v ...string) (int, error) {
return m.Push(k, v...)
}
// Push add element at the end. Returns the new length.
func (m *Miniredis) Push(k string, v ...string) (int, error) {
return m.DB(m.selectedDB).Push(k, v...)
}
// Push add element at the end. Is called RPUSH in redis. Returns the new length.
func (db *RedisDB) Push(k string, v ...string) (int, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if db.exists(k) && db.t(k) != keyTypeList {
return 0, ErrWrongType
}
return db.listPush(k, v...), nil
}
// RPop is an alias for Pop
func (m *Miniredis) RPop(k string) (string, error) {
return m.Pop(k)
}
// Pop removes and returns the last element. Is called RPOP in Redis.
func (m *Miniredis) Pop(k string) (string, error) {
return m.DB(m.selectedDB).Pop(k)
}
// Pop removes and returns the last element. Is called RPOP in Redis.
func (db *RedisDB) Pop(k string) (string, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if !db.exists(k) {
return "", ErrKeyNotFound
}
if db.t(k) != keyTypeList {
return "", ErrWrongType
}
return db.listPop(k), nil
}
// SAdd adds keys to a set. Returns the number of new keys.
// Alias for SetAdd
func (m *Miniredis) SAdd(k string, elems ...string) (int, error) {
return m.SetAdd(k, elems...)
}
// SetAdd adds keys to a set. Returns the number of new keys.
func (m *Miniredis) SetAdd(k string, elems ...string) (int, error) {
return m.DB(m.selectedDB).SetAdd(k, elems...)
}
// SetAdd adds keys to a set. Returns the number of new keys.
func (db *RedisDB) SetAdd(k string, elems ...string) (int, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if db.exists(k) && db.t(k) != keyTypeSet {
return 0, ErrWrongType
}
return db.setAdd(k, elems...), nil
}
// SMembers returns all keys in a set, sorted.
// Alias for Members.
func (m *Miniredis) SMembers(k string) ([]string, error) {
return m.Members(k)
}
// Members returns all keys in a set, sorted.
func (m *Miniredis) Members(k string) ([]string, error) {
return m.DB(m.selectedDB).Members(k)
}
// Members gives all set keys. Sorted.
func (db *RedisDB) Members(k string) ([]string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return nil, ErrKeyNotFound
}
if db.t(k) != keyTypeSet {
return nil, ErrWrongType
}
return db.setMembers(k), nil
}
// SIsMember tells if value is in the set.
// Alias for IsMember
func (m *Miniredis) SIsMember(k, v string) (bool, error) {
return m.IsMember(k, v)
}
// IsMember tells if value is in the set.
func (m *Miniredis) IsMember(k, v string) (bool, error) {
return m.DB(m.selectedDB).IsMember(k, v)
}
// IsMember tells if value is in the set.
func (db *RedisDB) IsMember(k, v string) (bool, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return false, ErrKeyNotFound
}
if db.t(k) != keyTypeSet {
return false, ErrWrongType
}
return db.setIsMember(k, v), nil
}
// HKeys returns all (sorted) keys ('fields') for a hash key.
func (m *Miniredis) HKeys(k string) ([]string, error) {
return m.DB(m.selectedDB).HKeys(k)
}
// HKeys returns all (sorted) keys ('fields') for a hash key.
func (db *RedisDB) HKeys(key string) ([]string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(key) {
return nil, ErrKeyNotFound
}
if db.t(key) != keyTypeHash {
return nil, ErrWrongType
}
return db.hashFields(key), nil
}
// Del deletes a key and any expiration value. Returns whether there was a key.
func (m *Miniredis) Del(k string) bool {
return m.DB(m.selectedDB).Del(k)
}
// Del deletes a key and any expiration value. Returns whether there was a key.
func (db *RedisDB) Del(k string) bool {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if !db.exists(k) {
return false
}
db.del(k, true)
return true
}
// Unlink deletes a key and any expiration value. Returns where there was a key.
// It's exactly the same as Del() and is not async. It is here for the consistency.
func (m *Miniredis) Unlink(k string) bool {
return m.Del(k)
}
// Unlink deletes a key and any expiration value. Returns where there was a key.
// It's exactly the same as Del() and is not async. It is here for the consistency.
func (db *RedisDB) Unlink(k string) bool {
return db.Del(k)
}
// TTL is the left over time to live. As set via EXPIRE, PEXPIRE, EXPIREAT,
// PEXPIREAT.
// Note: this direct function returns 0 if there is no TTL set, unlike redis,
// which returns -1.
func (m *Miniredis) TTL(k string) time.Duration {
return m.DB(m.selectedDB).TTL(k)
}
// TTL is the left over time to live. As set via EXPIRE, PEXPIRE, EXPIREAT,
// PEXPIREAT.
// 0 if not set.
func (db *RedisDB) TTL(k string) time.Duration {
db.master.Lock()
defer db.master.Unlock()
return db.ttl[k]
}
// SetTTL sets the TTL of a key.
func (m *Miniredis) SetTTL(k string, ttl time.Duration) {
m.DB(m.selectedDB).SetTTL(k, ttl)
}
// SetTTL sets the time to live of a key.
func (db *RedisDB) SetTTL(k string, ttl time.Duration) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
db.ttl[k] = ttl
db.incr(k)
}
// Type gives the type of a key, or ""
func (m *Miniredis) Type(k string) string {
return m.DB(m.selectedDB).Type(k)
}
// Type gives the type of a key, or ""
func (db *RedisDB) Type(k string) string {
db.master.Lock()
defer db.master.Unlock()
return db.t(k)
}
// Exists tells whether a key exists.
func (m *Miniredis) Exists(k string) bool {
return m.DB(m.selectedDB).Exists(k)
}
// Exists tells whether a key exists.
func (db *RedisDB) Exists(k string) bool {
db.master.Lock()
defer db.master.Unlock()
return db.exists(k)
}
// HGet returns hash keys added with HSET.
// This will return an empty string if the key is not set. Redis would return
// a nil.
// Returns empty string when the key is of a different type.
func (m *Miniredis) HGet(k, f string) string {
return m.DB(m.selectedDB).HGet(k, f)
}
// HGet returns hash keys added with HSET.
// Returns empty string when the key is of a different type.
func (db *RedisDB) HGet(k, f string) string {
db.master.Lock()
defer db.master.Unlock()
h, ok := db.hashKeys[k]
if !ok {
return ""
}
return h[f]
}
// HSet sets hash keys.
// If there is another key by the same name it will be gone.
func (m *Miniredis) HSet(k string, fv ...string) {
m.DB(m.selectedDB).HSet(k, fv...)
}
// HSet sets hash keys.
// If there is another key by the same name it will be gone.
func (db *RedisDB) HSet(k string, fv ...string) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
db.hashSet(k, fv...)
}
// HDel deletes a hash key.
func (m *Miniredis) HDel(k, f string) {
m.DB(m.selectedDB).HDel(k, f)
}
// HDel deletes a hash key.
func (db *RedisDB) HDel(k, f string) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
db.hdel(k, f)
}
func (db *RedisDB) hdel(k, f string) {
if _, ok := db.hashKeys[k]; !ok {
return
}
delete(db.hashKeys[k], f)
db.incr(k)
}
// HIncrBy increases the integer value of a hash field by delta (int).
func (m *Miniredis) HIncrBy(k, f string, delta int) (int, error) {
return m.HIncr(k, f, delta)
}
// HIncr increases a key/field by delta (int).
func (m *Miniredis) HIncr(k, f string, delta int) (int, error) {
return m.DB(m.selectedDB).HIncr(k, f, delta)
}
// HIncr increases a key/field by delta (int).
func (db *RedisDB) HIncr(k, f string, delta int) (int, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
return db.hashIncr(k, f, delta)
}
// HIncrByFloat increases a key/field by delta (float).
func (m *Miniredis) HIncrByFloat(k, f string, delta float64) (float64, error) {
return m.HIncrfloat(k, f, delta)
}
// HIncrfloat increases a key/field by delta (float).
func (m *Miniredis) HIncrfloat(k, f string, delta float64) (float64, error) {
return m.DB(m.selectedDB).HIncrfloat(k, f, delta)
}
// HIncrfloat increases a key/field by delta (float).
func (db *RedisDB) HIncrfloat(k, f string, delta float64) (float64, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
v, err := db.hashIncrfloat(k, f, big.NewFloat(delta))
if err != nil {
return 0, err
}
vf, _ := v.Float64()
return vf, nil
}
// SRem removes fields from a set. Returns number of deleted fields.
func (m *Miniredis) SRem(k string, fields ...string) (int, error) {
return m.DB(m.selectedDB).SRem(k, fields...)
}
// SRem removes fields from a set. Returns number of deleted fields.
func (db *RedisDB) SRem(k string, fields ...string) (int, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if !db.exists(k) {
return 0, ErrKeyNotFound
}
if db.t(k) != keyTypeSet {
return 0, ErrWrongType
}
return db.setRem(k, fields...), nil
}
// ZAdd adds a score,member to a sorted set.
func (m *Miniredis) ZAdd(k string, score float64, member string) (bool, error) {
return m.DB(m.selectedDB).ZAdd(k, score, member)
}
// ZAdd adds a score,member to a sorted set.
func (db *RedisDB) ZAdd(k string, score float64, member string) (bool, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if db.exists(k) && db.t(k) != keyTypeSortedSet {
return false, ErrWrongType
}
return db.ssetAdd(k, score, member), nil
}
// ZMembers returns all members of a sorted set by score
func (m *Miniredis) ZMembers(k string) ([]string, error) {
return m.DB(m.selectedDB).ZMembers(k)
}
// ZMembers returns all members of a sorted set by score
func (db *RedisDB) ZMembers(k string) ([]string, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return nil, ErrKeyNotFound
}
if db.t(k) != keyTypeSortedSet {
return nil, ErrWrongType
}
return db.ssetMembers(k), nil
}
// SortedSet returns a raw string->float64 map.
func (m *Miniredis) SortedSet(k string) (map[string]float64, error) {
return m.DB(m.selectedDB).SortedSet(k)
}
// SortedSet returns a raw string->float64 map.
func (db *RedisDB) SortedSet(k string) (map[string]float64, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return nil, ErrKeyNotFound
}
if db.t(k) != keyTypeSortedSet {
return nil, ErrWrongType
}
return db.sortedSet(k), nil
}
// ZRem deletes a member. Returns whether the was a key.
func (m *Miniredis) ZRem(k, member string) (bool, error) {
return m.DB(m.selectedDB).ZRem(k, member)
}
// ZRem deletes a member. Returns whether the was a key.
func (db *RedisDB) ZRem(k, member string) (bool, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
if !db.exists(k) {
return false, ErrKeyNotFound
}
if db.t(k) != keyTypeSortedSet {
return false, ErrWrongType
}
return db.ssetRem(k, member), nil
}
// ZScore gives the score of a sorted set member.
func (m *Miniredis) ZScore(k, member string) (float64, error) {
return m.DB(m.selectedDB).ZScore(k, member)
}
// ZScore gives the score of a sorted set member.
func (db *RedisDB) ZScore(k, member string) (float64, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return 0, ErrKeyNotFound
}
if db.t(k) != keyTypeSortedSet {
return 0, ErrWrongType
}
return db.ssetScore(k, member), nil
}
// ZScore gives scores of a list of members in a sorted set.
func (m *Miniredis) ZMScore(k string, members ...string) ([]float64, error) {
return m.DB(m.selectedDB).ZMScore(k, members)
}
func (db *RedisDB) ZMScore(k string, members []string) ([]float64, error) {
db.master.Lock()
defer db.master.Unlock()
if !db.exists(k) {
return nil, ErrKeyNotFound
}
if db.t(k) != keyTypeSortedSet {
return nil, ErrWrongType
}
return db.ssetMScore(k, members), nil
}
// XAdd adds an entry to a stream. `id` can be left empty or be '*'.
// If a value is given normal XADD rules apply. Values should be an even
// length.
func (m *Miniredis) XAdd(k string, id string, values []string) (string, error) {
return m.DB(m.selectedDB).XAdd(k, id, values)
}
// XAdd adds an entry to a stream. `id` can be left empty or be '*'.
// If a value is given normal XADD rules apply. Values should be an even
// length.
func (db *RedisDB) XAdd(k string, id string, values []string) (string, error) {
db.master.Lock()
defer db.master.Unlock()
defer db.master.signal.Broadcast()
s, err := db.stream(k)
if err != nil {
return "", err
}
if s == nil {
s, _ = db.newStream(k)
}
return s.add(id, values, db.master.effectiveNow())
}
// Stream returns a slice of stream entries. Oldest first.
func (m *Miniredis) Stream(k string) ([]StreamEntry, error) {
return m.DB(m.selectedDB).Stream(k)
}
// Stream returns a slice of stream entries. Oldest first.
func (db *RedisDB) Stream(key string) ([]StreamEntry, error) {
db.master.Lock()
defer db.master.Unlock()
s, err := db.stream(key)
if err != nil {
return nil, err
}
if s == nil {
return nil, nil
}
return s.entries, nil
}
// Publish a message to subscribers. Returns the number of receivers.
func (m *Miniredis) Publish(channel, message string) int {
m.Lock()
defer m.Unlock()
return m.publish(channel, message)
}
// PubSubChannels is "PUBSUB CHANNELS <pattern>". An empty pattern is fine
// (meaning all channels).
// Returned channels will be ordered alphabetically.
func (m *Miniredis) PubSubChannels(pattern string) []string {
m.Lock()
defer m.Unlock()
return activeChannels(m.allSubscribers(), pattern)
}
// PubSubNumSub is "PUBSUB NUMSUB [channels]". It returns all channels with their
// subscriber count.
func (m *Miniredis) PubSubNumSub(channels ...string) map[string]int {
m.Lock()
defer m.Unlock()
subs := m.allSubscribers()
res := map[string]int{}
for _, channel := range channels {
res[channel] = countSubs(subs, channel)
}
return res
}
// PubSubNumPat is "PUBSUB NUMPAT"
func (m *Miniredis) PubSubNumPat() int {
m.Lock()
defer m.Unlock()
return countPsubs(m.allSubscribers())
}
// PfAdd adds keys to a hll. Returns the flag which equals to 1 if the inner hll value has been changed.
func (m *Miniredis) PfAdd(k string, elems ...string) (int, error) {
return m.DB(m.selectedDB).HllAdd(k, elems...)
}
// HllAdd adds keys to a hll. Returns the flag which equals to true if the inner hll value has been changed.
func (db *RedisDB) HllAdd(k string, elems ...string) (int, error) {
db.master.Lock()
defer db.master.Unlock()
if db.exists(k) && db.t(k) != keyTypeHll {
return 0, ErrWrongType
}
return db.hllAdd(k, elems...), nil
}
// PfCount returns an estimation of the amount of elements previously added to a hll.
func (m *Miniredis) PfCount(keys ...string) (int, error) {
return m.DB(m.selectedDB).HllCount(keys...)
}
// HllCount returns an estimation of the amount of elements previously added to a hll.
func (db *RedisDB) HllCount(keys ...string) (int, error) {
db.master.Lock()
defer db.master.Unlock()
return db.hllCount(keys)
}
// PfMerge merges all the input hlls into a hll under destKey key.
func (m *Miniredis) PfMerge(destKey string, sourceKeys ...string) error {
return m.DB(m.selectedDB).HllMerge(destKey, sourceKeys...)
}
// HllMerge merges all the input hlls into a hll under destKey key.
func (db *RedisDB) HllMerge(destKey string, sourceKeys ...string) error {
db.master.Lock()
defer db.master.Unlock()
return db.hllMerge(append([]string{destKey}, sourceKeys...))
}
// Copy a value.
// Needs the IDs of both the source and dest DBs (which can differ).
// Returns ErrKeyNotFound if src does not exist.
// Overwrites dest if it already exists (unlike the redis command, which needs a flag to allow that).
func (m *Miniredis) Copy(srcDB int, src string, destDB int, dest string) error {
return m.copy(m.DB(srcDB), src, m.DB(destDB), dest)
}
+26
View File
@@ -0,0 +1,26 @@
This code is derived from the C code in redis-7.2.0/deps/fpconv/*, which has
this license:
Boost Software License - Version 1.0 - August 17th, 2003
Permission is hereby granted, free of charge, to any person or organization
obtaining a copy of the software and accompanying documentation covered by
this license (the "Software") to use, reproduce, display, distribute,
execute, and transmit the Software, and to prepare derivative works of the
Software, and to permit third-parties to whom the Software is furnished to
do so, all subject to the following:
The copyright notices in the Software and this entire statement, including
the above license grant, this restriction and the following disclaimer,
must be included in all copies of the Software, in whole or in part, and
all derivative works of the Software, unless such copies or derivative
works are solely in the form of machine-executable object code generated by
a source language processor.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
+6
View File
@@ -0,0 +1,6 @@
.PHONY: test fuzz
test:
go test
fuzz:
go test -fuzz=Fuzz
+3
View File
@@ -0,0 +1,3 @@
This is a translation of the actual C code in Redis (7.2) which does the float
-> string conversion.
Strconv does a close enough job, but we can use the exact same logic, so why not.
+286
View File
@@ -0,0 +1,286 @@
package fpconv
import (
"math"
)
var (
fracmask uint64 = 0x000FFFFFFFFFFFFF
expmask uint64 = 0x7FF0000000000000
hiddenbit uint64 = 0x0010000000000000
signmask uint64 = 0x8000000000000000
expbias int64 = 1023 + 52
zeros = []rune("0000000000000000000000")
tens = []uint64{
10000000000000000000,
1000000000000000000,
100000000000000000,
10000000000000000,
1000000000000000,
100000000000000,
10000000000000,
1000000000000,
100000000000,
10000000000,
1000000000,
100000000,
10000000,
1000000,
100000,
10000,
1000,
100,
10,
1}
)
func absv(n int) int {
if n < 0 {
return -n
}
return n
}
func minv(a, b int) int {
if a < b {
return a
}
return b
}
func Dtoa(d float64) string {
var (
dest [25]rune // Note C has 24, which is broken
digits [18]rune
str_len int = 0
neg = false
)
if get_dbits(d)&signmask != 0 {
dest[0] = '-'
str_len++
neg = true
}
if spec := filter_special(d, dest[str_len:]); spec != 0 {
return string(dest[:str_len+spec])
}
var (
k int = 0
ndigits int = grisu2(d, &digits, &k)
)
str_len += emit_digits(&digits, ndigits, dest[str_len:], k, neg)
return string(dest[:str_len])
}
func filter_special(fp float64, dest []rune) int {
if fp == 0.0 {
dest[0] = '0'
return 1
}
if math.IsNaN(fp) {
dest[0] = 'n'
dest[1] = 'a'
dest[2] = 'n'
return 3
}
if math.IsInf(fp, 0) {
dest[0] = 'i'
dest[1] = 'n'
dest[2] = 'f'
return 3
}
return 0
}
func grisu2(d float64, digits *[18]rune, K *int) int {
w := build_fp(d)
lower, upper := get_normalized_boundaries(w)
w = normalize(w)
var k int64
cp := find_cachedpow10(upper.exp, &k)
w = multiply(w, cp)
upper = multiply(upper, cp)
lower = multiply(lower, cp)
lower.frac++
upper.frac--
*K = int(-k)
return generate_digits(w, upper, lower, digits[:], K)
}
func emit_digits(digits *[18]rune, ndigits int, dest []rune, K int, neg bool) int {
exp := int(absv(K + ndigits - 1))
/* write plain integer */
if K >= 0 && (exp < (ndigits + 7)) {
copy(dest, digits[:ndigits])
copy(dest[ndigits:], zeros[:K])
return ndigits + K
}
/* write decimal w/o scientific notation */
if K < 0 && (K > -7 || exp < 4) {
offset := int(ndigits - absv(K))
/* fp < 1.0 -> write leading zero */
if offset <= 0 {
offset = -offset
dest[0] = '0'
dest[1] = '.'
copy(dest[2:], zeros[:offset])
copy(dest[offset+2:], digits[:ndigits])
return ndigits + 2 + offset
/* fp > 1.0 */
} else {
copy(dest, digits[:offset])
dest[offset] = '.'
copy(dest[offset+1:], digits[offset:offset+ndigits-offset])
return ndigits + 1
}
}
/* write decimal w/ scientific notation */
l := 18 // was: 18-neg
if neg {
l--
}
ndigits = minv(ndigits, l)
var idx int = 0
dest[idx] = digits[0]
idx++
if ndigits > 1 {
dest[idx] = '.'
idx++
copy(dest[idx:], digits[+1:ndigits-1+1])
idx += ndigits - 1
}
dest[idx] = 'e'
idx++
sign := '+'
if K+ndigits-1 < 0 {
sign = '-'
}
dest[idx] = sign
idx++
var cent rune = 0
if exp > 99 {
cent = rune(exp / 100)
dest[idx] = cent + '0'
idx++
exp -= int(cent) * 100
}
if exp > 9 {
dec := rune(exp / 10)
dest[idx] = dec + '0'
idx++
exp -= int(dec) * 10
} else if cent != 0 {
dest[idx] = '0'
idx++
}
dest[idx] = rune(exp%10) + '0'
idx++
return idx
}
func generate_digits(fp, upper, lower Fp, digits []rune, K *int) int {
var (
wfrac = uint64(upper.frac - fp.frac)
delta = uint64(upper.frac - lower.frac)
)
one := Fp{
frac: 1 << -upper.exp,
exp: upper.exp,
}
part1 := uint64(upper.frac >> -one.exp)
part2 := uint64(upper.frac & (one.frac - 1))
var (
idx = 0
kappa = 10
index = 10
)
/* 1000000000 */
for ; kappa > 0; index++ {
div := tens[index]
digit := part1 / div
if digit != 0 || idx != 0 {
digits[idx] = rune(digit) + '0'
idx++
}
part1 -= digit * div
kappa--
tmp := (part1 << -one.exp) + part2
if tmp <= delta {
*K += kappa
round_digit(digits, idx, delta, tmp, div<<-one.exp, wfrac)
return idx
}
}
/* 10 */
index = 18
for {
var unit uint64 = tens[index]
part2 *= 10
delta *= 10
kappa--
digit := part2 >> -one.exp
if digit != 0 || idx != 0 {
digits[idx] = rune(digit) + '0'
idx++
}
part2 &= uint64(one.frac) - 1
if part2 < delta {
*K += kappa
round_digit(digits, idx, delta, part2, uint64(one.frac), wfrac*unit)
return idx
}
index--
}
}
func round_digit(digits []rune,
ndigits int,
delta uint64,
rem uint64,
kappa uint64,
frac uint64) {
for rem < frac && delta-rem >= kappa &&
(rem+kappa < frac || frac-rem > rem+kappa-frac) {
digits[ndigits-1]--
rem += kappa
}
}
+96
View File
@@ -0,0 +1,96 @@
package fpconv
import (
"math"
)
type (
Fp struct {
frac uint64
exp int64
}
)
func build_fp(d float64) Fp {
bits := get_dbits(d)
fp := Fp{
frac: bits & fracmask,
exp: int64((bits & expmask) >> 52),
}
if fp.exp != 0 {
fp.frac += hiddenbit
fp.exp -= expbias
} else {
fp.exp = -expbias + 1
}
return fp
}
func normalize(fp Fp) Fp {
for (fp.frac & hiddenbit) == 0 {
fp.frac <<= 1
fp.exp--
}
var shift int64 = 64 - 52 - 1
fp.frac <<= shift
fp.exp -= shift
return fp
}
func multiply(a, b Fp) Fp {
lomask := uint64(0x00000000FFFFFFFF)
var (
ah_bl = uint64((a.frac >> 32) * (b.frac & lomask))
al_bh = uint64((a.frac & lomask) * (b.frac >> 32))
al_bl = uint64((a.frac & lomask) * (b.frac & lomask))
ah_bh = uint64((a.frac >> 32) * (b.frac >> 32))
)
tmp := uint64((ah_bl & lomask) + (al_bh & lomask) + (al_bl >> 32))
/* round up */
tmp += uint64(1) << 31
return Fp{
ah_bh + (ah_bl >> 32) + (al_bh >> 32) + (tmp >> 32),
a.exp + b.exp + 64,
}
}
func get_dbits(d float64) uint64 {
return math.Float64bits(d)
}
func get_normalized_boundaries(fp Fp) (Fp, Fp) {
upper := Fp{
frac: (fp.frac << 1) + 1,
exp: fp.exp - 1,
}
for (upper.frac & (hiddenbit << 1)) == 0 {
upper.frac <<= 1
upper.exp--
}
var u_shift int64 = 64 - 52 - 2
upper.frac <<= u_shift
upper.exp = upper.exp - u_shift
l_shift := int64(1)
if fp.frac == hiddenbit {
l_shift = 2
}
lower := Fp{
frac: (fp.frac << l_shift) - 1,
exp: fp.exp - l_shift,
}
lower.frac <<= lower.exp - upper.exp
lower.exp = upper.exp
return lower, upper
}
+82
View File
@@ -0,0 +1,82 @@
package fpconv
var (
npowers int64 = 87
steppowers int64 = 8
firstpower int64 = -348 /* 10 ^ -348 */
expmax = -32
expmin = -60
powers_ten = []Fp{
{18054884314459144840, -1220}, {13451937075301367670, -1193},
{10022474136428063862, -1166}, {14934650266808366570, -1140},
{11127181549972568877, -1113}, {16580792590934885855, -1087},
{12353653155963782858, -1060}, {18408377700990114895, -1034},
{13715310171984221708, -1007}, {10218702384817765436, -980},
{15227053142812498563, -954}, {11345038669416679861, -927},
{16905424996341287883, -901}, {12595523146049147757, -874},
{9384396036005875287, -847}, {13983839803942852151, -821},
{10418772551374772303, -794}, {15525180923007089351, -768},
{11567161174868858868, -741}, {17236413322193710309, -715},
{12842128665889583758, -688}, {9568131466127621947, -661},
{14257626930069360058, -635}, {10622759856335341974, -608},
{15829145694278690180, -582}, {11793632577567316726, -555},
{17573882009934360870, -529}, {13093562431584567480, -502},
{9755464219737475723, -475}, {14536774485912137811, -449},
{10830740992659433045, -422}, {16139061738043178685, -396},
{12024538023802026127, -369}, {17917957937422433684, -343},
{13349918974505688015, -316}, {9946464728195732843, -289},
{14821387422376473014, -263}, {11042794154864902060, -236},
{16455045573212060422, -210}, {12259964326927110867, -183},
{18268770466636286478, -157}, {13611294676837538539, -130},
{10141204801825835212, -103}, {15111572745182864684, -77},
{11258999068426240000, -50}, {16777216000000000000, -24},
{12500000000000000000, 3}, {9313225746154785156, 30},
{13877787807814456755, 56}, {10339757656912845936, 83},
{15407439555097886824, 109}, {11479437019748901445, 136},
{17105694144590052135, 162}, {12744735289059618216, 189},
{9495567745759798747, 216}, {14149498560666738074, 242},
{10542197943230523224, 269}, {15709099088952724970, 295},
{11704190886730495818, 322}, {17440603504673385349, 348},
{12994262207056124023, 375}, {9681479787123295682, 402},
{14426529090290212157, 428}, {10748601772107342003, 455},
{16016664761464807395, 481}, {11933345169920330789, 508},
{17782069995880619868, 534}, {13248674568444952270, 561},
{9871031767461413346, 588}, {14708983551653345445, 614},
{10959046745042015199, 641}, {16330252207878254650, 667},
{12166986024289022870, 694}, {18130221999122236476, 720},
{13508068024458167312, 747}, {10064294952495520794, 774},
{14996968138956309548, 800}, {11173611982879273257, 827},
{16649979327439178909, 853}, {12405201291620119593, 880},
{9242595204427927429, 907}, {13772540099066387757, 933},
{10261342003245940623, 960}, {15290591125556738113, 986},
{11392378155556871081, 1013}, {16975966327722178521, 1039},
{12648080533535911531, 1066},
}
)
func find_cachedpow10(exp int64, k *int64) Fp {
one_log_ten := 0.30102999566398114
approx := int64(float64(-(exp + npowers)) * one_log_ten)
idx := int((approx - firstpower) / steppowers)
for {
current := int(exp + powers_ten[idx].exp + 64)
if current < expmin {
idx++
continue
}
if current > expmax {
idx--
continue
}
*k = (firstpower + int64(idx)*steppowers)
return powers_ten[idx]
}
}
+46
View File
@@ -0,0 +1,46 @@
package miniredis
import (
"math"
"github.com/alicebob/miniredis/v2/geohash"
)
func toGeohash(long, lat float64) uint64 {
return geohash.EncodeIntWithPrecision(lat, long, 52)
}
func fromGeohash(score uint64) (float64, float64) {
lat, long := geohash.DecodeIntWithPrecision(score, 52)
return long, lat
}
// haversin(θ) function
func hsin(theta float64) float64 {
return math.Pow(math.Sin(theta/2), 2)
}
// distance function returns the distance (in meters) between two points of
// a given longitude and latitude relatively accurately (using a spherical
// approximation of the Earth) through the Haversin Distance Formula for
// great arc distance on a sphere with accuracy for small distances
// point coordinates are supplied in degrees and converted into rad. in the func
// distance returned is meters
// http://en.wikipedia.org/wiki/Haversine_formula
// Source: https://gist.github.com/cdipaolo/d3f8db3848278b49db68
func distance(lat1, lon1, lat2, lon2 float64) float64 {
// convert to radians
// must cast radius as float to multiply later
var la1, lo1, la2, lo2 float64
la1 = lat1 * math.Pi / 180
lo1 = lon1 * math.Pi / 180
la2 = lat2 * math.Pi / 180
lo2 = lon2 * math.Pi / 180
earth := 6372797.560856 // Earth radius in METERS, according to src/geohash_helper.c
// calculate
h := hsin(la2-la1) + math.Cos(la1)*math.Cos(la2)*hsin(lo2-lo1)
return 2 * earth * math.Asin(math.Sqrt(h))
}
+22
View File
@@ -0,0 +1,22 @@
The MIT License (MIT)
Copyright (c) 2015 Michael McLoughlin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+2
View File
@@ -0,0 +1,2 @@
This is a (selected) copy of github.com/mmcloughlin/geohash with the latitude
range changed from 90 to ~85, to align with the algorithm use by Redis.
+44
View File
@@ -0,0 +1,44 @@
package geohash
// encoding encapsulates an encoding defined by a given base32 alphabet.
type encoding struct {
encode string
decode [256]byte
}
// newEncoding constructs a new encoding defined by the given alphabet,
// which must be a 32-byte string.
func newEncoding(encoder string) *encoding {
e := new(encoding)
e.encode = encoder
for i := 0; i < len(e.decode); i++ {
e.decode[i] = 0xff
}
for i := 0; i < len(encoder); i++ {
e.decode[encoder[i]] = byte(i)
}
return e
}
// Decode string into bits of a 64-bit word. The string s may be at most 12
// characters.
func (e *encoding) Decode(s string) uint64 {
x := uint64(0)
for i := 0; i < len(s); i++ {
x = (x << 5) | uint64(e.decode[s[i]])
}
return x
}
// Encode bits of 64-bit word into a string.
func (e *encoding) Encode(x uint64) string {
b := [12]byte{}
for i := 0; i < 12; i++ {
b[11-i] = e.encode[x&0x1f]
x >>= 5
}
return string(b[:])
}
// Base32Encoding with the Geohash alphabet.
var base32encoding = newEncoding("0123456789bcdefghjkmnpqrstuvwxyz")
+269
View File
@@ -0,0 +1,269 @@
// Package geohash provides encoding and decoding of string and integer
// geohashes.
package geohash
import (
"math"
)
const (
ENC_LAT = 85.05112878
ENC_LONG = 180.0
)
// Direction represents directions in the latitute/longitude space.
type Direction int
// Cardinal and intercardinal directions
const (
North Direction = iota
NorthEast
East
SouthEast
South
SouthWest
West
NorthWest
)
// Encode the point (lat, lng) as a string geohash with the standard 12
// characters of precision.
func Encode(lat, lng float64) string {
return EncodeWithPrecision(lat, lng, 12)
}
// EncodeWithPrecision encodes the point (lat, lng) as a string geohash with
// the specified number of characters of precision (max 12).
func EncodeWithPrecision(lat, lng float64, chars uint) string {
bits := 5 * chars
inthash := EncodeIntWithPrecision(lat, lng, bits)
enc := base32encoding.Encode(inthash)
return enc[12-chars:]
}
// encodeInt provides a Go implementation of integer geohash. This is the
// default implementation of EncodeInt, but optimized versions are provided
// for certain architectures.
func EncodeInt(lat, lng float64) uint64 {
latInt := encodeRange(lat, ENC_LAT)
lngInt := encodeRange(lng, ENC_LONG)
return interleave(latInt, lngInt)
}
// EncodeIntWithPrecision encodes the point (lat, lng) to an integer with the
// specified number of bits.
func EncodeIntWithPrecision(lat, lng float64, bits uint) uint64 {
hash := EncodeInt(lat, lng)
return hash >> (64 - bits)
}
// Box represents a rectangle in latitude/longitude space.
type Box struct {
MinLat float64
MaxLat float64
MinLng float64
MaxLng float64
}
// Center returns the center of the box.
func (b Box) Center() (lat, lng float64) {
lat = (b.MinLat + b.MaxLat) / 2.0
lng = (b.MinLng + b.MaxLng) / 2.0
return
}
// Contains decides whether (lat, lng) is contained in the box. The
// containment test is inclusive of the edges and corners.
func (b Box) Contains(lat, lng float64) bool {
return (b.MinLat <= lat && lat <= b.MaxLat &&
b.MinLng <= lng && lng <= b.MaxLng)
}
// errorWithPrecision returns the error range in latitude and longitude for in
// integer geohash with bits of precision.
func errorWithPrecision(bits uint) (latErr, lngErr float64) {
b := int(bits)
latBits := b / 2
lngBits := b - latBits
latErr = math.Ldexp(180.0, -latBits)
lngErr = math.Ldexp(360.0, -lngBits)
return
}
// BoundingBox returns the region encoded by the given string geohash.
func BoundingBox(hash string) Box {
bits := uint(5 * len(hash))
inthash := base32encoding.Decode(hash)
return BoundingBoxIntWithPrecision(inthash, bits)
}
// BoundingBoxIntWithPrecision returns the region encoded by the integer
// geohash with the specified precision.
func BoundingBoxIntWithPrecision(hash uint64, bits uint) Box {
fullHash := hash << (64 - bits)
latInt, lngInt := deinterleave(fullHash)
lat := decodeRange(latInt, ENC_LAT)
lng := decodeRange(lngInt, ENC_LONG)
latErr, lngErr := errorWithPrecision(bits)
return Box{
MinLat: lat,
MaxLat: lat + latErr,
MinLng: lng,
MaxLng: lng + lngErr,
}
}
// BoundingBoxInt returns the region encoded by the given 64-bit integer
// geohash.
func BoundingBoxInt(hash uint64) Box {
return BoundingBoxIntWithPrecision(hash, 64)
}
// DecodeCenter decodes the string geohash to the central point of the bounding box.
func DecodeCenter(hash string) (lat, lng float64) {
box := BoundingBox(hash)
return box.Center()
}
// DecodeIntWithPrecision decodes the provided integer geohash with bits of
// precision to a (lat, lng) point.
func DecodeIntWithPrecision(hash uint64, bits uint) (lat, lng float64) {
box := BoundingBoxIntWithPrecision(hash, bits)
return box.Center()
}
// DecodeInt decodes the provided 64-bit integer geohash to a (lat, lng) point.
func DecodeInt(hash uint64) (lat, lng float64) {
return DecodeIntWithPrecision(hash, 64)
}
// Neighbors returns a slice of geohash strings that correspond to the provided
// geohash's neighbors.
func Neighbors(hash string) []string {
box := BoundingBox(hash)
lat, lng := box.Center()
latDelta := box.MaxLat - box.MinLat
lngDelta := box.MaxLng - box.MinLng
precision := uint(len(hash))
return []string{
// N
EncodeWithPrecision(lat+latDelta, lng, precision),
// NE,
EncodeWithPrecision(lat+latDelta, lng+lngDelta, precision),
// E,
EncodeWithPrecision(lat, lng+lngDelta, precision),
// SE,
EncodeWithPrecision(lat-latDelta, lng+lngDelta, precision),
// S,
EncodeWithPrecision(lat-latDelta, lng, precision),
// SW,
EncodeWithPrecision(lat-latDelta, lng-lngDelta, precision),
// W,
EncodeWithPrecision(lat, lng-lngDelta, precision),
// NW
EncodeWithPrecision(lat+latDelta, lng-lngDelta, precision),
}
}
// NeighborsInt returns a slice of uint64s that correspond to the provided hash's
// neighbors at 64-bit precision.
func NeighborsInt(hash uint64) []uint64 {
return NeighborsIntWithPrecision(hash, 64)
}
// NeighborsIntWithPrecision returns a slice of uint64s that correspond to the
// provided hash's neighbors at the given precision.
func NeighborsIntWithPrecision(hash uint64, bits uint) []uint64 {
box := BoundingBoxIntWithPrecision(hash, bits)
lat, lng := box.Center()
latDelta := box.MaxLat - box.MinLat
lngDelta := box.MaxLng - box.MinLng
return []uint64{
// N
EncodeIntWithPrecision(lat+latDelta, lng, bits),
// NE,
EncodeIntWithPrecision(lat+latDelta, lng+lngDelta, bits),
// E,
EncodeIntWithPrecision(lat, lng+lngDelta, bits),
// SE,
EncodeIntWithPrecision(lat-latDelta, lng+lngDelta, bits),
// S,
EncodeIntWithPrecision(lat-latDelta, lng, bits),
// SW,
EncodeIntWithPrecision(lat-latDelta, lng-lngDelta, bits),
// W,
EncodeIntWithPrecision(lat, lng-lngDelta, bits),
// NW
EncodeIntWithPrecision(lat+latDelta, lng-lngDelta, bits),
}
}
// Neighbor returns a geohash string that corresponds to the provided
// geohash's neighbor in the provided direction
func Neighbor(hash string, direction Direction) string {
return Neighbors(hash)[direction]
}
// NeighborInt returns a uint64 that corresponds to the provided hash's
// neighbor in the provided direction at 64-bit precision.
func NeighborInt(hash uint64, direction Direction) uint64 {
return NeighborsIntWithPrecision(hash, 64)[direction]
}
// NeighborIntWithPrecision returns a uint64s that corresponds to the
// provided hash's neighbor in the provided direction at the given precision.
func NeighborIntWithPrecision(hash uint64, bits uint, direction Direction) uint64 {
return NeighborsIntWithPrecision(hash, bits)[direction]
}
// precalculated for performance
var exp232 = math.Exp2(32)
// Encode the position of x within the range -r to +r as a 32-bit integer.
func encodeRange(x, r float64) uint32 {
p := (x + r) / (2 * r)
return uint32(p * exp232)
}
// Decode the 32-bit range encoding X back to a value in the range -r to +r.
func decodeRange(X uint32, r float64) float64 {
p := float64(X) / exp232
x := 2*r*p - r
return x
}
// Spread out the 32 bits of x into 64 bits, where the bits of x occupy even
// bit positions.
func spread(x uint32) uint64 {
X := uint64(x)
X = (X | (X << 16)) & 0x0000ffff0000ffff
X = (X | (X << 8)) & 0x00ff00ff00ff00ff
X = (X | (X << 4)) & 0x0f0f0f0f0f0f0f0f
X = (X | (X << 2)) & 0x3333333333333333
X = (X | (X << 1)) & 0x5555555555555555
return X
}
// Interleave the bits of x and y. In the result, x and y occupy even and odd
// bitlevels, respectively.
func interleave(x, y uint32) uint64 {
return spread(x) | (spread(y) << 1)
}
// Squash the even bitlevels of X into a 32-bit word. Odd bitlevels of X are
// ignored, and may take any value.
func squash(X uint64) uint32 {
X &= 0x5555555555555555
X = (X | (X >> 1)) & 0x3333333333333333
X = (X | (X >> 2)) & 0x0f0f0f0f0f0f0f0f
X = (X | (X >> 4)) & 0x00ff00ff00ff00ff
X = (X | (X >> 8)) & 0x0000ffff0000ffff
X = (X | (X >> 16)) & 0x00000000ffffffff
return uint32(X)
}
// Deinterleave the bits of X into 32-bit words containing the even and odd
// bitlevels of X, respectively.
func deinterleave(X uint64) (uint32, uint32) {
return squash(X), squash(X >> 1)
}
+24
View File
@@ -0,0 +1,24 @@
This is free and unencumbered software released into the public domain.
Anyone is free to copy, modify, publish, use, compile, sell, or
distribute this software, either in source code form or as a compiled
binary, for any purpose, commercial or non-commercial, and by any
means.
In jurisdictions that recognize copyright laws, the author or authors
of this software dedicate any and all copyright interest in the
software to the public domain. We make this dedication for the benefit
of the public at large and to the detriment of our heirs and
successors. We intend this dedication to be an overt act of
relinquishment in perpetuity of all present and future rights to this
software under copyright law.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
For more information, please refer to <http://unlicense.org/>
+1
View File
@@ -0,0 +1 @@
Copied from https://github.com/layeh/gopher-json and https://github.com/alicebob/gopher-json
+189
View File
@@ -0,0 +1,189 @@
package json
import (
"encoding/json"
"errors"
"github.com/yuin/gopher-lua"
)
// Preload adds json to the given Lua state's package.preload table. After it
// has been preloaded, it can be loaded using require:
//
// local json = require("json")
func Preload(L *lua.LState) {
L.PreloadModule("json", Loader)
}
// Loader is the module loader function.
func Loader(L *lua.LState) int {
t := L.NewTable()
L.SetFuncs(t, api)
L.Push(t)
return 1
}
var api = map[string]lua.LGFunction{
"decode": apiDecode,
"encode": apiEncode,
}
func apiDecode(L *lua.LState) int {
if L.GetTop() != 1 {
L.Error(lua.LString("bad argument #1 to decode"), 1)
return 0
}
str := L.CheckString(1)
value, err := Decode(L, []byte(str))
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
L.Push(value)
return 1
}
func apiEncode(L *lua.LState) int {
if L.GetTop() != 1 {
L.Error(lua.LString("bad argument #1 to encode"), 1)
return 0
}
value := L.CheckAny(1)
data, err := Encode(value)
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
L.Push(lua.LString(string(data)))
return 1
}
var (
errNested = errors.New("cannot encode recursively nested tables to JSON")
errSparseArray = errors.New("cannot encode sparse array")
errInvalidKeys = errors.New("cannot encode mixed or invalid key types")
)
type invalidTypeError lua.LValueType
func (i invalidTypeError) Error() string {
return `cannot encode ` + lua.LValueType(i).String() + ` to JSON`
}
// Encode returns the JSON encoding of value.
func Encode(value lua.LValue) ([]byte, error) {
return json.Marshal(jsonValue{
LValue: value,
visited: make(map[*lua.LTable]bool),
})
}
type jsonValue struct {
lua.LValue
visited map[*lua.LTable]bool
}
func (j jsonValue) MarshalJSON() (data []byte, err error) {
switch converted := j.LValue.(type) {
case lua.LBool:
data, err = json.Marshal(bool(converted))
case lua.LNumber:
data, err = json.Marshal(float64(converted))
case *lua.LNilType:
data = []byte(`null`)
case lua.LString:
data, err = json.Marshal(string(converted))
case *lua.LTable:
if j.visited[converted] {
return nil, errNested
}
j.visited[converted] = true
key, value := converted.Next(lua.LNil)
switch key.Type() {
case lua.LTNil: // empty table
data = []byte(`[]`)
case lua.LTNumber:
arr := make([]jsonValue, 0, converted.Len())
expectedKey := lua.LNumber(1)
for key != lua.LNil {
if key.Type() != lua.LTNumber {
err = errInvalidKeys
return
}
if expectedKey != key {
err = errSparseArray
return
}
arr = append(arr, jsonValue{value, j.visited})
expectedKey++
key, value = converted.Next(key)
}
data, err = json.Marshal(arr)
case lua.LTString:
obj := make(map[string]jsonValue)
for key != lua.LNil {
if key.Type() != lua.LTString {
err = errInvalidKeys
return
}
obj[key.String()] = jsonValue{value, j.visited}
key, value = converted.Next(key)
}
data, err = json.Marshal(obj)
default:
err = errInvalidKeys
}
default:
err = invalidTypeError(j.LValue.Type())
}
return
}
// Decode converts the JSON encoded data to Lua values.
func Decode(L *lua.LState, data []byte) (lua.LValue, error) {
var value interface{}
err := json.Unmarshal(data, &value)
if err != nil {
return nil, err
}
return DecodeValue(L, value), nil
}
// DecodeValue converts the value to a Lua value.
//
// This function only converts values that the encoding/json package decodes to.
// All other values will return lua.LNil.
func DecodeValue(L *lua.LState, value interface{}) lua.LValue {
switch converted := value.(type) {
case bool:
return lua.LBool(converted)
case float64:
return lua.LNumber(converted)
case string:
return lua.LString(converted)
case json.Number:
return lua.LString(converted)
case []interface{}:
arr := L.CreateTable(len(converted), 0)
for _, item := range converted {
arr.Append(DecodeValue(L, item))
}
return arr
case map[string]interface{}:
tbl := L.CreateTable(0, len(converted))
for key, item := range converted {
tbl.RawSetH(lua.LString(key), DecodeValue(L, item))
}
return tbl
case nil:
return lua.LNil
}
return lua.LNil
}
+42
View File
@@ -0,0 +1,42 @@
package miniredis
import (
"github.com/alicebob/miniredis/v2/hyperloglog"
)
type hll struct {
inner *hyperloglog.Sketch
}
func newHll() *hll {
return &hll{
inner: hyperloglog.New14(),
}
}
// Add returns true if cardinality has been changed, or false otherwise.
func (h *hll) Add(item []byte) bool {
return h.inner.Insert(item)
}
// Count returns the estimation of a set cardinality.
func (h *hll) Count() int {
return int(h.inner.Estimate())
}
// Merge merges the other hll into original one (not making a copy but doing this in place).
func (h *hll) Merge(other *hll) {
_ = h.inner.Merge(other.inner)
}
// Bytes returns raw-bytes representation of hll data structure.
func (h *hll) Bytes() []byte {
dataBytes, _ := h.inner.MarshalBinary()
return dataBytes
}
func (h *hll) copy() *hll {
return &hll{
inner: h.inner.Clone(),
}
}
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2017 Axiom Inc. <seif@axiom.sh>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+1
View File
@@ -0,0 +1 @@
This is a copy of github.com/axiomhq/hyperloglog.
+180
View File
@@ -0,0 +1,180 @@
package hyperloglog
import "encoding/binary"
// Original author of this file is github.com/clarkduvall/hyperloglog
type iterable interface {
decode(i int, last uint32) (uint32, int)
Len() int
Iter() *iterator
}
type iterator struct {
i int
last uint32
v iterable
}
func (iter *iterator) Next() uint32 {
n, i := iter.v.decode(iter.i, iter.last)
iter.last = n
iter.i = i
return n
}
func (iter *iterator) Peek() uint32 {
n, _ := iter.v.decode(iter.i, iter.last)
return n
}
func (iter iterator) HasNext() bool {
return iter.i < iter.v.Len()
}
type compressedList struct {
count uint32
last uint32
b variableLengthList
}
func (v *compressedList) Clone() *compressedList {
if v == nil {
return nil
}
newV := &compressedList{
count: v.count,
last: v.last,
}
newV.b = make(variableLengthList, len(v.b))
copy(newV.b, v.b)
return newV
}
func (v *compressedList) MarshalBinary() (data []byte, err error) {
// Marshal the variableLengthList
bdata, err := v.b.MarshalBinary()
if err != nil {
return nil, err
}
// At least 4 bytes for the two fixed sized values plus the size of bdata.
data = make([]byte, 0, 4+4+len(bdata))
// Marshal the count and last values.
data = append(data, []byte{
// Number of items in the list.
byte(v.count >> 24),
byte(v.count >> 16),
byte(v.count >> 8),
byte(v.count),
// The last item in the list.
byte(v.last >> 24),
byte(v.last >> 16),
byte(v.last >> 8),
byte(v.last),
}...)
// Append the list
return append(data, bdata...), nil
}
func (v *compressedList) UnmarshalBinary(data []byte) error {
if len(data) < 12 {
return ErrorTooShort
}
// Set the count.
v.count, data = binary.BigEndian.Uint32(data[:4]), data[4:]
// Set the last value.
v.last, data = binary.BigEndian.Uint32(data[:4]), data[4:]
// Set the list.
sz, data := binary.BigEndian.Uint32(data[:4]), data[4:]
v.b = make([]uint8, sz)
if uint32(len(data)) < sz {
return ErrorTooShort
}
for i := uint32(0); i < sz; i++ {
v.b[i] = data[i]
}
return nil
}
func newCompressedList() *compressedList {
v := &compressedList{}
v.b = make(variableLengthList, 0)
return v
}
func (v *compressedList) Len() int {
return len(v.b)
}
func (v *compressedList) decode(i int, last uint32) (uint32, int) {
n, i := v.b.decode(i, last)
return n + last, i
}
func (v *compressedList) Append(x uint32) {
v.count++
v.b = v.b.Append(x - v.last)
v.last = x
}
func (v *compressedList) Iter() *iterator {
return &iterator{0, 0, v}
}
type variableLengthList []uint8
func (v variableLengthList) MarshalBinary() (data []byte, err error) {
// 4 bytes for the size of the list, and a byte for each element in the
// list.
data = make([]byte, 0, 4+v.Len())
// Length of the list. We only need 32 bits because the size of the set
// couldn't exceed that on 32 bit architectures.
sz := v.Len()
data = append(data, []byte{
byte(sz >> 24),
byte(sz >> 16),
byte(sz >> 8),
byte(sz),
}...)
// Marshal each element in the list.
for i := 0; i < sz; i++ {
data = append(data, v[i])
}
return data, nil
}
func (v variableLengthList) Len() int {
return len(v)
}
func (v *variableLengthList) Iter() *iterator {
return &iterator{0, 0, v}
}
func (v variableLengthList) decode(i int, last uint32) (uint32, int) {
var x uint32
j := i
for ; v[j]&0x80 != 0; j++ {
x |= uint32(v[j]&0x7f) << (uint(j-i) * 7)
}
x |= uint32(v[j]) << (uint(j-i) * 7)
return x, j + 1
}
func (v variableLengthList) Append(x uint32) variableLengthList {
for x&0xffffff80 != 0 {
v = append(v, uint8((x&0x7f)|0x80))
x >>= 7
}
return append(v, uint8(x&0x7f))
}
+424
View File
@@ -0,0 +1,424 @@
package hyperloglog
import (
"encoding/binary"
"errors"
"fmt"
"math"
"sort"
)
const (
capacity = uint8(16)
pp = uint8(25)
mp = uint32(1) << pp
version = 1
)
// Sketch is a HyperLogLog data-structure for the count-distinct problem,
// approximating the number of distinct elements in a multiset.
type Sketch struct {
p uint8
b uint8
m uint32
alpha float64
tmpSet set
sparseList *compressedList
regs *registers
}
// New returns a HyperLogLog Sketch with 2^14 registers (precision 14)
func New() *Sketch {
return New14()
}
// New14 returns a HyperLogLog Sketch with 2^14 registers (precision 14)
func New14() *Sketch {
sk, _ := newSketch(14, true)
return sk
}
// New16 returns a HyperLogLog Sketch with 2^16 registers (precision 16)
func New16() *Sketch {
sk, _ := newSketch(16, true)
return sk
}
// NewNoSparse returns a HyperLogLog Sketch with 2^14 registers (precision 14)
// that will not use a sparse representation
func NewNoSparse() *Sketch {
sk, _ := newSketch(14, false)
return sk
}
// New16NoSparse returns a HyperLogLog Sketch with 2^16 registers (precision 16)
// that will not use a sparse representation
func New16NoSparse() *Sketch {
sk, _ := newSketch(16, false)
return sk
}
// newSketch returns a HyperLogLog Sketch with 2^precision registers
func newSketch(precision uint8, sparse bool) (*Sketch, error) {
if precision < 4 || precision > 18 {
return nil, fmt.Errorf("p has to be >= 4 and <= 18")
}
m := uint32(math.Pow(2, float64(precision)))
s := &Sketch{
m: m,
p: precision,
alpha: alpha(float64(m)),
}
if sparse {
s.tmpSet = set{}
s.sparseList = newCompressedList()
} else {
s.regs = newRegisters(m)
}
return s, nil
}
func (sk *Sketch) sparse() bool {
return sk.sparseList != nil
}
// Clone returns a deep copy of sk.
func (sk *Sketch) Clone() *Sketch {
return &Sketch{
b: sk.b,
p: sk.p,
m: sk.m,
alpha: sk.alpha,
tmpSet: sk.tmpSet.Clone(),
sparseList: sk.sparseList.Clone(),
regs: sk.regs.clone(),
}
}
// Converts to normal if the sparse list is too large.
func (sk *Sketch) maybeToNormal() {
if uint32(len(sk.tmpSet))*100 > sk.m {
sk.mergeSparse()
if uint32(sk.sparseList.Len()) > sk.m {
sk.toNormal()
}
}
}
// Merge takes another Sketch and combines it with Sketch h.
// If Sketch h is using the sparse Sketch, it will be converted
// to the normal Sketch.
func (sk *Sketch) Merge(other *Sketch) error {
if other == nil {
// Nothing to do
return nil
}
cpOther := other.Clone()
if sk.p != cpOther.p {
return errors.New("precisions must be equal")
}
if sk.sparse() && other.sparse() {
for k := range other.tmpSet {
sk.tmpSet.add(k)
}
for iter := other.sparseList.Iter(); iter.HasNext(); {
sk.tmpSet.add(iter.Next())
}
sk.maybeToNormal()
return nil
}
if sk.sparse() {
sk.toNormal()
}
if cpOther.sparse() {
for k := range cpOther.tmpSet {
i, r := decodeHash(k, cpOther.p, pp)
sk.insert(i, r)
}
for iter := cpOther.sparseList.Iter(); iter.HasNext(); {
i, r := decodeHash(iter.Next(), cpOther.p, pp)
sk.insert(i, r)
}
} else {
if sk.b < cpOther.b {
sk.regs.rebase(cpOther.b - sk.b)
sk.b = cpOther.b
} else {
cpOther.regs.rebase(sk.b - cpOther.b)
cpOther.b = sk.b
}
for i, v := range cpOther.regs.tailcuts {
v1 := v.get(0)
if v1 > sk.regs.get(uint32(i)*2) {
sk.regs.set(uint32(i)*2, v1)
}
v2 := v.get(1)
if v2 > sk.regs.get(1+uint32(i)*2) {
sk.regs.set(1+uint32(i)*2, v2)
}
}
}
return nil
}
// Convert from sparse Sketch to dense Sketch.
func (sk *Sketch) toNormal() {
if len(sk.tmpSet) > 0 {
sk.mergeSparse()
}
sk.regs = newRegisters(sk.m)
for iter := sk.sparseList.Iter(); iter.HasNext(); {
i, r := decodeHash(iter.Next(), sk.p, pp)
sk.insert(i, r)
}
sk.tmpSet = nil
sk.sparseList = nil
}
func (sk *Sketch) insert(i uint32, r uint8) bool {
changed := false
if r-sk.b >= capacity {
//overflow
db := sk.regs.min()
if db > 0 {
sk.b += db
sk.regs.rebase(db)
changed = true
}
}
if r > sk.b {
val := r - sk.b
if c1 := capacity - 1; c1 < val {
val = c1
}
if val > sk.regs.get(i) {
sk.regs.set(i, val)
changed = true
}
}
return changed
}
// Insert adds element e to sketch
func (sk *Sketch) Insert(e []byte) bool {
x := hash(e)
return sk.InsertHash(x)
}
// InsertHash adds hash x to sketch
func (sk *Sketch) InsertHash(x uint64) bool {
if sk.sparse() {
changed := sk.tmpSet.add(encodeHash(x, sk.p, pp))
if !changed {
return false
}
if uint32(len(sk.tmpSet))*100 > sk.m/2 {
sk.mergeSparse()
if uint32(sk.sparseList.Len()) > sk.m/2 {
sk.toNormal()
}
}
return true
} else {
i, r := getPosVal(x, sk.p)
return sk.insert(uint32(i), r)
}
}
// Estimate returns the cardinality of the Sketch
func (sk *Sketch) Estimate() uint64 {
if sk.sparse() {
sk.mergeSparse()
return uint64(linearCount(mp, mp-sk.sparseList.count))
}
sum, ez := sk.regs.sumAndZeros(sk.b)
m := float64(sk.m)
var est float64
var beta func(float64) float64
if sk.p < 16 {
beta = beta14
} else {
beta = beta16
}
if sk.b == 0 {
est = (sk.alpha * m * (m - ez) / (sum + beta(ez)))
} else {
est = (sk.alpha * m * m / sum)
}
return uint64(est + 0.5)
}
func (sk *Sketch) mergeSparse() {
if len(sk.tmpSet) == 0 {
return
}
keys := make(uint64Slice, 0, len(sk.tmpSet))
for k := range sk.tmpSet {
keys = append(keys, k)
}
sort.Sort(keys)
newList := newCompressedList()
for iter, i := sk.sparseList.Iter(), 0; iter.HasNext() || i < len(keys); {
if !iter.HasNext() {
newList.Append(keys[i])
i++
continue
}
if i >= len(keys) {
newList.Append(iter.Next())
continue
}
x1, x2 := iter.Peek(), keys[i]
if x1 == x2 {
newList.Append(iter.Next())
i++
} else if x1 > x2 {
newList.Append(x2)
i++
} else {
newList.Append(iter.Next())
}
}
sk.sparseList = newList
sk.tmpSet = set{}
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (sk *Sketch) MarshalBinary() (data []byte, err error) {
// Marshal a version marker.
data = append(data, version)
// Marshal p.
data = append(data, sk.p)
// Marshal b
data = append(data, sk.b)
if sk.sparse() {
// It's using the sparse Sketch.
data = append(data, byte(1))
// Add the tmp_set
tsdata, err := sk.tmpSet.MarshalBinary()
if err != nil {
return nil, err
}
data = append(data, tsdata...)
// Add the sparse Sketch
sdata, err := sk.sparseList.MarshalBinary()
if err != nil {
return nil, err
}
return append(data, sdata...), nil
}
// It's using the dense Sketch.
data = append(data, byte(0))
// Add the dense sketch Sketch.
sz := len(sk.regs.tailcuts)
data = append(data, []byte{
byte(sz >> 24),
byte(sz >> 16),
byte(sz >> 8),
byte(sz),
}...)
// Marshal each element in the list.
for i := 0; i < len(sk.regs.tailcuts); i++ {
data = append(data, byte(sk.regs.tailcuts[i]))
}
return data, nil
}
// ErrorTooShort is an error that UnmarshalBinary try to parse too short
// binary.
var ErrorTooShort = errors.New("too short binary")
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
func (sk *Sketch) UnmarshalBinary(data []byte) error {
if len(data) < 8 {
return ErrorTooShort
}
// Unmarshal version. We may need this in the future if we make
// non-compatible changes.
_ = data[0]
// Unmarshal p.
p := data[1]
// Unmarshal b.
sk.b = data[2]
// Determine if we need a sparse Sketch
sparse := data[3] == byte(1)
// Make a newSketch Sketch if the precision doesn't match or if the Sketch was used
if sk.p != p || sk.regs != nil || len(sk.tmpSet) > 0 || (sk.sparseList != nil && sk.sparseList.Len() > 0) {
newh, err := newSketch(p, sparse)
if err != nil {
return err
}
newh.b = sk.b
*sk = *newh
}
// h is now initialised with the correct p. We just need to fill the
// rest of the details out.
if sparse {
// Using the sparse Sketch.
// Unmarshal the tmp_set.
tssz := binary.BigEndian.Uint32(data[4:8])
sk.tmpSet = make(map[uint32]struct{}, tssz)
// We need to unmarshal tssz values in total, and each value requires us
// to read 4 bytes.
tsLastByte := int((tssz * 4) + 8)
for i := 8; i < tsLastByte; i += 4 {
k := binary.BigEndian.Uint32(data[i : i+4])
sk.tmpSet[k] = struct{}{}
}
// Unmarshal the sparse Sketch.
return sk.sparseList.UnmarshalBinary(data[tsLastByte:])
}
// Using the dense Sketch.
sk.sparseList = nil
sk.tmpSet = nil
dsz := binary.BigEndian.Uint32(data[4:8])
sk.regs = newRegisters(dsz * 2)
data = data[8:]
for i, val := range data {
sk.regs.tailcuts[i] = reg(val)
if uint8(sk.regs.tailcuts[i]<<4>>4) > 0 {
sk.regs.nz--
}
if uint8(sk.regs.tailcuts[i]>>4) > 0 {
sk.regs.nz--
}
}
return nil
}
+114
View File
@@ -0,0 +1,114 @@
package hyperloglog
import (
"math"
)
type reg uint8
type tailcuts []reg
type registers struct {
tailcuts
nz uint32
}
func (r *reg) set(offset, val uint8) bool {
var isZero bool
if offset == 0 {
isZero = *r < 16
tmpVal := uint8((*r) << 4 >> 4)
*r = reg(tmpVal | (val << 4))
} else {
isZero = *r&0x0f == 0
tmpVal := uint8((*r) >> 4 << 4)
*r = reg(tmpVal | val)
}
return isZero
}
func (r *reg) get(offset uint8) uint8 {
if offset == 0 {
return uint8((*r) >> 4)
}
return uint8((*r) << 4 >> 4)
}
func newRegisters(size uint32) *registers {
return &registers{
tailcuts: make(tailcuts, size/2),
nz: size,
}
}
func (rs *registers) clone() *registers {
if rs == nil {
return nil
}
tc := make([]reg, len(rs.tailcuts))
copy(tc, rs.tailcuts)
return &registers{
tailcuts: tc,
nz: rs.nz,
}
}
func (rs *registers) rebase(delta uint8) {
nz := uint32(len(rs.tailcuts)) * 2
for i := range rs.tailcuts {
for j := uint8(0); j < 2; j++ {
val := rs.tailcuts[i].get(j)
if val >= delta {
rs.tailcuts[i].set(j, val-delta)
if val-delta > 0 {
nz--
}
}
}
}
rs.nz = nz
}
func (rs *registers) set(i uint32, val uint8) {
offset, index := uint8(i)&1, i/2
if rs.tailcuts[index].set(offset, val) {
rs.nz--
}
}
func (rs *registers) get(i uint32) uint8 {
offset, index := uint8(i)&1, i/2
return rs.tailcuts[index].get(offset)
}
func (rs *registers) sumAndZeros(base uint8) (res, ez float64) {
for _, r := range rs.tailcuts {
for j := uint8(0); j < 2; j++ {
v := float64(base + r.get(j))
if v == 0 {
ez++
}
res += 1.0 / math.Pow(2.0, v)
}
}
rs.nz = uint32(ez)
return res, ez
}
func (rs *registers) min() uint8 {
if rs.nz > 0 {
return 0
}
min := uint8(math.MaxUint8)
for _, r := range rs.tailcuts {
if r == 0 || min == 0 {
return 0
}
if val := uint8(r << 4 >> 4); val < min {
min = val
}
if val := uint8(r >> 4); val < min {
min = val
}
}
return min
}
+92
View File
@@ -0,0 +1,92 @@
package hyperloglog
import (
"math/bits"
)
func getIndex(k uint32, p, pp uint8) uint32 {
if k&1 == 1 {
return bextr32(k, 32-p, p)
}
return bextr32(k, pp-p+1, p)
}
// Encode a hash to be used in the sparse representation.
func encodeHash(x uint64, p, pp uint8) uint32 {
idx := uint32(bextr(x, 64-pp, pp))
if bextr(x, 64-pp, pp-p) == 0 {
zeros := bits.LeadingZeros64((bextr(x, 0, 64-pp)<<pp)|(1<<pp-1)) + 1
return idx<<7 | uint32(zeros<<1) | 1
}
return idx << 1
}
// Decode a hash from the sparse representation.
func decodeHash(k uint32, p, pp uint8) (uint32, uint8) {
var r uint8
if k&1 == 1 {
r = uint8(bextr32(k, 1, 6)) + pp - p
} else {
// We can use the 64bit clz implementation and reduce the result
// by 32 to get a clz for a 32bit word.
r = uint8(bits.LeadingZeros64(uint64(k<<(32-pp+p-1))) - 31) // -32 + 1
}
return getIndex(k, p, pp), r
}
type set map[uint32]struct{}
func (s set) add(v uint32) bool {
_, ok := s[v]
if ok {
return false
}
s[v] = struct{}{}
return true
}
func (s set) Clone() set {
if s == nil {
return nil
}
newS := make(map[uint32]struct{}, len(s))
for k, v := range s {
newS[k] = v
}
return newS
}
func (s set) MarshalBinary() (data []byte, err error) {
// 4 bytes for the size of the set, and 4 bytes for each key.
// list.
data = make([]byte, 0, 4+(4*len(s)))
// Length of the set. We only need 32 bits because the size of the set
// couldn't exceed that on 32 bit architectures.
sl := len(s)
data = append(data, []byte{
byte(sl >> 24),
byte(sl >> 16),
byte(sl >> 8),
byte(sl),
}...)
// Marshal each element in the set.
for k := range s {
data = append(data, []byte{
byte(k >> 24),
byte(k >> 16),
byte(k >> 8),
byte(k),
}...)
}
return data, nil
}
type uint64Slice []uint32
func (p uint64Slice) Len() int { return len(p) }
func (p uint64Slice) Less(i, j int) bool { return p[i] < p[j] }
func (p uint64Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
+69
View File
@@ -0,0 +1,69 @@
package hyperloglog
import (
"github.com/alicebob/miniredis/v2/metro"
"math"
"math/bits"
)
var hash = hashFunc
func beta14(ez float64) float64 {
zl := math.Log(ez + 1)
return -0.370393911*ez +
0.070471823*zl +
0.17393686*math.Pow(zl, 2) +
0.16339839*math.Pow(zl, 3) +
-0.09237745*math.Pow(zl, 4) +
0.03738027*math.Pow(zl, 5) +
-0.005384159*math.Pow(zl, 6) +
0.00042419*math.Pow(zl, 7)
}
func beta16(ez float64) float64 {
zl := math.Log(ez + 1)
return -0.37331876643753059*ez +
-1.41704077448122989*zl +
0.40729184796612533*math.Pow(zl, 2) +
1.56152033906584164*math.Pow(zl, 3) +
-0.99242233534286128*math.Pow(zl, 4) +
0.26064681399483092*math.Pow(zl, 5) +
-0.03053811369682807*math.Pow(zl, 6) +
0.00155770210179105*math.Pow(zl, 7)
}
func alpha(m float64) float64 {
switch m {
case 16:
return 0.673
case 32:
return 0.697
case 64:
return 0.709
}
return 0.7213 / (1 + 1.079/m)
}
func getPosVal(x uint64, p uint8) (uint64, uint8) {
i := bextr(x, 64-p, p) // {x63,...,x64-p}
w := x<<p | 1<<(p-1) // {x63-p,...,x0}
rho := uint8(bits.LeadingZeros64(w)) + 1
return i, rho
}
func linearCount(m uint32, v uint32) float64 {
fm := float64(m)
return fm * math.Log(fm/float64(v))
}
func bextr(v uint64, start, length uint8) uint64 {
return (v >> start) & ((1 << length) - 1)
}
func bextr32(v uint32, start, length uint8) uint32 {
return (v >> start) & ((1 << length) - 1)
}
func hashFunc(e []byte) uint64 {
return metro.Hash64(e, 1337)
}
+83
View File
@@ -0,0 +1,83 @@
package miniredis
// Translate the 'KEYS' or 'PSUBSCRIBE' argument ('foo*', 'f??', &c.) into a regexp.
import (
"bytes"
"regexp"
)
// patternRE compiles a glob to a regexp. Returns nil if the given
// pattern will never match anything.
// The general strategy is to sandwich all non-meta characters between \Q...\E.
func patternRE(k string) *regexp.Regexp {
re := bytes.Buffer{}
re.WriteString(`(?s)^\Q`)
for i := 0; i < len(k); i++ {
p := k[i]
switch p {
case '*':
re.WriteString(`\E.*\Q`)
case '?':
re.WriteString(`\E.\Q`)
case '[':
charClass := bytes.Buffer{}
i++
for ; i < len(k); i++ {
if k[i] == ']' {
break
}
if k[i] == '\\' {
if i == len(k)-1 {
// Ends with a '\'. U-huh.
return nil
}
charClass.WriteByte(k[i])
i++
charClass.WriteByte(k[i])
continue
}
charClass.WriteByte(k[i])
}
if charClass.Len() == 0 {
// '[]' is valid in Redis, but matches nothing.
return nil
}
re.WriteString(`\E[`)
re.Write(charClass.Bytes())
re.WriteString(`]\Q`)
case '\\':
if i == len(k)-1 {
// Ends with a '\'. U-huh.
return nil
}
// Forget the \, keep the next char.
i++
re.WriteByte(k[i])
continue
default:
re.WriteByte(p)
}
}
re.WriteString(`\E$`)
return regexp.MustCompile(re.String())
}
// matchKeys filters only matching keys.
// The returned boolean is whether the match pattern was valid
func matchKeys(keys []string, match string) ([]string, bool) {
re := patternRE(match)
if re == nil {
// Special case: the given pattern won't match anything or is invalid.
return nil, false
}
var res []string
for _, k := range keys {
if !re.MatchString(k) {
continue
}
res = append(res, k)
}
return res, true
}
+281
View File
@@ -0,0 +1,281 @@
package miniredis
import (
"bufio"
"bytes"
"fmt"
"strings"
lua "github.com/yuin/gopher-lua"
"github.com/alicebob/miniredis/v2/server"
)
var luaRedisConstants = map[string]lua.LValue{
"LOG_DEBUG": lua.LNumber(0),
"LOG_VERBOSE": lua.LNumber(1),
"LOG_NOTICE": lua.LNumber(2),
"LOG_WARNING": lua.LNumber(3),
}
func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFunction, map[string]lua.LValue) {
mkCall := func(failFast bool) func(l *lua.LState) int {
// one server.Ctx for a single Lua run
pCtx := &connCtx{}
if getCtx(c).authenticated {
pCtx.authenticated = true
}
pCtx.nested = true
pCtx.nestedSHA = sha
pCtx.selectedDB = getCtx(c).selectedDB
return func(l *lua.LState) int {
top := l.GetTop()
if top == 0 {
l.Error(lua.LString(fmt.Sprintf("Please specify at least one argument for this redis lib call script: %s, &c.", sha)), 1)
return 0
}
var args []string
for i := 1; i <= top; i++ {
switch a := l.Get(i).(type) {
case lua.LNumber:
args = append(args, a.String())
case lua.LString:
args = append(args, string(a))
default:
l.Error(lua.LString(fmt.Sprintf("Lua redis lib command arguments must be strings or integers script: %s, &c.", sha)), 1)
return 0
}
}
if len(args) == 0 {
l.Error(lua.LString(msgNotFromScripts(sha)), 1)
return 0
}
buf := &bytes.Buffer{}
wr := bufio.NewWriter(buf)
peer := server.NewPeer(wr)
peer.Ctx = pCtx
srv.Dispatch(peer, args)
wr.Flush()
res, err := server.ParseReply(bufio.NewReader(buf))
if err != nil {
if failFast {
// call() mode
if strings.Contains(err.Error(), "ERR unknown command") {
l.Error(lua.LString(fmt.Sprintf("Unknown Redis command called from script script: %s, &c.", sha)), 1)
} else {
l.Error(lua.LString(err.Error()), 1)
}
return 0
}
// pcall() mode
l.Push(lua.LNil)
return 1
}
if res == nil {
l.Push(lua.LFalse)
} else {
switch r := res.(type) {
case int64:
l.Push(lua.LNumber(r))
case int:
l.Push(lua.LNumber(r))
case []uint8:
l.Push(lua.LString(string(r)))
case []interface{}:
l.Push(redisToLua(l, r))
case server.Simple:
l.Push(luaStatusReply(string(r)))
case string:
l.Push(lua.LString(r))
case error:
l.Error(lua.LString(r.Error()), 1)
return 0
default:
panic(fmt.Sprintf("type not handled (%T)", r))
}
}
return 1
}
}
return map[string]lua.LGFunction{
"call": mkCall(true),
"pcall": mkCall(false),
"error_reply": func(l *lua.LState) int {
v := l.Get(1)
msg, ok := v.(lua.LString)
if !ok {
l.Error(lua.LString("wrong number or type of arguments"), 1)
return 0
}
res := &lua.LTable{}
parts := strings.SplitN(msg.String(), " ", 2)
// '-' at the beginging will be added as a part of error response
if parts[0] != "" && parts[0][0] == '-' {
parts[0] = parts[0][1:]
}
var final_msg string
if len(parts) == 2 {
final_msg = fmt.Sprintf("%s %s", parts[0], parts[1])
} else {
final_msg = fmt.Sprintf("ERR %s", parts[0])
}
res.RawSetString("err", lua.LString(final_msg))
l.Push(res)
return 1
},
"log": func(l *lua.LState) int {
level := l.CheckInt(1)
msg := l.CheckString(2)
_, _ = level, msg
// do nothing by default. To see logs uncomment:
// fmt.Printf("%v: %v", level, msg)
return 0
},
"status_reply": func(l *lua.LState) int {
v := l.Get(1)
msg, ok := v.(lua.LString)
if !ok {
l.Error(lua.LString("wrong number or type of arguments"), 1)
return 0
}
res := luaStatusReply(string(msg))
l.Push(res)
return 1
},
"sha1hex": func(l *lua.LState) int {
top := l.GetTop()
if top != 1 {
l.Error(lua.LString("wrong number of arguments"), 1)
return 0
}
msg := lua.LVAsString(l.Get(1))
l.Push(lua.LString(sha1Hex(msg)))
return 1
},
"replicate_commands": func(l *lua.LState) int {
// always succeeds since 7.0.0
l.Push(lua.LTrue)
return 1
},
"set_repl": func(l *lua.LState) int {
top := l.GetTop()
if top != 1 {
l.Error(lua.LString("wrong number of arguments"), 1)
return 0
}
// ignored
return 1
},
"setresp": func(l *lua.LState) int {
level := l.CheckInt(1)
toresp3 := false
switch level {
case 2:
toresp3 = false
case 3:
toresp3 = true
default:
l.Error(lua.LString("RESP version must be 2 or 3"), 1)
return 0
}
c.SwitchResp3 = &toresp3
return 0
},
}, luaRedisConstants
}
func luaToRedis(l *lua.LState, c *server.Peer, value lua.LValue) {
if value == nil {
c.WriteNull()
return
}
switch t := value.(type) {
case *lua.LNilType:
c.WriteNull()
case lua.LBool:
if lua.LVAsBool(value) {
c.WriteInt(1)
} else {
c.WriteNull()
}
case lua.LNumber:
c.WriteInt(int(lua.LVAsNumber(value)))
case lua.LString:
s := lua.LVAsString(value)
c.WriteBulk(s)
case *lua.LTable:
// special case for tables with an 'err' or 'ok' field
// note: according to the docs this only counts when 'err' or 'ok' is
// the only field.
if s := t.RawGetString("err"); s.Type() != lua.LTNil {
c.WriteError(s.String())
return
}
if s := t.RawGetString("ok"); s.Type() != lua.LTNil {
c.WriteInline(s.String())
return
}
result := []lua.LValue{}
for j := 1; true; j++ {
val := l.GetTable(value, lua.LNumber(j))
if val == nil {
result = append(result, val)
continue
}
if val.Type() == lua.LTNil {
break
}
result = append(result, val)
}
c.WriteLen(len(result))
for _, r := range result {
luaToRedis(l, c, r)
}
default:
panic(fmt.Sprintf("wat: %T", t))
}
}
func redisToLua(l *lua.LState, res []interface{}) *lua.LTable {
rettb := l.NewTable()
for _, e := range res {
var v lua.LValue
if e == nil {
v = lua.LFalse
} else {
switch et := e.(type) {
case int:
v = lua.LNumber(et)
case int64:
v = lua.LNumber(et)
case []uint8:
v = lua.LString(string(et))
case []interface{}:
v = redisToLua(l, et)
case string:
v = lua.LString(et)
default:
// TODO: oops?
v = lua.LString(e.(string))
}
}
l.RawSet(rettb, lua.LNumber(rettb.Len()+1), v)
}
return rettb
}
func luaStatusReply(msg string) *lua.LTable {
tab := &lua.LTable{}
tab.RawSetString("ok", lua.LString(msg))
return tab
}
+24
View File
@@ -0,0 +1,24 @@
This package is a mechanical translation of the reference C++ code for
MetroHash, available at https://github.com/jandrewrogers/MetroHash
The MIT License (MIT)
Copyright (c) 2016 Damian Gryski
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+1
View File
@@ -0,0 +1 @@
This is a partial copy of github.com/dgryski/go-metro.
+87
View File
@@ -0,0 +1,87 @@
package metro
import "encoding/binary"
func Hash64(buffer []byte, seed uint64) uint64 {
const (
k0 = 0xD6D018F5
k1 = 0xA2AA033B
k2 = 0x62992FC1
k3 = 0x30BC5B29
)
ptr := buffer
hash := (seed + k2) * k0
if len(ptr) >= 32 {
v := [4]uint64{hash, hash, hash, hash}
for len(ptr) >= 32 {
v[0] += binary.LittleEndian.Uint64(ptr[:8]) * k0
v[0] = rotate_right(v[0], 29) + v[2]
v[1] += binary.LittleEndian.Uint64(ptr[8:16]) * k1
v[1] = rotate_right(v[1], 29) + v[3]
v[2] += binary.LittleEndian.Uint64(ptr[16:24]) * k2
v[2] = rotate_right(v[2], 29) + v[0]
v[3] += binary.LittleEndian.Uint64(ptr[24:32]) * k3
v[3] = rotate_right(v[3], 29) + v[1]
ptr = ptr[32:]
}
v[2] ^= rotate_right(((v[0]+v[3])*k0)+v[1], 37) * k1
v[3] ^= rotate_right(((v[1]+v[2])*k1)+v[0], 37) * k0
v[0] ^= rotate_right(((v[0]+v[2])*k0)+v[3], 37) * k1
v[1] ^= rotate_right(((v[1]+v[3])*k1)+v[2], 37) * k0
hash += v[0] ^ v[1]
}
if len(ptr) >= 16 {
v0 := hash + (binary.LittleEndian.Uint64(ptr[:8]) * k2)
v0 = rotate_right(v0, 29) * k3
v1 := hash + (binary.LittleEndian.Uint64(ptr[8:16]) * k2)
v1 = rotate_right(v1, 29) * k3
v0 ^= rotate_right(v0*k0, 21) + v1
v1 ^= rotate_right(v1*k3, 21) + v0
hash += v1
ptr = ptr[16:]
}
if len(ptr) >= 8 {
hash += binary.LittleEndian.Uint64(ptr[:8]) * k3
ptr = ptr[8:]
hash ^= rotate_right(hash, 55) * k1
}
if len(ptr) >= 4 {
hash += uint64(binary.LittleEndian.Uint32(ptr[:4])) * k3
hash ^= rotate_right(hash, 26) * k1
ptr = ptr[4:]
}
if len(ptr) >= 2 {
hash += uint64(binary.LittleEndian.Uint16(ptr[:2])) * k3
ptr = ptr[2:]
hash ^= rotate_right(hash, 48) * k1
}
if len(ptr) >= 1 {
hash += uint64(ptr[0]) * k3
hash ^= rotate_right(hash, 37) * k1
}
hash ^= rotate_right(hash, 28)
hash *= k0
hash ^= rotate_right(hash, 29)
return hash
}
func Hash64Str(buffer string, seed uint64) uint64 {
return Hash64([]byte(buffer), seed)
}
func rotate_right(v uint64, k uint) uint64 {
return (v >> k) | (v << (64 - k))
}
+759
View File
@@ -0,0 +1,759 @@
// Package miniredis is a pure Go Redis test server, for use in Go unittests.
// There are no dependencies on system binaries, and every server you start
// will be empty.
//
// import "github.com/alicebob/miniredis/v2"
//
// Start a server with `s := miniredis.RunT(t)`, it'll be shutdown via a t.Cleanup().
// Or do everything manual: `s, err := miniredis.Run(); defer s.Close()`
//
// Point your Redis client to `s.Addr()` or `s.Host(), s.Port()`.
//
// Set keys directly via s.Set(...) and similar commands, or use a Redis client.
//
// For direct use you can select a Redis database with either `s.Select(12);
// s.Get("foo")` or `s.DB(12).Get("foo")`.
package miniredis
import (
"context"
"crypto/tls"
"fmt"
"math/rand"
"strconv"
"strings"
"sync"
"time"
"github.com/alicebob/miniredis/v2/proto"
"github.com/alicebob/miniredis/v2/server"
)
var DumpMaxLineLen = 60
type hashKey map[string]string
type listKey []string
type setKey map[string]struct{}
// RedisDB holds a single (numbered) Redis database.
type RedisDB struct {
master *Miniredis // pointer to the lock in Miniredis
id int // db id
keys map[string]string // Master map of keys with their type
stringKeys map[string]string // GET/SET &c. keys
hashKeys map[string]hashKey // MGET/MSET &c. keys
listKeys map[string]listKey // LPUSH &c. keys
setKeys map[string]setKey // SADD &c. keys
hllKeys map[string]*hll // PFADD &c. keys
sortedsetKeys map[string]sortedSet // ZADD &c. keys
streamKeys map[string]*streamKey // XADD &c. keys
ttl map[string]time.Duration // effective TTL values
lru map[string]time.Time // last recently used ( read or written to )
keyVersion map[string]uint // used to watch values
}
// Miniredis is a Redis server implementation.
type Miniredis struct {
sync.Mutex
srv *server.Server
port int
passwords map[string]string // username password
dbs map[int]*RedisDB
selectedDB int // DB id used in the direct Get(), Set() &c.
scripts map[string]string // sha1 -> lua src
signal *sync.Cond
now time.Time // time.Now() if not set.
subscribers map[*Subscriber]struct{}
rand *rand.Rand
Ctx context.Context
CtxCancel context.CancelFunc
}
type txCmd func(*server.Peer, *connCtx)
// database id + key combo
type dbKey struct {
db int
key string
}
// connCtx has all state for a single connection.
// (this struct was named before context.Context existed)
type connCtx struct {
selectedDB int // selected DB
authenticated bool // auth enabled and a valid AUTH seen
transaction []txCmd // transaction callbacks. Or nil.
dirtyTransaction bool // any error during QUEUEing
watch map[dbKey]uint // WATCHed keys
subscriber *Subscriber // client is in PUBSUB mode if not nil
nested bool // this is called via Lua
nestedSHA string // set to the SHA of the nesting function
}
// NewMiniRedis makes a new, non-started, Miniredis object.
func NewMiniRedis() *Miniredis {
m := Miniredis{
dbs: map[int]*RedisDB{},
scripts: map[string]string{},
subscribers: map[*Subscriber]struct{}{},
}
m.Ctx, m.CtxCancel = context.WithCancel(context.Background())
m.signal = sync.NewCond(&m)
return &m
}
func newRedisDB(id int, m *Miniredis) RedisDB {
return RedisDB{
id: id,
master: m,
keys: map[string]string{},
lru: map[string]time.Time{},
stringKeys: map[string]string{},
hashKeys: map[string]hashKey{},
listKeys: map[string]listKey{},
setKeys: map[string]setKey{},
hllKeys: map[string]*hll{},
sortedsetKeys: map[string]sortedSet{},
streamKeys: map[string]*streamKey{},
ttl: map[string]time.Duration{},
keyVersion: map[string]uint{},
}
}
// Run creates and Start()s a Miniredis.
func Run() (*Miniredis, error) {
m := NewMiniRedis()
return m, m.Start()
}
// Run creates and Start()s a Miniredis, TLS version.
func RunTLS(cfg *tls.Config) (*Miniredis, error) {
m := NewMiniRedis()
return m, m.StartTLS(cfg)
}
// Tester is a minimal version of a testing.T
type Tester interface {
Fatalf(string, ...interface{})
Cleanup(func())
Logf(format string, args ...interface{})
}
// RunT start a new miniredis, pass it a testing.T. It also registers the cleanup after your test is done.
func RunT(t Tester) *Miniredis {
m := NewMiniRedis()
if err := m.Start(); err != nil {
t.Fatalf("could not start miniredis: %s", err)
// not reached
}
t.Cleanup(m.Close)
return m
}
func runWithClient(t Tester) (*Miniredis, *proto.Client) {
m := RunT(t)
c, err := proto.Dial(m.Addr())
if err != nil {
t.Fatalf("could not connect to miniredis: %s", err)
}
t.Cleanup(func() {
if err = c.Close(); err != nil {
t.Logf("error closing connection to miniredis: %s", err)
}
})
return m, c
}
// Start starts a server. It listens on a random port on localhost. See also
// Addr().
func (m *Miniredis) Start() error {
s, err := server.NewServer(fmt.Sprintf("127.0.0.1:%d", m.port))
if err != nil {
return err
}
return m.start(s)
}
// Start starts a server, TLS version.
func (m *Miniredis) StartTLS(cfg *tls.Config) error {
s, err := server.NewServerTLS(fmt.Sprintf("127.0.0.1:%d", m.port), cfg)
if err != nil {
return err
}
return m.start(s)
}
// StartAddr runs miniredis with a given addr. Examples: "127.0.0.1:6379",
// ":6379", or "127.0.0.1:0"
func (m *Miniredis) StartAddr(addr string) error {
s, err := server.NewServer(addr)
if err != nil {
return err
}
return m.start(s)
}
// StartAddrTLS runs miniredis with a given addr, TLS version.
func (m *Miniredis) StartAddrTLS(addr string, cfg *tls.Config) error {
s, err := server.NewServerTLS(addr, cfg)
if err != nil {
return err
}
return m.start(s)
}
func (m *Miniredis) start(s *server.Server) error {
m.Lock()
defer m.Unlock()
m.srv = s
m.port = s.Addr().Port
commandsConnection(m)
commandsGeneric(m)
commandsServer(m)
commandsString(m)
commandsHash(m)
commandsList(m)
commandsPubsub(m)
commandsSet(m)
commandsSortedSet(m)
commandsStream(m)
commandsTransaction(m)
commandsScripting(m)
commandsGeo(m)
commandsCluster(m)
commandsHll(m)
commandsClient(m)
commandsObject(m)
return nil
}
// Restart restarts a Close()d server on the same port. Values will be
// preserved.
func (m *Miniredis) Restart() error {
return m.Start()
}
// Close shuts down a Miniredis.
func (m *Miniredis) Close() {
m.Lock()
if m.srv == nil {
m.Unlock()
return
}
srv := m.srv
m.srv = nil
m.CtxCancel()
m.Unlock()
// the OnDisconnect callbacks can lock m, so run Close() outside the lock.
srv.Close()
}
// RequireAuth makes every connection need to AUTH first. This is the old 'AUTH [password] command.
// Remove it by setting an empty string.
func (m *Miniredis) RequireAuth(pw string) {
m.RequireUserAuth("default", pw)
}
// Add a username/password, for use with 'AUTH [username] [password]'.
// There are currently no access controls for commands implemented.
// Disable access for the user with an empty password.
func (m *Miniredis) RequireUserAuth(username, pw string) {
m.Lock()
defer m.Unlock()
if m.passwords == nil {
m.passwords = map[string]string{}
}
if pw == "" {
delete(m.passwords, username)
return
}
m.passwords[username] = pw
}
// DB returns a DB by ID.
func (m *Miniredis) DB(i int) *RedisDB {
m.Lock()
defer m.Unlock()
return m.db(i)
}
// get DB. No locks!
func (m *Miniredis) db(i int) *RedisDB {
if db, ok := m.dbs[i]; ok {
return db
}
db := newRedisDB(i, m) // main miniredis has our mutex.
m.dbs[i] = &db
return &db
}
// SwapDB swaps DBs by IDs.
func (m *Miniredis) SwapDB(i, j int) bool {
m.Lock()
defer m.Unlock()
return m.swapDB(i, j)
}
// swap DB. No locks!
func (m *Miniredis) swapDB(i, j int) bool {
db1 := m.db(i)
db2 := m.db(j)
db1.id = j
db2.id = i
m.dbs[i] = db2
m.dbs[j] = db1
return true
}
// Addr returns '127.0.0.1:12345'. Can be given to a Dial(). See also Host()
// and Port(), which return the same things.
func (m *Miniredis) Addr() string {
m.Lock()
defer m.Unlock()
return m.srv.Addr().String()
}
// Host returns the host part of Addr().
func (m *Miniredis) Host() string {
m.Lock()
defer m.Unlock()
return m.srv.Addr().IP.String()
}
// Port returns the (random) port part of Addr().
func (m *Miniredis) Port() string {
m.Lock()
defer m.Unlock()
return strconv.Itoa(m.srv.Addr().Port)
}
// CommandCount returns the number of processed commands.
func (m *Miniredis) CommandCount() int {
m.Lock()
defer m.Unlock()
return int(m.srv.TotalCommands())
}
// CurrentConnectionCount returns the number of currently connected clients.
func (m *Miniredis) CurrentConnectionCount() int {
m.Lock()
defer m.Unlock()
return m.srv.ClientsLen()
}
// TotalConnectionCount returns the number of client connections since server start.
func (m *Miniredis) TotalConnectionCount() int {
m.Lock()
defer m.Unlock()
return int(m.srv.TotalConnections())
}
// FastForward decreases all TTLs by the given duration. All TTLs <= 0 will be
// expired.
func (m *Miniredis) FastForward(duration time.Duration) {
m.Lock()
defer m.Unlock()
for _, db := range m.dbs {
db.fastForward(duration)
}
}
// Server returns the underlying server to allow custom commands to be implemented
func (m *Miniredis) Server() *server.Server {
return m.srv
}
// Dump returns a text version of the selected DB, usable for debugging.
//
// Dump limits the maximum length of each key:value to "DumpMaxLineLen" characters.
// To increase that, call something like:
//
// miniredis.DumpMaxLineLen = 1024
// mr, _ = miniredis.Run()
// mr.Dump()
func (m *Miniredis) Dump() string {
m.Lock()
defer m.Unlock()
var (
maxLen = DumpMaxLineLen
indent = " "
db = m.db(m.selectedDB)
r = ""
v = func(s string) string {
suffix := ""
if len(s) > maxLen {
suffix = fmt.Sprintf("...(%d)", len(s))
s = s[:maxLen-len(suffix)]
}
return fmt.Sprintf("%q%s", s, suffix)
}
)
for _, k := range db.allKeys() {
r += fmt.Sprintf("- %s\n", k)
t := db.t(k)
switch t {
case keyTypeString:
r += fmt.Sprintf("%s%s\n", indent, v(db.stringKeys[k]))
case keyTypeHash:
for _, hk := range db.hashFields(k) {
r += fmt.Sprintf("%s%s: %s\n", indent, hk, v(db.hashGet(k, hk)))
}
case keyTypeList:
for _, lk := range db.listKeys[k] {
r += fmt.Sprintf("%s%s\n", indent, v(lk))
}
case keyTypeSet:
for _, mk := range db.setMembers(k) {
r += fmt.Sprintf("%s%s\n", indent, v(mk))
}
case keyTypeSortedSet:
for _, el := range db.ssetElements(k) {
r += fmt.Sprintf("%s%f: %s\n", indent, el.score, v(el.member))
}
case keyTypeStream:
for _, entry := range db.streamKeys[k].entries {
r += fmt.Sprintf("%s%s\n", indent, entry.ID)
ev := entry.Values
for i := 0; i < len(ev)/2; i++ {
r += fmt.Sprintf("%s%s%s: %s\n", indent, indent, v(ev[2*i]), v(ev[2*i+1]))
}
}
case keyTypeHll:
for _, entry := range db.hllKeys {
r += fmt.Sprintf("%s%s\n", indent, v(string(entry.Bytes())))
}
default:
r += fmt.Sprintf("%s(a %s, fixme!)\n", indent, t)
}
}
return r
}
// SetTime sets the time against which EXPIREAT values are compared, and the
// time used in stream entry IDs. Will use time.Now() if this is not set.
func (m *Miniredis) SetTime(t time.Time) {
m.Lock()
defer m.Unlock()
m.now = t
}
// make every command return this message. For example:
//
// LOADING Redis is loading the dataset in memory
// MASTERDOWN Link with MASTER is down and replica-serve-stale-data is set to 'no'.
//
// Clear it with an empty string. Don't add newlines.
func (m *Miniredis) SetError(msg string) {
cb := server.Hook(nil)
if msg != "" {
cb = func(c *server.Peer, cmd string, args ...string) bool {
c.WriteError(msg)
return true
}
}
m.srv.SetPreHook(cb)
}
// isValidCMD returns true if command is valid and can be executed.
func (m *Miniredis) isValidCMD(c *server.Peer, cmd string) bool {
if !m.handleAuth(c) {
return false
}
if m.checkPubsub(c, cmd) {
return false
}
return true
}
// handleAuth returns false if connection has no access. It sends the reply.
func (m *Miniredis) handleAuth(c *server.Peer) bool {
if getCtx(c).nested {
return true
}
m.Lock()
defer m.Unlock()
if len(m.passwords) == 0 {
return true
}
if !getCtx(c).authenticated {
c.WriteError("NOAUTH Authentication required.")
return false
}
return true
}
// handlePubsub sends an error to the user if the connection is in PUBSUB mode.
// It'll return true if it did.
func (m *Miniredis) checkPubsub(c *server.Peer, cmd string) bool {
if getCtx(c).nested {
return false
}
m.Lock()
defer m.Unlock()
ctx := getCtx(c)
if ctx.subscriber == nil {
return false
}
prefix := "ERR "
if strings.ToLower(cmd) == "exec" {
prefix = "EXECABORT Transaction discarded because of: "
}
c.WriteError(fmt.Sprintf(
"%sCan't execute '%s': only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT are allowed in this context",
prefix,
strings.ToLower(cmd),
))
return true
}
func getCtx(c *server.Peer) *connCtx {
if c.Ctx == nil {
c.Ctx = &connCtx{}
}
return c.Ctx.(*connCtx)
}
func startTx(ctx *connCtx) {
ctx.transaction = []txCmd{}
ctx.dirtyTransaction = false
}
func stopTx(ctx *connCtx) {
ctx.transaction = nil
unwatch(ctx)
}
func inTx(ctx *connCtx) bool {
return ctx.transaction != nil
}
func addTxCmd(ctx *connCtx, cb txCmd) {
ctx.transaction = append(ctx.transaction, cb)
}
func watch(db *RedisDB, ctx *connCtx, key string) {
if ctx.watch == nil {
ctx.watch = map[dbKey]uint{}
}
ctx.watch[dbKey{db: db.id, key: key}] = db.keyVersion[key] // Can be 0.
}
func unwatch(ctx *connCtx) {
ctx.watch = nil
}
// setDirty can be called even when not in an tx. Is an no-op then.
func setDirty(c *server.Peer) {
if c.Ctx == nil {
// No transaction. Not relevant.
return
}
getCtx(c).dirtyTransaction = true
}
func (m *Miniredis) addSubscriber(s *Subscriber) {
m.subscribers[s] = struct{}{}
}
// closes and remove the subscriber.
func (m *Miniredis) removeSubscriber(s *Subscriber) {
_, ok := m.subscribers[s]
delete(m.subscribers, s)
if ok {
s.Close()
}
}
func (m *Miniredis) publish(c, msg string) int {
n := 0
for s := range m.subscribers {
n += s.Publish(c, msg)
}
return n
}
// enter 'subscribed state', or return the existing one.
func (m *Miniredis) subscribedState(c *server.Peer) *Subscriber {
ctx := getCtx(c)
sub := ctx.subscriber
if sub != nil {
return sub
}
sub = newSubscriber()
m.addSubscriber(sub)
c.OnDisconnect(func() {
m.Lock()
m.removeSubscriber(sub)
m.Unlock()
})
ctx.subscriber = sub
go monitorPublish(c, sub.publish)
go monitorPpublish(c, sub.ppublish)
return sub
}
// whenever the p?sub count drops to 0 subscribed state should be stopped, and
// all redis commands are allowed again.
func endSubscriber(m *Miniredis, c *server.Peer) {
ctx := getCtx(c)
if sub := ctx.subscriber; sub != nil {
m.removeSubscriber(sub) // will Close() the sub
}
ctx.subscriber = nil
}
// Start a new pubsub subscriber. It can (un) subscribe to channels and
// patterns, and has a channel to get published messages. Close it with
// Close().
// Does not close itself when there are no subscriptions left.
func (m *Miniredis) NewSubscriber() *Subscriber {
sub := newSubscriber()
m.Lock()
m.addSubscriber(sub)
m.Unlock()
return sub
}
func (m *Miniredis) allSubscribers() []*Subscriber {
var subs []*Subscriber
for s := range m.subscribers {
subs = append(subs, s)
}
return subs
}
func (m *Miniredis) Seed(seed int) {
m.Lock()
defer m.Unlock()
// m.rand is not safe for concurrent use.
m.rand = rand.New(rand.NewSource(int64(seed)))
}
func (m *Miniredis) randIntn(n int) int {
if m.rand == nil {
return rand.Intn(n)
}
return m.rand.Intn(n)
}
// shuffle shuffles a list of strings. Kinda.
func (m *Miniredis) shuffle(l []string) {
for range l {
i := m.randIntn(len(l))
j := m.randIntn(len(l))
l[i], l[j] = l[j], l[i]
}
}
func (m *Miniredis) effectiveNow() time.Time {
if !m.now.IsZero() {
return m.now
}
return time.Now().UTC()
}
// convert a unixtimestamp to a duration, to use an absolute time as TTL.
// d can be either time.Second or time.Millisecond.
func (m *Miniredis) at(i int, d time.Duration) time.Duration {
var ts time.Time
switch d {
case time.Millisecond:
ts = time.Unix(int64(i/1000), 1000000*int64(i%1000))
case time.Second:
ts = time.Unix(int64(i), 0)
default:
panic("invalid time unit (d). Fixme!")
}
now := m.effectiveNow()
return ts.Sub(now)
}
// copy does not mind if dst already exists.
func (m *Miniredis) copy(
srcDB *RedisDB, src string,
destDB *RedisDB, dst string,
) error {
if !srcDB.exists(src) {
return ErrKeyNotFound
}
switch srcDB.t(src) {
case keyTypeString:
destDB.stringKeys[dst] = srcDB.stringKeys[src]
case keyTypeHash:
destDB.hashKeys[dst] = copyHashKey(srcDB.hashKeys[src])
case keyTypeList:
destDB.listKeys[dst] = copyListKey(srcDB.listKeys[src])
case keyTypeSet:
destDB.setKeys[dst] = copySetKey(srcDB.setKeys[src])
case keyTypeSortedSet:
destDB.sortedsetKeys[dst] = copySortedSet(srcDB.sortedsetKeys[src])
case keyTypeStream:
destDB.streamKeys[dst] = srcDB.streamKeys[src].copy()
case keyTypeHll:
destDB.hllKeys[dst] = srcDB.hllKeys[src].copy()
default:
panic("missing case")
}
destDB.keys[dst] = srcDB.keys[src]
destDB.incr(dst)
if v, ok := srcDB.ttl[src]; ok {
destDB.ttl[dst] = v
}
return nil
}
func copyHashKey(orig hashKey) hashKey {
cpy := hashKey{}
for k, v := range orig {
cpy[k] = v
}
return cpy
}
func copyListKey(orig listKey) listKey {
cpy := make(listKey, len(orig))
copy(cpy, orig)
return cpy
}
func copySetKey(orig setKey) setKey {
cpy := setKey{}
for k, v := range orig {
cpy[k] = v
}
return cpy
}
func copySortedSet(orig sortedSet) sortedSet {
cpy := sortedSet{}
for k, v := range orig {
cpy[k] = v
}
return cpy
}
+60
View File
@@ -0,0 +1,60 @@
package miniredis
import (
"errors"
"math"
"strconv"
"time"
"github.com/alicebob/miniredis/v2/server"
)
// optInt parses an int option in a command.
// Writes "invalid integer" error to c if it's not a valid integer. Returns
// whether or not things were okay.
func optInt(c *server.Peer, src string, dest *int) bool {
return optIntErr(c, src, dest, msgInvalidInt)
}
func optIntErr(c *server.Peer, src string, dest *int, errMsg string) bool {
n, err := strconv.Atoi(src)
if err != nil {
setDirty(c)
c.WriteError(errMsg)
return false
}
*dest = n
return true
}
// optIntSimple sets dest or returns an error
func optIntSimple(src string, dest *int) error {
n, err := strconv.Atoi(src)
if err != nil {
return errors.New(msgInvalidInt)
}
*dest = n
return nil
}
func optDuration(c *server.Peer, src string, dest *time.Duration) bool {
n, err := strconv.ParseFloat(src, 64)
if err != nil {
setDirty(c)
c.WriteError(msgInvalidTimeout)
return false
}
if n < 0 {
setDirty(c)
c.WriteError(msgTimeoutNegative)
return false
}
if math.IsInf(n, 0) {
setDirty(c)
c.WriteError(msgTimeoutIsOutOfRange)
return false
}
*dest = time.Duration(n*1_000_000) * time.Microsecond
return true
}

Some files were not shown because too many files have changed in this diff Show More