mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Add redis support for distributed caching
This commit is contained in:
+295
@@ -31,6 +31,7 @@ summary: |
|
||||
- Flexible configuration with multiple deployment scenarios
|
||||
- Memory-efficient operation with automatic cleanup
|
||||
- Extensive logging and debugging capabilities
|
||||
- Redis cache support for multi-replica deployments with automatic failover
|
||||
It supports various authentication scenarios including:
|
||||
|
||||
- Basic authentication with customizable callback and logout URLs
|
||||
@@ -137,6 +138,42 @@ testData:
|
||||
X-Custom-Header: "production"
|
||||
X-API-Version: "v1"
|
||||
|
||||
# Example with Redis cache for multi-replica deployments
|
||||
testDataWithRedis:
|
||||
# Required OIDC parameters (same as standard configuration)
|
||||
providerURL: https://auth.example.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
callbackURL: /oauth2/callback
|
||||
sessionEncryptionKey: your-64-character-encryption-key-at-least-32-bytes
|
||||
|
||||
# Standard optional parameters
|
||||
logLevel: info
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
|
||||
# Redis cache configuration for multi-replica support
|
||||
redis:
|
||||
enabled: true # Enable Redis caching
|
||||
address: "redis:6379" # Redis server address
|
||||
password: "redis-password" # Redis authentication password
|
||||
db: 0 # Redis database number (0-15)
|
||||
keyPrefix: "traefikoidc:" # Prefix for all Redis keys
|
||||
cacheMode: "hybrid" # Cache mode: redis, hybrid, or memory
|
||||
poolSize: 20 # Maximum number of connections
|
||||
connectTimeout: 5 # Connection timeout in seconds
|
||||
readTimeout: 3 # Read operation timeout
|
||||
writeTimeout: 3 # Write operation timeout
|
||||
enableTLS: false # Use TLS for Redis connection
|
||||
tlsSkipVerify: false # Skip TLS certificate verification
|
||||
hybridL1Size: 500 # L1 cache size for hybrid mode
|
||||
hybridL1MemoryMB: 10 # L1 memory limit for hybrid mode
|
||||
enableCircuitBreaker: true # Enable circuit breaker
|
||||
circuitBreakerThreshold: 5 # Failures before opening circuit
|
||||
circuitBreakerTimeout: 60 # Timeout before retry (seconds)
|
||||
enableHealthCheck: true # Enable periodic health checks
|
||||
healthCheckInterval: 30 # Health check interval (seconds)
|
||||
|
||||
# --- Common Configuration Examples ---
|
||||
#
|
||||
# 🔒 HIGH-SECURITY CONFIGURATION
|
||||
@@ -1138,3 +1175,261 @@ configuration:
|
||||
|
||||
Prevents your resources from being embedded on other sites.
|
||||
required: false
|
||||
|
||||
redis:
|
||||
type: object
|
||||
description: |
|
||||
Optional Redis cache configuration for multi-replica deployments.
|
||||
|
||||
When running multiple Traefik instances, Redis provides shared caching to:
|
||||
- Prevent JTI replay detection false positives across replicas
|
||||
- Share token verification results between instances
|
||||
- Maintain consistent session state across the cluster
|
||||
- Improve performance by reducing redundant OIDC provider calls
|
||||
|
||||
Features:
|
||||
- Automatic failover to memory-only mode when Redis is unavailable
|
||||
- Circuit breaker pattern for resilience against Redis failures
|
||||
- Health checking with automatic recovery
|
||||
- Multiple cache modes: redis-only, hybrid (L1 memory + L2 Redis), memory-only
|
||||
- Configurable timeouts and connection pooling
|
||||
- TLS support for secure Redis connections
|
||||
|
||||
The middleware gracefully handles Redis failures by falling back to in-memory
|
||||
caching, ensuring your authentication flow continues even during Redis outages.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
enableCircuitBreaker: true
|
||||
```
|
||||
required: false
|
||||
properties:
|
||||
enabled:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable Redis caching for distributed session and token management.
|
||||
When enabled, the middleware will attempt to connect to Redis and use it
|
||||
for shared state across multiple Traefik instances.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
address:
|
||||
type: string
|
||||
description: |
|
||||
Redis server address in host:port format.
|
||||
|
||||
Examples:
|
||||
- "redis:6379" (Docker/Kubernetes service)
|
||||
- "localhost:6379" (local Redis)
|
||||
- "redis.example.com:6380" (custom host/port)
|
||||
- "redis-cluster.default.svc.cluster.local:6379" (Kubernetes)
|
||||
|
||||
Required when Redis is enabled.
|
||||
required: false
|
||||
|
||||
password:
|
||||
type: string
|
||||
description: |
|
||||
Password for Redis authentication.
|
||||
Leave empty if Redis doesn't require authentication.
|
||||
|
||||
For Kubernetes deployments, you can use secret references:
|
||||
urn:k8s:secret:namespace:secret-name:key
|
||||
|
||||
Default: "" (no authentication)
|
||||
required: false
|
||||
|
||||
db:
|
||||
type: integer
|
||||
description: |
|
||||
Redis database number to use (0-15).
|
||||
Different databases can be used to isolate data between environments.
|
||||
|
||||
Default: 0
|
||||
required: false
|
||||
|
||||
keyPrefix:
|
||||
type: string
|
||||
description: |
|
||||
Prefix for all Redis keys created by this middleware.
|
||||
Useful for:
|
||||
- Avoiding key collisions with other applications
|
||||
- Identifying keys for monitoring/debugging
|
||||
- Supporting multiple environments in the same Redis instance
|
||||
|
||||
Default: "traefikoidc:"
|
||||
required: false
|
||||
|
||||
cacheMode:
|
||||
type: string
|
||||
description: |
|
||||
Determines the caching strategy:
|
||||
|
||||
- "redis": Redis-only caching. All cache operations go directly to Redis.
|
||||
Best for: Consistent state across all replicas, minimal memory usage.
|
||||
|
||||
- "hybrid": Two-tier caching with in-memory L1 and Redis L2.
|
||||
Best for: High performance with shared state, reduced Redis load.
|
||||
L1 provides fast local cache, L2 provides shared state.
|
||||
|
||||
- "memory": Memory-only caching (Redis disabled even if configured).
|
||||
Best for: Single instance deployments, development/testing.
|
||||
|
||||
Default: "redis" (when Redis is enabled)
|
||||
required: false
|
||||
enum:
|
||||
- redis
|
||||
- hybrid
|
||||
- memory
|
||||
|
||||
poolSize:
|
||||
type: integer
|
||||
description: |
|
||||
Maximum number of socket connections to Redis.
|
||||
Higher values allow more concurrent operations but consume more resources.
|
||||
|
||||
Recommendations:
|
||||
- Small deployments: 10-20
|
||||
- Medium deployments: 20-50
|
||||
- Large deployments: 50-100
|
||||
|
||||
Default: 10
|
||||
required: false
|
||||
|
||||
connectTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Timeout in seconds for establishing new connections to Redis.
|
||||
Should be higher than network latency but low enough to fail fast.
|
||||
|
||||
Default: 5 seconds
|
||||
required: false
|
||||
|
||||
readTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Timeout in seconds for Redis read operations.
|
||||
Includes the time to send the command, wait for Redis to process it,
|
||||
and receive the response.
|
||||
|
||||
Default: 3 seconds
|
||||
required: false
|
||||
|
||||
writeTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Timeout in seconds for Redis write operations.
|
||||
Should account for network latency and Redis persistence settings.
|
||||
|
||||
Default: 3 seconds
|
||||
required: false
|
||||
|
||||
enableTLS:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable TLS encryption for Redis connections.
|
||||
Required when connecting to Redis instances that enforce TLS,
|
||||
such as AWS ElastiCache with encryption in transit.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
tlsSkipVerify:
|
||||
type: boolean
|
||||
description: |
|
||||
Skip TLS certificate verification for Redis connections.
|
||||
|
||||
⚠️ WARNING: Only use in development environments.
|
||||
This option bypasses certificate validation and should never be used
|
||||
in production as it's vulnerable to man-in-the-middle attacks.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
hybridL1Size:
|
||||
type: integer
|
||||
description: |
|
||||
Maximum number of items in the L1 (in-memory) cache for hybrid mode.
|
||||
Controls how many cache entries are kept in local memory before eviction.
|
||||
|
||||
Only applies when cacheMode is "hybrid".
|
||||
|
||||
Default: 500
|
||||
required: false
|
||||
|
||||
hybridL1MemoryMB:
|
||||
type: integer
|
||||
description: |
|
||||
Maximum memory in megabytes for L1 cache in hybrid mode.
|
||||
The cache will start evicting items when this limit is approached.
|
||||
|
||||
Only applies when cacheMode is "hybrid".
|
||||
|
||||
Default: 10 MB
|
||||
required: false
|
||||
|
||||
enableCircuitBreaker:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable circuit breaker pattern for Redis connection failures.
|
||||
|
||||
When enabled, the middleware will:
|
||||
1. Track Redis operation failures
|
||||
2. Open the circuit after threshold failures (stop trying Redis)
|
||||
3. Fall back to in-memory caching
|
||||
4. Periodically attempt to reconnect (half-open state)
|
||||
5. Resume Redis operations when connection recovers
|
||||
|
||||
This prevents cascading failures and improves resilience.
|
||||
|
||||
Default: true
|
||||
required: false
|
||||
|
||||
circuitBreakerThreshold:
|
||||
type: integer
|
||||
description: |
|
||||
Number of consecutive Redis failures before opening the circuit.
|
||||
Lower values make the system more sensitive to Redis issues,
|
||||
higher values tolerate more failures before switching to fallback.
|
||||
|
||||
Default: 5
|
||||
required: false
|
||||
|
||||
circuitBreakerTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Time in seconds to wait before attempting to close the circuit.
|
||||
After this timeout, the circuit breaker will allow one test request
|
||||
to Redis. If successful, normal operations resume.
|
||||
|
||||
Default: 60 seconds
|
||||
required: false
|
||||
|
||||
enableHealthCheck:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable periodic health checks for Redis connection.
|
||||
|
||||
Health checks:
|
||||
- Run in the background at regular intervals
|
||||
- Detect Redis availability without affecting request processing
|
||||
- Automatically reconnect when Redis becomes available
|
||||
- Update circuit breaker state based on health status
|
||||
|
||||
Default: true
|
||||
required: false
|
||||
|
||||
healthCheckInterval:
|
||||
type: integer
|
||||
description: |
|
||||
Interval in seconds between Redis health checks.
|
||||
Lower values detect issues faster but increase Redis load.
|
||||
Higher values reduce overhead but delay failure detection.
|
||||
|
||||
Default: 30 seconds
|
||||
required: false
|
||||
|
||||
@@ -133,6 +133,7 @@ The middleware supports the following configuration options:
|
||||
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
|
||||
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
|
||||
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
|
||||
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
|
||||
|
||||
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
|
||||
>
|
||||
@@ -520,12 +521,14 @@ When running multiple Traefik replicas with the OIDC plugin, you may encounter f
|
||||
- Request → Replica B → JTI NOT in Replica B's cache ✓
|
||||
- Request → Replica A → ❌ **FALSE POSITIVE**: "token replay detected"
|
||||
|
||||
**Solution**: Disable replay detection for distributed deployments:
|
||||
**Solution 1 (Simple)**: Disable replay detection for distributed deployments:
|
||||
|
||||
```yaml
|
||||
disableReplayDetection: true # Disable JTI replay detection for multi-replica setups
|
||||
```
|
||||
|
||||
**Solution 2 (Recommended)**: Use Redis cache backend for shared state (see [Redis Cache](#redis-cache-optional) section)
|
||||
|
||||
**Security Note**: When `disableReplayDetection: true`:
|
||||
- ✅ Token signatures still validated
|
||||
- ✅ Expiration still checked
|
||||
@@ -547,10 +550,158 @@ spec:
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
disableReplayDetection: true # Required for multi-replica deployments
|
||||
disableReplayDetection: true # Required for multi-replica deployments without Redis
|
||||
```
|
||||
|
||||
**Recommendation**: For single-instance deployments, leave this setting at `false` (default) to maintain replay attack protection. For multi-replica deployments, set to `true` and consider implementing a shared cache backend (Redis/Memcached) if replay detection is required.
|
||||
**Recommendation**: For single-instance deployments, leave this setting at `false` (default) to maintain replay attack protection. For multi-replica deployments, use the Redis cache backend for proper replay detection across all instances.
|
||||
|
||||
## Redis Cache (Optional)
|
||||
|
||||
The plugin supports optional Redis caching for multi-replica deployments. This solves issues with JTI replay detection and session management when running multiple Traefik instances behind a load balancer.
|
||||
|
||||
### Why Use Redis Cache?
|
||||
|
||||
When running multiple Traefik replicas, each instance maintains its own in-memory cache for:
|
||||
- JTI (JWT Token ID) replay detection
|
||||
- Session data
|
||||
- Token metadata
|
||||
|
||||
Without a shared cache, you may experience:
|
||||
- False positive replay detection errors
|
||||
- Session inconsistencies between replicas
|
||||
- Users needing to re-authenticate when hitting different instances
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
Redis is configured through Traefik's dynamic configuration (YAML, labels, etc.):
|
||||
|
||||
```yaml
|
||||
# Enable Redis cache in your middleware configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "localhost:6379"
|
||||
password: "your-password" # Optional
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
```
|
||||
|
||||
### Configuration Priority
|
||||
|
||||
The plugin uses the following priority for Redis configuration:
|
||||
|
||||
1. **Traefik Dynamic Configuration** (PRIMARY) - Configure via YAML files or Docker/Kubernetes labels
|
||||
2. **Environment Variables** (FALLBACK) - Used only when not set in Traefik config
|
||||
|
||||
This approach allows you to manage all settings through Traefik's configuration system while maintaining backward compatibility with environment variables.
|
||||
|
||||
### Configuration Options
|
||||
|
||||
| Parameter | Description | Default | Example |
|
||||
|-----------|-------------|---------|---------|
|
||||
| `enabled` | Enable Redis caching | `false` | `true` |
|
||||
| `address` | Redis server address | - | `redis:6379` |
|
||||
| `password` | Redis password | - | `secret` |
|
||||
| `db` | Database number | `0` | `1` |
|
||||
| `keyPrefix` | Key prefix for namespacing | `traefikoidc:` | `myapp:` |
|
||||
| `cacheMode` | Cache mode: `redis`, `hybrid`, `memory` | `redis` | `hybrid` |
|
||||
| `poolSize` | Connection pool size | `10` | `20` |
|
||||
| `connectTimeout` | Connection timeout (seconds) | `5` | `10` |
|
||||
| `readTimeout` | Read timeout (seconds) | `3` | `5` |
|
||||
| `writeTimeout` | Write timeout (seconds) | `3` | `5` |
|
||||
| `enableTLS` | Enable TLS | `false` | `true` |
|
||||
| `tlsSkipVerify` | Skip TLS verification | `false` | `true` |
|
||||
| `enableCircuitBreaker` | Circuit breaker for failures | `true` | `true` |
|
||||
| `circuitBreakerThreshold` | Failures before circuit opens | `5` | `10` |
|
||||
| `circuitBreakerTimeout` | Circuit reset timeout (seconds) | `60` | `30` |
|
||||
| `enableHealthCheck` | Periodic health checks | `true` | `true` |
|
||||
| `healthCheckInterval` | Health check interval (seconds) | `30` | `60` |
|
||||
|
||||
### Environment Variables (Fallback)
|
||||
|
||||
If not configured through Traefik, these environment variables can be used as fallback:
|
||||
|
||||
- `REDIS_ENABLED` - Enable Redis cache
|
||||
- `REDIS_ADDRESS` - Redis server address
|
||||
- `REDIS_PASSWORD` - Redis password
|
||||
- `REDIS_DB` - Database number
|
||||
- `REDIS_KEY_PREFIX` - Key prefix
|
||||
- `REDIS_CACHE_MODE` - Cache mode
|
||||
- `REDIS_POOL_SIZE` - Connection pool size
|
||||
- `REDIS_CONNECT_TIMEOUT` - Connection timeout
|
||||
- `REDIS_READ_TIMEOUT` - Read timeout
|
||||
- `REDIS_WRITE_TIMEOUT` - Write timeout
|
||||
- `REDIS_ENABLE_TLS` - Enable TLS
|
||||
- `REDIS_TLS_SKIP_VERIFY` - Skip TLS verification
|
||||
|
||||
### Cache Modes
|
||||
|
||||
The plugin supports three cache modes:
|
||||
|
||||
- **memory** (default): In-memory cache only, suitable for single-instance deployments
|
||||
- **redis**: Redis-only cache, all data stored in Redis
|
||||
- **hybrid**: Two-tier caching with local memory cache + Redis backend for optimal performance
|
||||
|
||||
### Example Configurations
|
||||
|
||||
#### Docker Compose with Redis
|
||||
|
||||
```yaml
|
||||
services:
|
||||
redis:
|
||||
image: redis:alpine
|
||||
command: redis-server --requirepass yourpassword
|
||||
|
||||
traefik:
|
||||
image: traefik:v3.2
|
||||
# ... rest of your Traefik configuration
|
||||
labels:
|
||||
# Configure the OIDC middleware with Redis
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.clientID=your-client-id"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.clientSecret=your-secret"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
|
||||
# Redis configuration via labels
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=yourpassword"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
|
||||
```
|
||||
|
||||
#### Kubernetes with Redis
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-redis
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-encryption-key
|
||||
callbackURL: /oauth2/callback
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis-service.redis-namespace:6379"
|
||||
password: "urn:k8s:secret:redis-secret:password"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
### Advanced Redis Configuration
|
||||
|
||||
See [Redis Cache Documentation](docs/REDIS_CACHE.md) for:
|
||||
- Detailed architecture overview
|
||||
- High availability setup with Redis Sentinel
|
||||
- Redis Cluster configuration
|
||||
- Performance tuning guidelines
|
||||
- Monitoring and observability
|
||||
- Troubleshooting guide
|
||||
- Migration from memory-only cache
|
||||
|
||||
## Usage Examples
|
||||
|
||||
|
||||
+28
-1
@@ -21,10 +21,37 @@ var (
|
||||
)
|
||||
|
||||
// GetGlobalCacheManager returns a singleton CacheManager instance
|
||||
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
|
||||
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
|
||||
return GetGlobalCacheManagerWithConfig(wg, nil)
|
||||
}
|
||||
|
||||
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
|
||||
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
|
||||
cacheManagerInitOnce.Do(func() {
|
||||
var redisConfig *RedisConfig
|
||||
var logger *Logger
|
||||
|
||||
if config != nil {
|
||||
logger = NewLogger(config.LogLevel)
|
||||
|
||||
// Initialize Redis config if not present
|
||||
if config.Redis == nil {
|
||||
config.Redis = &RedisConfig{}
|
||||
}
|
||||
|
||||
// Apply environment variable fallbacks for fields not set in config
|
||||
// This allows env vars to be used as optional overrides
|
||||
config.Redis.ApplyEnvFallbacks()
|
||||
|
||||
// Apply defaults after env fallbacks
|
||||
config.Redis.ApplyDefaults()
|
||||
|
||||
redisConfig = config.Redis
|
||||
}
|
||||
|
||||
globalCacheManagerInstance = &CacheManager{
|
||||
manager: GetUniversalCacheManager(nil),
|
||||
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
|
||||
}
|
||||
})
|
||||
return globalCacheManagerInstance
|
||||
|
||||
@@ -0,0 +1,297 @@
|
||||
// Package config provides configuration structures for the Traefik OIDC plugin.
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RedisMode represents the Redis deployment mode
|
||||
type RedisMode string
|
||||
|
||||
const (
|
||||
// RedisModeStandalone represents a single Redis instance
|
||||
RedisModeStandalone RedisMode = "standalone"
|
||||
|
||||
// RedisModeCluster represents Redis cluster mode
|
||||
RedisModeCluster RedisMode = "cluster"
|
||||
|
||||
// RedisModeSentinel represents Redis sentinel mode
|
||||
RedisModeSentinel RedisMode = "sentinel"
|
||||
)
|
||||
|
||||
// RedisConfig holds Redis cache backend configuration
|
||||
type RedisConfig struct {
|
||||
// Enabled indicates if Redis backend should be used
|
||||
Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"`
|
||||
|
||||
// Mode specifies the Redis deployment mode
|
||||
Mode RedisMode `json:"mode,omitempty" yaml:"mode,omitempty"`
|
||||
|
||||
// === Standalone Configuration ===
|
||||
// Addr is the Redis server address (host:port)
|
||||
Addr string `json:"addr,omitempty" yaml:"addr,omitempty"`
|
||||
|
||||
// Password for Redis authentication
|
||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||
|
||||
// DB is the database number (0-15)
|
||||
DB int `json:"db,omitempty" yaml:"db,omitempty"`
|
||||
|
||||
// === Cluster Configuration ===
|
||||
// ClusterAddrs is the list of cluster node addresses
|
||||
ClusterAddrs []string `json:"clusterAddrs,omitempty" yaml:"clusterAddrs,omitempty"`
|
||||
|
||||
// === Sentinel Configuration ===
|
||||
// MasterName is the name of the master instance
|
||||
MasterName string `json:"masterName,omitempty" yaml:"masterName,omitempty"`
|
||||
|
||||
// SentinelAddrs is the list of sentinel addresses
|
||||
SentinelAddrs []string `json:"sentinelAddrs,omitempty" yaml:"sentinelAddrs,omitempty"`
|
||||
|
||||
// SentinelPassword is the password for sentinel authentication
|
||||
SentinelPassword string `json:"sentinelPassword,omitempty" yaml:"sentinelPassword,omitempty"`
|
||||
|
||||
// === Connection Pool Settings ===
|
||||
// PoolSize is the maximum number of socket connections
|
||||
PoolSize int `json:"poolSize,omitempty" yaml:"poolSize,omitempty"`
|
||||
|
||||
// MinIdleConns is the minimum number of idle connections
|
||||
MinIdleConns int `json:"minIdleConns,omitempty" yaml:"minIdleConns,omitempty"`
|
||||
|
||||
// MaxRetries is the maximum number of retries before giving up
|
||||
MaxRetries int `json:"maxRetries,omitempty" yaml:"maxRetries,omitempty"`
|
||||
|
||||
// === Timeouts ===
|
||||
// DialTimeout is the timeout for establishing new connections
|
||||
DialTimeout time.Duration `json:"dialTimeout,omitempty" yaml:"dialTimeout,omitempty"`
|
||||
|
||||
// ReadTimeout is the timeout for socket reads
|
||||
ReadTimeout time.Duration `json:"readTimeout,omitempty" yaml:"readTimeout,omitempty"`
|
||||
|
||||
// WriteTimeout is the timeout for socket writes
|
||||
WriteTimeout time.Duration `json:"writeTimeout,omitempty" yaml:"writeTimeout,omitempty"`
|
||||
|
||||
// PoolTimeout is the timeout for connection pool
|
||||
PoolTimeout time.Duration `json:"poolTimeout,omitempty" yaml:"poolTimeout,omitempty"`
|
||||
|
||||
// ConnMaxIdleTime is the maximum amount of time a connection may be idle
|
||||
ConnMaxIdleTime time.Duration `json:"connMaxIdleTime,omitempty" yaml:"connMaxIdleTime,omitempty"`
|
||||
|
||||
// ConnMaxLifetime is the maximum lifetime of a connection
|
||||
ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty" yaml:"connMaxLifetime,omitempty"`
|
||||
|
||||
// === Key Management ===
|
||||
// KeyPrefix is the prefix for all Redis keys
|
||||
KeyPrefix string `json:"keyPrefix,omitempty" yaml:"keyPrefix,omitempty"`
|
||||
|
||||
// === TLS Configuration ===
|
||||
// TLSEnabled enables TLS for Redis connections
|
||||
TLSEnabled bool `json:"tlsEnabled,omitempty" yaml:"tlsEnabled,omitempty"`
|
||||
|
||||
// TLSInsecureSkipVerify skips TLS certificate verification
|
||||
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify,omitempty" yaml:"tlsInsecureSkipVerify,omitempty"`
|
||||
|
||||
// === Resilience Settings ===
|
||||
// EnableCircuitBreaker enables circuit breaker for Redis operations
|
||||
EnableCircuitBreaker bool `json:"enableCircuitBreaker,omitempty" yaml:"enableCircuitBreaker,omitempty"`
|
||||
|
||||
// CircuitBreakerMaxFailures is the number of failures before opening circuit
|
||||
CircuitBreakerMaxFailures int `json:"circuitBreakerMaxFailures,omitempty" yaml:"circuitBreakerMaxFailures,omitempty"`
|
||||
|
||||
// CircuitBreakerTimeout is how long the circuit stays open
|
||||
CircuitBreakerTimeout time.Duration `json:"circuitBreakerTimeout,omitempty" yaml:"circuitBreakerTimeout,omitempty"`
|
||||
|
||||
// EnableHealthCheck enables periodic health checks
|
||||
EnableHealthCheck bool `json:"enableHealthCheck,omitempty" yaml:"enableHealthCheck,omitempty"`
|
||||
|
||||
// HealthCheckInterval is how often to check Redis health
|
||||
HealthCheckInterval time.Duration `json:"healthCheckInterval,omitempty" yaml:"healthCheckInterval,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultRedisConfig returns default Redis configuration
|
||||
func DefaultRedisConfig() *RedisConfig {
|
||||
return &RedisConfig{
|
||||
Enabled: false,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "localhost:6379",
|
||||
DB: 0,
|
||||
PoolSize: 10,
|
||||
MinIdleConns: 2,
|
||||
MaxRetries: 3,
|
||||
DialTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
PoolTimeout: 4 * time.Second,
|
||||
ConnMaxIdleTime: 5 * time.Minute,
|
||||
ConnMaxLifetime: 30 * time.Minute,
|
||||
KeyPrefix: "traefikoidc:",
|
||||
TLSEnabled: false,
|
||||
TLSInsecureSkipVerify: false,
|
||||
EnableCircuitBreaker: true,
|
||||
CircuitBreakerMaxFailures: 5,
|
||||
CircuitBreakerTimeout: 30 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromEnv loads Redis configuration from environment variables
|
||||
func (c *RedisConfig) LoadFromEnv() {
|
||||
// Enable Redis if environment variable is set
|
||||
if enabled := os.Getenv("REDIS_ENABLED"); enabled != "" {
|
||||
c.Enabled = strings.ToLower(enabled) == "true"
|
||||
}
|
||||
|
||||
// Mode
|
||||
if mode := os.Getenv("REDIS_MODE"); mode != "" {
|
||||
c.Mode = RedisMode(strings.ToLower(mode))
|
||||
}
|
||||
|
||||
// Standalone configuration
|
||||
if addr := os.Getenv("REDIS_ADDR"); addr != "" {
|
||||
c.Addr = addr
|
||||
}
|
||||
if password := os.Getenv("REDIS_PASSWORD"); password != "" {
|
||||
c.Password = password
|
||||
}
|
||||
if db := os.Getenv("REDIS_DB"); db != "" {
|
||||
if dbNum, err := strconv.Atoi(db); err == nil {
|
||||
c.DB = dbNum
|
||||
}
|
||||
}
|
||||
|
||||
// Cluster configuration
|
||||
if clusterAddrs := os.Getenv("REDIS_CLUSTER_ADDRS"); clusterAddrs != "" {
|
||||
c.ClusterAddrs = strings.Split(clusterAddrs, ",")
|
||||
for i := range c.ClusterAddrs {
|
||||
c.ClusterAddrs[i] = strings.TrimSpace(c.ClusterAddrs[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Sentinel configuration
|
||||
if masterName := os.Getenv("REDIS_MASTER_NAME"); masterName != "" {
|
||||
c.MasterName = masterName
|
||||
}
|
||||
if sentinelAddrs := os.Getenv("REDIS_SENTINEL_ADDRS"); sentinelAddrs != "" {
|
||||
c.SentinelAddrs = strings.Split(sentinelAddrs, ",")
|
||||
for i := range c.SentinelAddrs {
|
||||
c.SentinelAddrs[i] = strings.TrimSpace(c.SentinelAddrs[i])
|
||||
}
|
||||
}
|
||||
if sentinelPassword := os.Getenv("REDIS_SENTINEL_PASSWORD"); sentinelPassword != "" {
|
||||
c.SentinelPassword = sentinelPassword
|
||||
}
|
||||
|
||||
// Connection pool settings
|
||||
if poolSize := os.Getenv("REDIS_POOL_SIZE"); poolSize != "" {
|
||||
if size, err := strconv.Atoi(poolSize); err == nil {
|
||||
c.PoolSize = size
|
||||
}
|
||||
}
|
||||
if minIdleConns := os.Getenv("REDIS_MIN_IDLE_CONNS"); minIdleConns != "" {
|
||||
if conns, err := strconv.Atoi(minIdleConns); err == nil {
|
||||
c.MinIdleConns = conns
|
||||
}
|
||||
}
|
||||
if maxRetries := os.Getenv("REDIS_MAX_RETRIES"); maxRetries != "" {
|
||||
if retries, err := strconv.Atoi(maxRetries); err == nil {
|
||||
c.MaxRetries = retries
|
||||
}
|
||||
}
|
||||
|
||||
// Timeouts
|
||||
if dialTimeout := os.Getenv("REDIS_DIAL_TIMEOUT"); dialTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(dialTimeout); err == nil {
|
||||
c.DialTimeout = timeout
|
||||
}
|
||||
}
|
||||
if readTimeout := os.Getenv("REDIS_READ_TIMEOUT"); readTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(readTimeout); err == nil {
|
||||
c.ReadTimeout = timeout
|
||||
}
|
||||
}
|
||||
if writeTimeout := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(writeTimeout); err == nil {
|
||||
c.WriteTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Key prefix
|
||||
if keyPrefix := os.Getenv("REDIS_KEY_PREFIX"); keyPrefix != "" {
|
||||
c.KeyPrefix = keyPrefix
|
||||
}
|
||||
|
||||
// TLS settings
|
||||
if tlsEnabled := os.Getenv("REDIS_TLS_ENABLED"); tlsEnabled != "" {
|
||||
c.TLSEnabled = strings.ToLower(tlsEnabled) == "true"
|
||||
}
|
||||
if tlsInsecure := os.Getenv("REDIS_TLS_INSECURE_SKIP_VERIFY"); tlsInsecure != "" {
|
||||
c.TLSInsecureSkipVerify = strings.ToLower(tlsInsecure) == "true"
|
||||
}
|
||||
|
||||
// Resilience settings
|
||||
if enableCB := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCB != "" {
|
||||
c.EnableCircuitBreaker = strings.ToLower(enableCB) == "true"
|
||||
}
|
||||
if cbMaxFailures := os.Getenv("REDIS_CIRCUIT_BREAKER_MAX_FAILURES"); cbMaxFailures != "" {
|
||||
if failures, err := strconv.Atoi(cbMaxFailures); err == nil {
|
||||
c.CircuitBreakerMaxFailures = failures
|
||||
}
|
||||
}
|
||||
if cbTimeout := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(cbTimeout); err == nil {
|
||||
c.CircuitBreakerTimeout = timeout
|
||||
}
|
||||
}
|
||||
if enableHC := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHC != "" {
|
||||
c.EnableHealthCheck = strings.ToLower(enableHC) == "true"
|
||||
}
|
||||
if hcInterval := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcInterval != "" {
|
||||
if interval, err := time.ParseDuration(hcInterval); err == nil {
|
||||
c.HealthCheckInterval = interval
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid
|
||||
func (c *RedisConfig) Validate() error {
|
||||
if !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch c.Mode {
|
||||
case RedisModeStandalone:
|
||||
if c.Addr == "" {
|
||||
return &ConfigError{Field: "addr", Message: "Redis address is required for standalone mode"}
|
||||
}
|
||||
case RedisModeCluster:
|
||||
if len(c.ClusterAddrs) == 0 {
|
||||
return &ConfigError{Field: "clusterAddrs", Message: "At least one cluster address is required"}
|
||||
}
|
||||
case RedisModeSentinel:
|
||||
if c.MasterName == "" {
|
||||
return &ConfigError{Field: "masterName", Message: "Master name is required for sentinel mode"}
|
||||
}
|
||||
if len(c.SentinelAddrs) == 0 {
|
||||
return &ConfigError{Field: "sentinelAddrs", Message: "At least one sentinel address is required"}
|
||||
}
|
||||
default:
|
||||
return &ConfigError{Field: "mode", Message: "Invalid Redis mode"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigError represents a configuration validation error
|
||||
type ConfigError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *ConfigError) Error() string {
|
||||
return "redis config error: " + e.Field + ": " + e.Message
|
||||
}
|
||||
+1110
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,413 @@
|
||||
# Redis Cache Backend Test Suite
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the comprehensive test suite created for the Redis cache backend feature in the Traefik OIDC plugin. The test suite ensures reliability, performance, and correctness of the caching infrastructure.
|
||||
|
||||
## Test Structure
|
||||
|
||||
### Directory Organization
|
||||
|
||||
```
|
||||
internal/cache/
|
||||
├── backend/
|
||||
│ ├── interface.go # CacheBackend interface definition
|
||||
│ ├── interface_test.go # Contract tests for all backends
|
||||
│ ├── memory.go # In-memory backend implementation
|
||||
│ ├── memory_test.go # Memory backend unit tests
|
||||
│ ├── redis.go # Redis backend implementation
|
||||
│ ├── redis_test.go # Redis backend unit tests
|
||||
│ ├── errors.go # Error definitions
|
||||
│ └── test_helpers_test.go # Test infrastructure and helpers
|
||||
│
|
||||
└── resilience/
|
||||
├── circuit_breaker.go # Circuit breaker implementation
|
||||
├── circuit_breaker_test.go # Circuit breaker tests
|
||||
├── health_check.go # Health checker implementation
|
||||
└── health_check_test.go # Health check tests
|
||||
|
||||
redis_integration_test.go # End-to-end integration tests
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### 1. Interface Contract Tests (`interface_test.go`)
|
||||
|
||||
**Purpose:** Ensure all backend implementations (Memory, Redis, Hybrid) comply with the CacheBackend interface contract.
|
||||
|
||||
**Test Cases:**
|
||||
- `TestCacheBackendContract` - Runs all contract tests against each backend type
|
||||
- `testBasicSetGet` - Verifies basic set/get operations
|
||||
- `testGetNonExistent` - Tests behavior for non-existent keys
|
||||
- `testUpdateExisting` - Validates updating existing keys
|
||||
- `testDelete` - Tests delete operations
|
||||
- `testDeleteNonExistent` - Delete non-existent keys
|
||||
- `testExists` - Key existence checking
|
||||
- `testTTLExpiration` - TTL and expiration behavior
|
||||
- `testClear` - Clear all keys operation
|
||||
- `testPing` - Health check functionality
|
||||
- `testStats` - Statistics tracking
|
||||
- `testConcurrentAccess` - Thread safety with 10+ goroutines
|
||||
- `testLargeValues` - Handling of 1MB+ values
|
||||
- `testEmptyValues` - Empty byte array handling
|
||||
- `testSpecialCharactersInKeys` - Special characters in key names
|
||||
|
||||
**Coverage:** ~95% of interface methods
|
||||
|
||||
### 2. Memory Backend Tests (`memory_test.go`)
|
||||
|
||||
**Purpose:** Test the in-memory LRU cache backend with comprehensive edge cases.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Basic Operations (6 tests)
|
||||
- `TestMemoryBackend_BasicOperations` - CRUD operations
|
||||
- SetAndGet
|
||||
- GetNonExistent
|
||||
- Delete
|
||||
- DeleteNonExistent
|
||||
- Exists
|
||||
- Clear
|
||||
|
||||
#### TTL and Expiration (3 tests)
|
||||
- `TestMemoryBackend_TTLExpiration`
|
||||
- ShortTTL (100ms)
|
||||
- TTLDecrement over time
|
||||
- CleanupExpiredItems
|
||||
|
||||
#### LRU Eviction (2 tests)
|
||||
- `TestMemoryBackend_LRUEviction` - Verifies LRU algorithm
|
||||
- `TestMemoryBackend_MemoryLimit` - Memory-based eviction
|
||||
|
||||
#### Concurrency (1 test)
|
||||
- `TestMemoryBackend_ConcurrentAccess` - 20 goroutines, 50 iterations each
|
||||
|
||||
#### Edge Cases (6 tests)
|
||||
- `TestMemoryBackend_UpdateExisting` - Overwriting values
|
||||
- `TestMemoryBackend_Stats` - Metrics tracking (hits, misses, hit rate)
|
||||
- `TestMemoryBackend_EmptyValues` - Zero-length byte arrays
|
||||
- `TestMemoryBackend_LargeValues` - 1MB values
|
||||
- `TestMemoryBackend_Close` - Proper cleanup
|
||||
- `TestMemoryBackend_Ping` - Health checks
|
||||
- `TestMemoryBackend_ValueIsolation` - Returns copies, not references
|
||||
|
||||
**Coverage:** ~92% of memory backend code
|
||||
|
||||
### 3. Redis Backend Tests (`redis_test.go`)
|
||||
|
||||
**Purpose:** Test Redis backend using miniredis (in-memory Redis mock).
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Basic Operations (4 tests)
|
||||
- `TestRedisBackend_BasicOperations`
|
||||
- SetAndGet
|
||||
- GetNonExistent
|
||||
- Delete
|
||||
- Exists
|
||||
|
||||
#### Redis-Specific Features (6 tests)
|
||||
- `TestRedisBackend_KeyPrefixing` - Namespace isolation
|
||||
- `TestRedisBackend_TTLExpiration` - Redis TTL handling
|
||||
- `TestRedisBackend_Clear` - Bulk delete with SCAN
|
||||
- `TestRedisBackend_NoPrefix` - Operation without prefix
|
||||
|
||||
#### Error Handling (2 tests)
|
||||
- `TestRedisBackend_ConnectionFailure` - Connection errors
|
||||
- `TestRedisBackend_RedisErrors` - Simulated Redis failures
|
||||
|
||||
#### Concurrency (1 test)
|
||||
- `TestRedisBackend_ConcurrentAccess` - 20 goroutines, 50 operations
|
||||
|
||||
#### Advanced Features (3 tests)
|
||||
- `TestRedisBackend_PipelineOperations`
|
||||
- SetMany (batch writes)
|
||||
- GetMany (batch reads)
|
||||
- GetManyWithNonExistent
|
||||
|
||||
#### Edge Cases (5 tests)
|
||||
- `TestRedisBackend_Stats` - Statistics tracking
|
||||
- `TestRedisBackend_Ping` - Connection health
|
||||
- `TestRedisBackend_Close` - Resource cleanup
|
||||
- `TestRedisBackend_UpdateExisting` - Overwrite handling
|
||||
- `TestRedisBackend_LargeValues` - 1MB values
|
||||
- `TestRedisBackend_EmptyValues` - Empty arrays
|
||||
|
||||
**Coverage:** ~88% of Redis backend code
|
||||
|
||||
**Key Testing Tool:** `miniredis` - In-memory Redis mock that supports:
|
||||
- All basic Redis commands
|
||||
- TTL and expiration
|
||||
- Time manipulation (FastForward)
|
||||
- Error simulation
|
||||
- No external Redis server required
|
||||
|
||||
### 4. Circuit Breaker Tests (`circuit_breaker_test.go`)
|
||||
|
||||
**Purpose:** Verify circuit breaker pattern implementation for fault tolerance.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### State Transitions (5 tests)
|
||||
- `TestCircuitBreaker_StateTransitions`
|
||||
- Initial state (Closed)
|
||||
- Closed → Open (after max failures)
|
||||
- Open → HalfOpen (after timeout)
|
||||
- HalfOpen → Closed (after successful requests)
|
||||
- HalfOpen → Open (on failure)
|
||||
|
||||
#### Behavior Tests (5 tests)
|
||||
- `TestCircuitBreaker_OpenCircuitBlocks` - Blocks requests when open
|
||||
- `TestCircuitBreaker_HalfOpenMaxRequests` - Limits requests in half-open
|
||||
- `TestCircuitBreaker_SuccessResetsFailures` - Failure counter reset
|
||||
- `TestCircuitBreaker_ConcurrentAccess` - Thread safety
|
||||
- `TestCircuitBreaker_Stats` - Statistics tracking
|
||||
|
||||
#### Advanced Tests (7 tests)
|
||||
- `TestCircuitBreaker_Reset` - Manual reset
|
||||
- `TestCircuitBreaker_StateChangeCallback` - Notifications
|
||||
- `TestCircuitBreaker_IsAvailable` - Availability check
|
||||
- `TestCircuitBreaker_RapidFailures` - Fast consecutive failures
|
||||
- `TestCircuitBreaker_TimeoutAccuracy` - Timeout precision
|
||||
- `TestCircuitBreaker_DefaultConfig` - Default configuration
|
||||
- `TestCircuitBreaker_StateString` - String representation
|
||||
|
||||
**Benchmarks:**
|
||||
- `BenchmarkCircuitBreaker_Execute` - Successful operations
|
||||
- `BenchmarkCircuitBreaker_ExecuteWithFailures` - Mixed success/failure
|
||||
|
||||
**Coverage:** ~95% of circuit breaker code
|
||||
|
||||
### 5. Health Check Tests (`health_check_test.go`)
|
||||
|
||||
**Purpose:** Validate periodic health checking and status management.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Status Transitions (4 tests)
|
||||
- `TestHealthChecker_StatusTransitions` - Healthy → Degraded → Unhealthy → Healthy
|
||||
- `TestHealthChecker_InitialState` - Default healthy state
|
||||
- `TestHealthChecker_ForceCheck` - Manual health check trigger
|
||||
- `TestHealthChecker_StatusChangeCallback` - Change notifications
|
||||
|
||||
#### Behavior Tests (6 tests)
|
||||
- `TestHealthChecker_Stats` - Statistics tracking
|
||||
- `TestHealthChecker_Timeout` - Check timeout handling
|
||||
- `TestHealthChecker_ConcurrentAccess` - Thread safety
|
||||
- `TestHealthChecker_StopAndStart` - Lifecycle management
|
||||
- `TestHealthChecker_DegradedState` - Degraded status detection
|
||||
- `TestHealthChecker_DefaultConfig` - Default settings
|
||||
|
||||
#### Advanced Tests (2 tests)
|
||||
- `TestHealthChecker_StatusString` - String representation
|
||||
- `TestHealthChecker_RecoveryPattern` - Typical failure/recovery cycle
|
||||
|
||||
**Benchmarks:**
|
||||
- `BenchmarkHealthChecker_ForceCheck` - Check performance
|
||||
- `BenchmarkHealthChecker_Status` - Status read performance
|
||||
|
||||
**Coverage:** ~90% of health checker code
|
||||
|
||||
### 6. Integration Tests (`redis_integration_test.go`)
|
||||
|
||||
**Purpose:** End-to-end testing of real-world scenarios.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Multi-Instance Tests (3 tests)
|
||||
- `TestRedisIntegration_MultipleInstances`
|
||||
- ShareTokenBlacklist - JTI sharing across Traefik replicas
|
||||
- ShareTokenCache - Token cache sharing
|
||||
- ShareMetadataCache - Provider metadata sharing
|
||||
|
||||
#### Replay Detection (2 tests)
|
||||
- `TestRedisIntegration_JTIReplayDetection`
|
||||
- PreventReplayAcrossInstances - Block used JTIs
|
||||
- ConcurrentJTIChecks - Race condition handling
|
||||
|
||||
#### Resilience (1 test)
|
||||
- `TestRedisIntegration_Failover`
|
||||
- RedisTemporaryFailure - Recovery from temporary failures
|
||||
|
||||
#### Performance (1 test)
|
||||
- `TestRedisIntegration_HighLoad`
|
||||
- HighConcurrency - 50 goroutines × 100 operations
|
||||
|
||||
#### Consistency (2 tests)
|
||||
- `TestRedisIntegration_TTLConsistency` - TTL accuracy
|
||||
- `TestRedisIntegration_MemoryUsage` - 10,000 item dataset
|
||||
- `TestRedisIntegration_Cleanup` - Bulk cleanup operations
|
||||
|
||||
**Coverage:** Integration scenarios covering 80%+ of realistic use cases
|
||||
|
||||
## Test Helpers and Infrastructure
|
||||
|
||||
### Test Helpers (`test_helpers_test.go`)
|
||||
|
||||
**Utilities:**
|
||||
- `TestLogger` - Logging for tests
|
||||
- `MiniredisServer` - Miniredis setup/teardown
|
||||
- `TestConfig` - Default test configurations
|
||||
- `GenerateTestData` - Test data generation
|
||||
- `GenerateLargeValue` - Large value creation
|
||||
- `AssertCacheStats` - Statistics validation
|
||||
- `WaitForCondition` - Async condition waiting
|
||||
- `AssertEventuallyExpires` - TTL expiration verification
|
||||
|
||||
## Running the Tests
|
||||
|
||||
### Run All Tests
|
||||
```bash
|
||||
go test ./internal/cache/backend/... -v
|
||||
go test ./internal/cache/resilience/... -v
|
||||
go test -run TestRedisIntegration -v
|
||||
```
|
||||
|
||||
### Run Specific Test Suites
|
||||
```bash
|
||||
# Memory backend only
|
||||
go test ./internal/cache/backend -run TestMemoryBackend -v
|
||||
|
||||
# Redis backend only
|
||||
go test ./internal/cache/backend -run TestRedisBackend -v
|
||||
|
||||
# Circuit breaker only
|
||||
go test ./internal/cache/resilience -run TestCircuitBreaker -v
|
||||
|
||||
# Integration tests only
|
||||
go test -run TestRedisIntegration -v
|
||||
```
|
||||
|
||||
### Run with Coverage
|
||||
```bash
|
||||
go test ./internal/cache/backend/... -coverprofile=coverage.out
|
||||
go test ./internal/cache/resilience/... -coverprofile=coverage_resilience.out
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Run Benchmarks
|
||||
```bash
|
||||
go test ./internal/cache/backend -bench=. -benchmem
|
||||
go test ./internal/cache/resilience -bench=. -benchmem
|
||||
```
|
||||
|
||||
### Run with Race Detector
|
||||
```bash
|
||||
go test ./internal/cache/... -race -v
|
||||
```
|
||||
|
||||
## Test Patterns Used
|
||||
|
||||
### 1. Table-Driven Tests
|
||||
Used for testing multiple scenarios with similar structure.
|
||||
|
||||
### 2. Subtests (t.Run)
|
||||
Organized test cases into logical groups with clear names.
|
||||
|
||||
### 3. Parallel Tests
|
||||
Tests marked with `t.Parallel()` for faster execution.
|
||||
|
||||
### 4. Test Fixtures
|
||||
Reusable setup functions for common test data.
|
||||
|
||||
### 5. Mocking
|
||||
- `miniredis` for Redis operations
|
||||
- Mock functions for callbacks and health checks
|
||||
|
||||
### 6. Assertion Helpers
|
||||
Using `testify/assert` and `testify/require` for clear assertions.
|
||||
|
||||
## Test Coverage Summary
|
||||
|
||||
| Component | Coverage | Tests | Lines of Code |
|
||||
|-----------|----------|-------|---------------|
|
||||
| Interface Contract | 95% | 14 | ~200 |
|
||||
| Memory Backend | 92% | 18 | ~350 |
|
||||
| Redis Backend | 88% | 21 | ~400 |
|
||||
| Circuit Breaker | 95% | 17 | ~250 |
|
||||
| Health Checker | 90% | 12 | ~200 |
|
||||
| Integration Tests | 80% | 9 | ~300 |
|
||||
| **Total** | **90%** | **91** | **~1,700** |
|
||||
|
||||
## Edge Cases Tested
|
||||
|
||||
1. **Empty values** - Zero-length byte arrays
|
||||
2. **Large values** - 1MB+ data
|
||||
3. **Special characters** - Keys with :, /, -, _, ., |
|
||||
4. **Concurrent access** - 10-50 goroutines
|
||||
5. **TTL edge cases** - Very short (<100ms) and long (24h+) TTLs
|
||||
6. **Connection failures** - Network errors, timeouts
|
||||
7. **Redis errors** - Simulated Redis failures
|
||||
8. **Memory limits** - Eviction under memory pressure
|
||||
9. **Race conditions** - Concurrent JTI checks
|
||||
10. **State transitions** - All circuit breaker and health check states
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
Benchmarks included for:
|
||||
- Cache operations (Set, Get, Delete)
|
||||
- Circuit breaker execution
|
||||
- Health check operations
|
||||
- Concurrent access patterns
|
||||
- Large datasets (10,000+ items)
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Testing Libraries
|
||||
- `github.com/stretchr/testify` - Assertions and test utilities
|
||||
- `github.com/alicebob/miniredis/v2` - In-memory Redis mock
|
||||
- `github.com/redis/go-redis/v9` - Redis client
|
||||
|
||||
### Why Miniredis?
|
||||
- **No external dependencies** - No Redis server required
|
||||
- **Fast** - In-memory, perfect for unit tests
|
||||
- **Full Redis API** - Supports all operations we need
|
||||
- **Time manipulation** - FastForward for TTL testing
|
||||
- **Error simulation** - Test failure scenarios
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### Planned Tests
|
||||
1. Hybrid backend tests (L1/L2 cache)
|
||||
2. Network partition scenarios
|
||||
3. Redis cluster support
|
||||
4. Persistence and recovery tests
|
||||
5. Metrics and monitoring integration
|
||||
|
||||
### Test Infrastructure Improvements
|
||||
1. Test containers for real Redis integration
|
||||
2. Performance regression tracking
|
||||
3. Chaos engineering tests
|
||||
4. Load testing framework
|
||||
|
||||
## Continuous Integration
|
||||
|
||||
### Recommended CI Configuration
|
||||
|
||||
```yaml
|
||||
test:
|
||||
script:
|
||||
- go test ./internal/cache/... -race -cover -v
|
||||
- go test -run TestRedisIntegration -v
|
||||
- go test ./internal/cache/... -bench=. -benchmem
|
||||
```
|
||||
|
||||
## Maintenance Guidelines
|
||||
|
||||
1. **Add tests for new features** - Maintain >85% coverage
|
||||
2. **Update contract tests** - When interface changes
|
||||
3. **Test edge cases** - Always test error paths
|
||||
4. **Document test purpose** - Clear comments explaining what each test validates
|
||||
5. **Keep tests fast** - Use t.Parallel() where possible
|
||||
6. **Mock external dependencies** - Use miniredis, not real Redis
|
||||
|
||||
## Conclusion
|
||||
|
||||
This comprehensive test suite provides:
|
||||
- **High confidence** in cache backend correctness
|
||||
- **Fast feedback** - Tests run in seconds
|
||||
- **Good coverage** - 90% overall
|
||||
- **Clear documentation** - Each test is well-documented
|
||||
- **Maintainability** - Clear structure and patterns
|
||||
|
||||
The test suite ensures that the Redis cache backend feature is production-ready and reliable for multi-replica Traefik deployments with shared caching requirements.
|
||||
@@ -0,0 +1,486 @@
|
||||
# ============================================================================
|
||||
# Complete Traefik Configuration Example with TraefikOIDC Plugin + Redis
|
||||
# ============================================================================
|
||||
#
|
||||
# This example shows a complete, production-ready configuration for using
|
||||
# the TraefikOIDC plugin with Redis caching in a multi-replica deployment.
|
||||
#
|
||||
|
||||
# ============================================================================
|
||||
# Part 1: Traefik Static Configuration (traefik.yml)
|
||||
# ============================================================================
|
||||
# This file configures Traefik itself and enables the plugin.
|
||||
# Place this in /etc/traefik/traefik.yml or mount it in your container.
|
||||
|
||||
---
|
||||
# Static Configuration
|
||||
api:
|
||||
dashboard: true
|
||||
insecure: false # Set to true only for local development
|
||||
|
||||
entryPoints:
|
||||
web:
|
||||
address: ":80"
|
||||
http:
|
||||
redirections:
|
||||
entryPoint:
|
||||
to: websecure
|
||||
scheme: https
|
||||
|
||||
websecure:
|
||||
address: ":443"
|
||||
http:
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
certificatesResolvers:
|
||||
letsencrypt:
|
||||
acme:
|
||||
email: admin@example.com
|
||||
storage: /letsencrypt/acme.json
|
||||
httpChallenge:
|
||||
entryPoint: web
|
||||
|
||||
providers:
|
||||
file:
|
||||
filename: /etc/traefik/dynamic.yml
|
||||
watch: true
|
||||
|
||||
# Enable the TraefikOIDC plugin
|
||||
experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.8.0
|
||||
|
||||
log:
|
||||
level: INFO
|
||||
format: json
|
||||
|
||||
accessLog:
|
||||
format: json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 2: Traefik Dynamic Configuration (dynamic.yml)
|
||||
# ============================================================================
|
||||
# This file defines your routes, services, and middleware.
|
||||
# Place this in /etc/traefik/dynamic.yml
|
||||
|
||||
---
|
||||
http:
|
||||
# -------------------------------------------------------------------------
|
||||
# Middleware Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
middlewares:
|
||||
# Example 1: Minimal Redis Configuration
|
||||
# Perfect for getting started quickly
|
||||
oidc-minimal:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-application-client-id"
|
||||
clientSecret: "your-client-secret-from-provider"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-secure-64-character-encryption-key-must-be-kept-secret"
|
||||
|
||||
# Minimal Redis configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
|
||||
# Example 2: Production Redis Configuration
|
||||
# Recommended for production deployments with multiple Traefik replicas
|
||||
oidc-production:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# OIDC Provider Configuration
|
||||
clientID: "prod-client-id"
|
||||
clientSecret: "prod-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
|
||||
# Session Configuration
|
||||
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
|
||||
sessionMaxAge: 28800 # 8 hours
|
||||
|
||||
# Security Settings
|
||||
forceHTTPS: true
|
||||
strictAudienceValidation: true
|
||||
|
||||
# Redis Configuration for Multi-Replica Deployment
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis-master.redis-namespace.svc.cluster.local:6379"
|
||||
password: "strong-redis-password"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:prod:"
|
||||
|
||||
# Cache Strategy
|
||||
cacheMode: "hybrid" # Fast local cache + shared Redis
|
||||
|
||||
# Connection Pooling
|
||||
poolSize: 20
|
||||
connectTimeout: 5
|
||||
readTimeout: 3
|
||||
writeTimeout: 3
|
||||
|
||||
# Resilience Features
|
||||
enableCircuitBreaker: true
|
||||
circuitBreakerThreshold: 5
|
||||
circuitBreakerTimeout: 60
|
||||
enableHealthCheck: true
|
||||
healthCheckInterval: 30
|
||||
|
||||
# Example 3: Redis with TLS (for production security)
|
||||
oidc-secure:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
clientID: "secure-client-id"
|
||||
clientSecret: "secure-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "secure-64-character-encryption-key-for-production-use-only"
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.example.com:6380"
|
||||
password: "secure-redis-password"
|
||||
enableTLS: true
|
||||
tlsSkipVerify: false # Verify certificates in production
|
||||
cacheMode: "redis"
|
||||
|
||||
# Example 4: Hybrid Mode (Best Performance + Consistency)
|
||||
# Local cache for hot data, Redis for consistency across replicas
|
||||
oidc-hybrid:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
clientID: "app-client-id"
|
||||
clientSecret: "app-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "hybrid-mode-encryption-key-64-characters-long-and-secure"
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
password: "redis-password"
|
||||
cacheMode: "hybrid"
|
||||
|
||||
# Hybrid mode L1 cache settings
|
||||
hybridL1Size: 1000 # Number of items in local cache
|
||||
hybridL1MemoryMB: 20 # MB of memory for local cache
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Router Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
routers:
|
||||
# Protected application using OIDC authentication
|
||||
my-app:
|
||||
rule: "Host(`app.example.com`)"
|
||||
entryPoints:
|
||||
- websecure
|
||||
middlewares:
|
||||
- oidc-production # Use the OIDC middleware
|
||||
service: my-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
# Another app with minimal OIDC config
|
||||
simple-app:
|
||||
rule: "Host(`simple.example.com`)"
|
||||
entryPoints:
|
||||
- websecure
|
||||
middlewares:
|
||||
- oidc-minimal
|
||||
service: simple-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Service Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
services:
|
||||
my-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://my-app:8080"
|
||||
healthCheck:
|
||||
path: /health
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
|
||||
simple-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://simple-app:3000"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 3: Docker Compose Example
|
||||
# ============================================================================
|
||||
|
||||
---
|
||||
# docker-compose.yml
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# Redis service for shared caching
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
command: redis-server --requirepass yourredispassword --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||
ports:
|
||||
- "6379:6379"
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 10s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
networks:
|
||||
- traefik-network
|
||||
|
||||
# Traefik with TraefikOIDC plugin
|
||||
traefik:
|
||||
image: traefik:v3.2
|
||||
command:
|
||||
- "--api.dashboard=true"
|
||||
- "--providers.docker=true"
|
||||
- "--providers.docker.exposedbydefault=false"
|
||||
- "--providers.file.filename=/etc/traefik/dynamic.yml"
|
||||
- "--entrypoints.web.address=:80"
|
||||
- "--entrypoints.websecure.address=:443"
|
||||
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.8.0"
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
- "8080:8080" # Dashboard
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
- ./traefik-dynamic.yml:/etc/traefik/dynamic.yml:ro
|
||||
- ./letsencrypt:/letsencrypt
|
||||
depends_on:
|
||||
- redis
|
||||
networks:
|
||||
- traefik-network
|
||||
|
||||
# Your application
|
||||
my-app:
|
||||
image: my-app:latest
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
|
||||
- "traefik.http.routers.my-app.entrypoints=websecure"
|
||||
- "traefik.http.routers.my-app.tls.certresolver=letsencrypt"
|
||||
|
||||
# OIDC Middleware Configuration with Redis (using labels)
|
||||
- "traefik.http.routers.my-app.middlewares=my-oidc@docker"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-client-secret"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-character-encryption-key-here"
|
||||
|
||||
# Redis configuration
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=yourredispassword"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
|
||||
networks:
|
||||
- traefik-network
|
||||
deploy:
|
||||
replicas: 3 # Multiple replicas sharing Redis cache
|
||||
|
||||
volumes:
|
||||
redis-data:
|
||||
|
||||
networks:
|
||||
traefik-network:
|
||||
driver: bridge
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 4: Kubernetes Example
|
||||
# ============================================================================
|
||||
|
||||
---
|
||||
# kubernetes-example.yaml
|
||||
|
||||
# Redis Deployment
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: redis
|
||||
namespace: traefik
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: redis
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: redis
|
||||
spec:
|
||||
containers:
|
||||
- name: redis
|
||||
image: redis:7-alpine
|
||||
args:
|
||||
- redis-server
|
||||
- --requirepass
|
||||
- $(REDIS_PASSWORD)
|
||||
- --maxmemory
|
||||
- 512mb
|
||||
- --maxmemory-policy
|
||||
- allkeys-lru
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: redis-secret
|
||||
key: password
|
||||
ports:
|
||||
- containerPort: 6379
|
||||
resources:
|
||||
requests:
|
||||
memory: "256Mi"
|
||||
cpu: "100m"
|
||||
limits:
|
||||
memory: "512Mi"
|
||||
cpu: "500m"
|
||||
---
|
||||
# Redis Service
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: redis
|
||||
namespace: traefik
|
||||
spec:
|
||||
selector:
|
||||
app: redis
|
||||
ports:
|
||||
- port: 6379
|
||||
targetPort: 6379
|
||||
---
|
||||
# Redis Secret
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: redis-secret
|
||||
namespace: traefik
|
||||
type: Opaque
|
||||
stringData:
|
||||
password: "your-secure-redis-password"
|
||||
---
|
||||
# OIDC Middleware with Redis
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# OIDC Configuration
|
||||
clientID: "kubernetes-client-id"
|
||||
clientSecret: "kubernetes-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "kubernetes-64-character-session-encryption-key-keep-secret"
|
||||
|
||||
# Redis Configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.traefik.svc.cluster.local:6379"
|
||||
password: "your-secure-redis-password"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:k8s:"
|
||||
cacheMode: "hybrid"
|
||||
poolSize: 20
|
||||
enableCircuitBreaker: true
|
||||
enableHealthCheck: true
|
||||
---
|
||||
# IngressRoute using the middleware
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: IngressRoute
|
||||
metadata:
|
||||
name: my-app
|
||||
namespace: default
|
||||
spec:
|
||||
entryPoints:
|
||||
- websecure
|
||||
routes:
|
||||
- match: Host(`app.example.com`)
|
||||
kind: Rule
|
||||
middlewares:
|
||||
- name: oidc-auth
|
||||
namespace: traefik
|
||||
services:
|
||||
- name: my-app
|
||||
port: 80
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 5: Environment Variables (Optional Fallback)
|
||||
# ============================================================================
|
||||
|
||||
# If you prefer environment variables as fallback (not recommended for production),
|
||||
# you can set these. NOTE: Plugin configuration takes precedence!
|
||||
|
||||
# Docker Compose env file (.env)
|
||||
---
|
||||
# OIDC Configuration
|
||||
OIDC_CLIENT_ID=your-client-id
|
||||
OIDC_CLIENT_SECRET=your-client-secret
|
||||
OIDC_PROVIDER_URL=https://auth.example.com
|
||||
|
||||
# Redis Configuration (fallback)
|
||||
REDIS_ENABLED=true
|
||||
REDIS_ADDRESS=redis:6379
|
||||
REDIS_PASSWORD=yourredispassword
|
||||
REDIS_DB=0
|
||||
REDIS_KEY_PREFIX=traefikoidc:
|
||||
REDIS_CACHE_MODE=hybrid
|
||||
REDIS_POOL_SIZE=20
|
||||
REDIS_ENABLE_CIRCUIT_BREAKER=true
|
||||
REDIS_ENABLE_HEALTH_CHECK=true
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration Cheat Sheet
|
||||
# ============================================================================
|
||||
|
||||
# Minimal Setup (Quick Start):
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis:6379"
|
||||
|
||||
# Production Setup (Recommended):
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis-master:6379"
|
||||
# password: "strong-password"
|
||||
# cacheMode: "hybrid"
|
||||
# enableCircuitBreaker: true
|
||||
# enableHealthCheck: true
|
||||
|
||||
# High Security Setup:
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis.example.com:6380"
|
||||
# password: "strong-password"
|
||||
# enableTLS: true
|
||||
# tlsSkipVerify: false
|
||||
# cacheMode: "redis"
|
||||
|
||||
# Cache Modes:
|
||||
# - "memory": Local cache only (default, no Redis needed)
|
||||
# - "redis": Redis only (consistent, shared across replicas)
|
||||
# - "hybrid": Local L1 + Redis L2 (best performance + consistency)
|
||||
@@ -0,0 +1,149 @@
|
||||
# Example Traefik configuration for TraefikOIDC plugin with Redis caching
|
||||
# This example shows how to configure Redis through Traefik's dynamic configuration
|
||||
|
||||
# Static configuration (traefik.yml)
|
||||
experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.8.0
|
||||
|
||||
# Dynamic configuration (dynamic.yml or labels)
|
||||
http:
|
||||
middlewares:
|
||||
# Example 1: Basic Redis configuration
|
||||
oidc-redis-basic:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
# password: "your-redis-password" # Optional
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
|
||||
# Example 2: Redis with resilience features
|
||||
oidc-redis-resilient:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis with full resilience configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
password: "strong-password"
|
||||
db: 1
|
||||
keyPrefix: "myapp:"
|
||||
poolSize: 20
|
||||
connectTimeout: 10
|
||||
readTimeout: 5
|
||||
writeTimeout: 5
|
||||
cacheMode: "redis" # Options: "redis", "hybrid", "memory"
|
||||
# Circuit breaker settings
|
||||
enableCircuitBreaker: true
|
||||
circuitBreakerThreshold: 5
|
||||
circuitBreakerTimeout: 60
|
||||
# Health check settings
|
||||
enableHealthCheck: true
|
||||
healthCheckInterval: 30
|
||||
|
||||
# Example 3: Redis with TLS
|
||||
oidc-redis-tls:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis with TLS configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.example.com:6380"
|
||||
password: "secure-password"
|
||||
enableTLS: true
|
||||
tlsSkipVerify: false # Set to true only for testing
|
||||
cacheMode: "redis"
|
||||
|
||||
routers:
|
||||
my-app:
|
||||
rule: "Host(`app.example.com`)"
|
||||
middlewares:
|
||||
- oidc-redis-basic
|
||||
service: my-app-service
|
||||
|
||||
services:
|
||||
my-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://localhost:8080"
|
||||
|
||||
# Docker Compose labels example
|
||||
# version: '3.8'
|
||||
# services:
|
||||
# traefik:
|
||||
# image: traefik:v3.0
|
||||
# # ... other config ...
|
||||
#
|
||||
# my-app:
|
||||
# image: my-app:latest
|
||||
# labels:
|
||||
# - "traefik.enable=true"
|
||||
# - "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
|
||||
# - "traefik.http.routers.my-app.middlewares=my-oidc"
|
||||
# # OIDC middleware configuration with Redis
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-secret"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
|
||||
# # Redis configuration via labels
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=redis-password"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=redis"
|
||||
#
|
||||
# redis:
|
||||
# image: redis:7-alpine
|
||||
# command: redis-server --requirepass redis-password
|
||||
# # ... other config ...
|
||||
|
||||
# Environment variable fallback (optional)
|
||||
# If Redis configuration is not provided in Traefik config, these environment variables
|
||||
# can be used as a fallback (but Traefik config takes precedence):
|
||||
#
|
||||
# REDIS_ENABLED=true
|
||||
# REDIS_ADDRESS=redis:6379
|
||||
# REDIS_PASSWORD=secret
|
||||
# REDIS_DB=0
|
||||
# REDIS_KEY_PREFIX=traefikoidc:
|
||||
# REDIS_CACHE_MODE=redis
|
||||
# REDIS_POOL_SIZE=10
|
||||
# REDIS_CONNECT_TIMEOUT=5
|
||||
# REDIS_READ_TIMEOUT=3
|
||||
# REDIS_WRITE_TIMEOUT=3
|
||||
# REDIS_ENABLE_TLS=false
|
||||
# REDIS_TLS_SKIP_VERIFY=false
|
||||
# REDIS_ENABLE_CIRCUIT_BREAKER=true
|
||||
# REDIS_CIRCUIT_BREAKER_THRESHOLD=5
|
||||
# REDIS_CIRCUIT_BREAKER_TIMEOUT=60
|
||||
# REDIS_ENABLE_HEALTH_CHECK=true
|
||||
# REDIS_HEALTH_CHECK_INTERVAL=30
|
||||
@@ -3,15 +3,20 @@ module github.com/lukaszraczylo/traefikoidc
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/redis/go-redis/v9 v9.14.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.14.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,5 +1,15 @@
|
||||
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
@@ -10,8 +20,12 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
|
||||
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
Vendored
+90
@@ -0,0 +1,90 @@
|
||||
package backends
|
||||
|
||||
import "time"
|
||||
|
||||
// BackendType represents the type of cache backend
|
||||
type BackendType string
|
||||
|
||||
const (
|
||||
BackendTypeMemory BackendType = "memory"
|
||||
BackendTypeRedis BackendType = "redis"
|
||||
BackendTypeHybrid BackendType = "hybrid"
|
||||
|
||||
// Aliases for backward compatibility
|
||||
TypeMemory BackendType = "memory"
|
||||
TypeRedis BackendType = "redis"
|
||||
TypeHybrid BackendType = "hybrid"
|
||||
)
|
||||
|
||||
// Config provides common configuration for cache backends
|
||||
type Config struct {
|
||||
// Type specifies the backend type
|
||||
Type BackendType
|
||||
|
||||
// Memory backend settings
|
||||
MaxSize int
|
||||
MaxMemoryBytes int64
|
||||
CleanupInterval time.Duration
|
||||
|
||||
// Redis backend settings
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
RedisPrefix string
|
||||
PoolSize int
|
||||
|
||||
// Hybrid backend settings
|
||||
L1Config *Config // Memory cache (L1)
|
||||
L2Config *Config // Redis cache (L2)
|
||||
AsyncWrites bool // Write to L2 asynchronously
|
||||
|
||||
// Resilience settings
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
HealthCheckInterval time.Duration
|
||||
|
||||
// Metrics
|
||||
EnableMetrics bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeMemory,
|
||||
MaxSize: 1000,
|
||||
MaxMemoryBytes: 50 * 1024 * 1024, // 50MB
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRedisConfig returns a default configuration for Redis caching
|
||||
func DefaultRedisConfig(addr string) *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeRedis,
|
||||
RedisAddr: addr,
|
||||
RedisDB: 0,
|
||||
RedisPrefix: "traefikoidc:",
|
||||
PoolSize: 10,
|
||||
EnableCircuitBreaker: true,
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHybridConfig returns a default configuration for hybrid caching
|
||||
func DefaultHybridConfig(redisAddr string) *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeHybrid,
|
||||
L1Config: &Config{
|
||||
Type: BackendTypeMemory,
|
||||
MaxSize: 500,
|
||||
MaxMemoryBytes: 10 * 1024 * 1024, // 10MB for L1
|
||||
CleanupInterval: 1 * time.Minute,
|
||||
},
|
||||
L2Config: DefaultRedisConfig(redisAddr),
|
||||
AsyncWrites: true,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
Vendored
+38
@@ -0,0 +1,38 @@
|
||||
package backends
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrBackendClosed is returned when operating on a closed backend
|
||||
ErrBackendClosed = errors.New("cache backend is closed")
|
||||
|
||||
// ErrKeyNotFound is returned when a key doesn't exist
|
||||
ErrKeyNotFound = errors.New("key not found")
|
||||
|
||||
// ErrCacheMiss indicates the requested key was not found in the cache
|
||||
ErrCacheMiss = errors.New("cache miss")
|
||||
|
||||
// ErrBackendUnavailable indicates the cache backend is not available
|
||||
ErrBackendUnavailable = errors.New("cache backend unavailable")
|
||||
|
||||
// ErrInvalidValue indicates the cached value is invalid or corrupted
|
||||
ErrInvalidValue = errors.New("invalid cached value")
|
||||
|
||||
// ErrInvalidTTL is returned when TTL is invalid
|
||||
ErrInvalidTTL = errors.New("invalid TTL")
|
||||
|
||||
// ErrConnectionFailed is returned when connection fails
|
||||
ErrConnectionFailed = errors.New("connection failed")
|
||||
|
||||
// ErrCircuitOpen is returned when circuit breaker is open
|
||||
ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||
|
||||
// ErrTimeout is returned when operation times out
|
||||
ErrTimeout = errors.New("operation timeout")
|
||||
|
||||
// ErrSerializationFailed is returned when serialization fails
|
||||
ErrSerializationFailed = errors.New("serialization failed")
|
||||
|
||||
// ErrDeserializationFailed is returned when deserialization fails
|
||||
ErrDeserializationFailed = errors.New("deserialization failed")
|
||||
)
|
||||
Vendored
+695
@@ -0,0 +1,695 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
|
||||
// It provides automatic failover, async writes for non-critical data, and optimized read paths
|
||||
type HybridBackend struct {
|
||||
primary CacheBackend // L1: Memory cache for fast access
|
||||
secondary CacheBackend // L2: Redis cache for distributed access
|
||||
|
||||
// Configuration
|
||||
syncWriteCacheTypes map[string]bool // Which cache types require synchronous writes
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
|
||||
// Metrics
|
||||
l1Hits atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
l1Writes atomic.Int64
|
||||
l2Writes atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Fallback tracking
|
||||
fallbackMode atomic.Bool // True when operating in degraded mode (L1 only)
|
||||
lastL2Error atomic.Value // Stores last L2 error timestamp
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Logging
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// asyncWriteItem represents an async write operation
|
||||
type asyncWriteItem struct {
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Warnf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// defaultLogger provides a basic logger implementation
|
||||
type defaultLogger struct {
|
||||
*log.Logger
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
|
||||
l.Printf("[DEBUG] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Infof(format string, args ...interface{}) {
|
||||
l.Printf("[INFO] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Warnf(format string, args ...interface{}) {
|
||||
l.Printf("[WARN] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
|
||||
l.Printf("[ERROR] "+format, args...)
|
||||
}
|
||||
|
||||
// HybridConfig provides configuration for the hybrid backend
|
||||
type HybridConfig struct {
|
||||
Primary CacheBackend
|
||||
Secondary CacheBackend
|
||||
SyncWriteCacheTypes map[string]bool // Cache types requiring synchronous L2 writes
|
||||
AsyncBufferSize int
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
|
||||
func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
if config.Primary == nil {
|
||||
return nil, fmt.Errorf("primary (L1) backend is required")
|
||||
}
|
||||
|
||||
if config.Secondary == nil {
|
||||
return nil, fmt.Errorf("secondary (L2) backend is required")
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
config.Logger = &defaultLogger{Logger: log.New(log.Writer(), "[HybridCache] ", log.LstdFlags)}
|
||||
}
|
||||
|
||||
if config.AsyncBufferSize <= 0 {
|
||||
config.AsyncBufferSize = 1000
|
||||
}
|
||||
|
||||
// Default critical cache types that require synchronous writes
|
||||
if config.SyncWriteCacheTypes == nil {
|
||||
config.SyncWriteCacheTypes = map[string]bool{
|
||||
"blacklist": true, // Token blacklist must be immediately consistent
|
||||
"token": true, // Token validation is critical
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
h := &HybridBackend{
|
||||
primary: config.Primary,
|
||||
secondary: config.Secondary,
|
||||
syncWriteCacheTypes: config.SyncWriteCacheTypes,
|
||||
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: config.Logger,
|
||||
}
|
||||
|
||||
// Start async write worker
|
||||
h.wg.Add(1)
|
||||
go h.asyncWriteWorker()
|
||||
|
||||
// Start health monitoring
|
||||
h.wg.Add(1)
|
||||
go h.healthMonitor()
|
||||
|
||||
h.logger.Infof("HybridBackend initialized with L1 (memory) and L2 (Redis) tiers")
|
||||
h.logger.Infof("Sync write cache types: %v", config.SyncWriteCacheTypes)
|
||||
h.logger.Infof("Async write buffer size: %d", config.AsyncBufferSize)
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Set stores a value in both L1 and L2 caches
|
||||
func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
// Always write to L1 first (synchronous)
|
||||
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Warnf("Failed to write to L1 cache: %v", err)
|
||||
// Continue to try L2 even if L1 fails
|
||||
} else {
|
||||
h.l1Writes.Add(1)
|
||||
}
|
||||
|
||||
// Check if we're in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", key)
|
||||
return nil // Don't fail the operation if L2 is down
|
||||
}
|
||||
|
||||
// Determine if this should be a sync or async write based on cache type
|
||||
cacheType := h.extractCacheType(key)
|
||||
requiresSync := h.syncWriteCacheTypes[cacheType]
|
||||
|
||||
if requiresSync {
|
||||
// Synchronous write for critical cache types
|
||||
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", key, err)
|
||||
h.recordL2Error()
|
||||
// Don't fail the operation - L1 write succeeded
|
||||
return nil
|
||||
}
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", key)
|
||||
} else {
|
||||
// Asynchronous write for non-critical cache types
|
||||
select {
|
||||
case h.asyncWriteBuffer <- &asyncWriteItem{
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
ctx: ctx,
|
||||
}:
|
||||
h.logger.Debugf("Queued async write to L2 for key: %s", key)
|
||||
default:
|
||||
// Buffer is full, log and continue
|
||||
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", key)
|
||||
h.errors.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from cache, checking L1 first, then L2
|
||||
func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
// Try L1 first
|
||||
value, ttl, exists, err := h.primary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L1 get error for key %s: %v", key, err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
h.l1Hits.Add(1)
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
|
||||
// Check if we're in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
// Try L2
|
||||
value, ttl, exists, err = h.secondary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L2 get error for key %s: %v", key, err)
|
||||
h.recordL2Error()
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil // Don't propagate L2 errors
|
||||
}
|
||||
|
||||
if !exists {
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Populate L1 cache with value from L2 (write-through on read)
|
||||
// Use goroutine to avoid blocking the read path
|
||||
go func() {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
|
||||
}
|
||||
}()
|
||||
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
|
||||
// Delete removes a key from both L1 and L2 caches
|
||||
func (h *HybridBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
var deleted bool
|
||||
|
||||
// Delete from L1
|
||||
if d, err := h.primary.Delete(ctx, key); err != nil {
|
||||
h.logger.Debugf("Failed to delete from L1 cache: %v", err)
|
||||
} else if d {
|
||||
deleted = true
|
||||
}
|
||||
|
||||
// Delete from L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if d, err := h.secondary.Delete(ctx, key); err != nil {
|
||||
h.logger.Debugf("Failed to delete from L2 cache: %v", err)
|
||||
h.recordL2Error()
|
||||
} else if d {
|
||||
deleted = true
|
||||
}
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in either cache
|
||||
func (h *HybridBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
// Check L1 first
|
||||
if exists, err := h.primary.Exists(ctx, key); err == nil && exists {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if exists, err := h.secondary.Exists(ctx, key); err == nil && exists {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Clear removes all keys from both caches
|
||||
func (h *HybridBackend) Clear(ctx context.Context) error {
|
||||
var lastErr error
|
||||
|
||||
// Clear L1
|
||||
if err := h.primary.Clear(ctx); err != nil {
|
||||
h.logger.Errorf("Failed to clear L1 cache: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
// Clear L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if err := h.secondary.Clear(ctx); err != nil {
|
||||
h.logger.Errorf("Failed to clear L2 cache: %v", err)
|
||||
h.recordL2Error()
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetStats returns statistics for the hybrid cache
|
||||
func (h *HybridBackend) GetStats() map[string]interface{} {
|
||||
l1Hits := h.l1Hits.Load()
|
||||
l2Hits := h.l2Hits.Load()
|
||||
misses := h.misses.Load()
|
||||
total := l1Hits + l2Hits + misses
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"type": TypeHybrid,
|
||||
"l1_hits": l1Hits,
|
||||
"l2_hits": l2Hits,
|
||||
"misses": misses,
|
||||
"total": total,
|
||||
"l1_writes": h.l1Writes.Load(),
|
||||
"l2_writes": h.l2Writes.Load(),
|
||||
"errors": h.errors.Load(),
|
||||
"fallback_mode": h.fallbackMode.Load(),
|
||||
}
|
||||
|
||||
if total > 0 {
|
||||
stats["l1_hit_rate"] = float64(l1Hits) / float64(total)
|
||||
stats["l2_hit_rate"] = float64(l2Hits) / float64(total)
|
||||
stats["overall_hit_rate"] = float64(l1Hits+l2Hits) / float64(total)
|
||||
}
|
||||
|
||||
// Add sub-backend stats
|
||||
stats["l1_stats"] = h.primary.GetStats()
|
||||
stats["l2_stats"] = h.secondary.GetStats()
|
||||
|
||||
// Add last L2 error time if available
|
||||
if lastErr := h.lastL2Error.Load(); lastErr != nil {
|
||||
if t, ok := lastErr.(time.Time); ok {
|
||||
stats["last_l2_error"] = t.Format(time.RFC3339)
|
||||
stats["seconds_since_l2_error"] = time.Since(t).Seconds()
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks if both backends are healthy
|
||||
func (h *HybridBackend) Ping(ctx context.Context) error {
|
||||
// Check L1
|
||||
if err := h.primary.Ping(ctx); err != nil {
|
||||
return fmt.Errorf("L1 ping failed: %w", err)
|
||||
}
|
||||
|
||||
// Check L2 (but don't fail if it's down)
|
||||
if err := h.secondary.Ping(ctx); err != nil {
|
||||
h.logger.Warnf("L2 ping failed: %v", err)
|
||||
h.recordL2Error()
|
||||
// Don't return error - we can operate with L1 only
|
||||
} else {
|
||||
// L2 is healthy, clear fallback mode if it was set
|
||||
if h.fallbackMode.CompareAndSwap(true, false) {
|
||||
h.logger.Infof("L2 backend recovered, exiting fallback mode")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close shuts down the hybrid backend
|
||||
func (h *HybridBackend) Close() error {
|
||||
// Cancel context to stop workers
|
||||
h.cancel()
|
||||
|
||||
// Close async write channel
|
||||
close(h.asyncWriteBuffer)
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
h.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Workers finished
|
||||
case <-time.After(5 * time.Second):
|
||||
h.logger.Warnf("Timeout waiting for workers to finish")
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
|
||||
// Close backends
|
||||
if err := h.primary.Close(); err != nil {
|
||||
h.logger.Errorf("Failed to close L1 backend: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if err := h.secondary.Close(); err != nil {
|
||||
h.logger.Errorf("Failed to close L2 backend: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
h.logger.Infof("HybridBackend closed")
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values efficiently
|
||||
func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
results := make(map[string][]byte, len(keys))
|
||||
missingKeys := make([]string, 0)
|
||||
|
||||
// Try L1 first for all keys
|
||||
for _, key := range keys {
|
||||
if value, _, exists, _ := h.primary.Get(ctx, key); exists {
|
||||
results[key] = value
|
||||
h.l1Hits.Add(1)
|
||||
} else {
|
||||
missingKeys = append(missingKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
// If all found in L1 or in fallback mode, return
|
||||
if len(missingKeys) == 0 || h.fallbackMode.Load() {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Try L2 for missing keys using batch operation if available
|
||||
if batcher, ok := h.secondary.(interface {
|
||||
GetMany(context.Context, []string) (map[string][]byte, error)
|
||||
}); ok {
|
||||
l2Results, err := batcher.GetMany(ctx, missingKeys)
|
||||
if err != nil {
|
||||
h.logger.Debugf("L2 batch get error: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
for key, value := range l2Results {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
|
||||
}(key, value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to individual gets
|
||||
for _, key := range missingKeys {
|
||||
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte, t time.Duration) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, t)
|
||||
}(key, value, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count misses for keys not found anywhere
|
||||
for _, key := range keys {
|
||||
if _, found := results[key]; !found {
|
||||
h.misses.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// SetMany stores multiple key-value pairs efficiently
|
||||
func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write to L1 first
|
||||
for key, value := range items {
|
||||
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to write to L1 in batch: %v", err)
|
||||
} else {
|
||||
h.l1Writes.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Skip L2 if in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if L2 supports batch operations
|
||||
if batcher, ok := h.secondary.(interface {
|
||||
SetMany(context.Context, map[string][]byte, time.Duration) error
|
||||
}); ok {
|
||||
if err := batcher.SetMany(ctx, items, ttl); err != nil {
|
||||
h.logger.Warnf("Failed to batch write to L2: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(int64(len(items)))
|
||||
}
|
||||
} else {
|
||||
// Fallback to individual sets
|
||||
for key, value := range items {
|
||||
cacheType := h.extractCacheType(key)
|
||||
if h.syncWriteCacheTypes[cacheType] {
|
||||
// Sync write for critical types
|
||||
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to write to L2: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(1)
|
||||
}
|
||||
} else {
|
||||
// Async write for non-critical types
|
||||
select {
|
||||
case h.asyncWriteBuffer <- &asyncWriteItem{
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
ctx: ctx,
|
||||
}:
|
||||
// Queued
|
||||
default:
|
||||
h.logger.Warnf("Async buffer full for batch write")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// asyncWriteWorker processes asynchronous writes to L2
|
||||
func (h *HybridBackend) asyncWriteWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
// Drain remaining items with best effort
|
||||
for len(h.asyncWriteBuffer) > 0 {
|
||||
select {
|
||||
case item := <-h.asyncWriteBuffer:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
_ = h.secondary.Set(ctx, item.key, item.value, item.ttl)
|
||||
cancel()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
case item, ok := <-h.asyncWriteBuffer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Perform the write with a timeout
|
||||
writeCtx, cancel := context.WithTimeout(item.ctx, 500*time.Millisecond)
|
||||
if err := h.secondary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("Async write to L2 failed for key %s: %v", item.key, err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Async write to L2 completed for key: %s", item.key)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthMonitor periodically checks L2 health and manages fallback mode
|
||||
func (h *HybridBackend) healthMonitor() {
|
||||
defer h.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
if err := h.secondary.Ping(ctx); err != nil {
|
||||
if !h.fallbackMode.Load() {
|
||||
h.fallbackMode.Store(true)
|
||||
h.logger.Warnf("L2 backend unhealthy, entering fallback mode: %v", err)
|
||||
}
|
||||
} else {
|
||||
if h.fallbackMode.CompareAndSwap(true, false) {
|
||||
h.logger.Infof("L2 backend healthy, exiting fallback mode")
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordL2Error records the timestamp of an L2 error
|
||||
func (h *HybridBackend) recordL2Error() {
|
||||
h.lastL2Error.Store(time.Now())
|
||||
|
||||
// Check if we should enter fallback mode based on recent errors
|
||||
if !h.fallbackMode.Load() {
|
||||
// Simple heuristic: if we've had an error in the last second, consider L2 unhealthy
|
||||
if lastErr := h.lastL2Error.Load(); lastErr != nil {
|
||||
if t, ok := lastErr.(time.Time); ok && time.Since(t) < time.Second {
|
||||
h.fallbackMode.Store(true)
|
||||
h.logger.Warnf("Multiple L2 errors detected, entering fallback mode")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractCacheType attempts to determine the cache type from the key
|
||||
func (h *HybridBackend) extractCacheType(key string) string {
|
||||
// Simple heuristic based on key prefixes
|
||||
// This should match the actual cache type strategy in the main application
|
||||
|
||||
if len(key) > 10 {
|
||||
prefix := key[:10]
|
||||
switch {
|
||||
case contains(prefix, "blacklist"):
|
||||
return "blacklist"
|
||||
case contains(prefix, "token"):
|
||||
return "token"
|
||||
case contains(prefix, "metadata"):
|
||||
return "metadata"
|
||||
case contains(prefix, "jwk"):
|
||||
return "jwk"
|
||||
case contains(prefix, "session"):
|
||||
return "session"
|
||||
case contains(prefix, "introspect"):
|
||||
return "introspection"
|
||||
}
|
||||
}
|
||||
|
||||
return "general"
|
||||
}
|
||||
|
||||
// contains checks if a string contains a substring (case-insensitive)
|
||||
func contains(s, substr string) bool {
|
||||
if len(substr) > len(s) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
match := true
|
||||
for j := 0; j < len(substr); j++ {
|
||||
if toLower(s[i+j]) != toLower(substr[j]) {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// toLower converts a byte to lowercase
|
||||
func toLower(b byte) byte {
|
||||
if b >= 'A' && b <= 'Z' {
|
||||
return b + 32
|
||||
}
|
||||
return b
|
||||
}
|
||||
Vendored
+133
@@ -0,0 +1,133 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheBackend defines the interface for all cache backend implementations
|
||||
// Implementations include: MemoryBackend, RedisBackend, and HybridBackend
|
||||
type CacheBackend interface {
|
||||
// Set stores a value in the cache with the specified TTL
|
||||
// Returns an error if the operation fails
|
||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
// Returns: value, remaining TTL, exists flag, and error
|
||||
// If the key doesn't exist, exists will be false
|
||||
Get(ctx context.Context, key string) (value []byte, ttl time.Duration, exists bool, err error)
|
||||
|
||||
// Delete removes a key from the cache
|
||||
// Returns true if the key was deleted, false if it didn't exist
|
||||
Delete(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Clear removes all keys from the cache
|
||||
Clear(ctx context.Context) error
|
||||
|
||||
// GetStats returns cache statistics
|
||||
// Stats include: hits, misses, size, memory usage, etc.
|
||||
GetStats() map[string]interface{}
|
||||
|
||||
// Close shuts down the cache backend and releases resources
|
||||
Close() error
|
||||
|
||||
// Ping checks if the backend is healthy and responsive
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
||||
// BackendStats represents statistics for a cache backend
|
||||
type BackendStats struct {
|
||||
// Type is the backend type
|
||||
Type BackendType
|
||||
|
||||
// Hits is the number of cache hits
|
||||
Hits int64
|
||||
|
||||
// Misses is the number of cache misses
|
||||
Misses int64
|
||||
|
||||
// Sets is the number of set operations
|
||||
Sets int64
|
||||
|
||||
// Deletes is the number of delete operations
|
||||
Deletes int64
|
||||
|
||||
// Errors is the number of errors
|
||||
Errors int64
|
||||
|
||||
// Evictions is the number of evicted items
|
||||
Evictions int64
|
||||
|
||||
// CurrentSize is the current number of items in cache
|
||||
CurrentSize int64
|
||||
|
||||
// MaxSize is the maximum number of items (0 means unlimited)
|
||||
MaxSize int64
|
||||
|
||||
// MemoryUsage is the approximate memory usage in bytes
|
||||
MemoryUsage int64
|
||||
|
||||
// AverageGetLatency is the average latency for get operations
|
||||
AverageGetLatency time.Duration
|
||||
|
||||
// AverageSetLatency is the average latency for set operations
|
||||
AverageSetLatency time.Duration
|
||||
|
||||
// LastError is the last error encountered
|
||||
LastError string
|
||||
|
||||
// LastErrorTime is when the last error occurred
|
||||
LastErrorTime time.Time
|
||||
|
||||
// Uptime is how long the backend has been running
|
||||
Uptime time.Duration
|
||||
|
||||
// StartTime is when the backend was started
|
||||
StartTime time.Time
|
||||
}
|
||||
|
||||
// BackendCapabilities describes the capabilities of a cache backend
|
||||
type BackendCapabilities struct {
|
||||
// Distributed indicates if the backend is distributed across multiple instances
|
||||
Distributed bool
|
||||
|
||||
// Persistent indicates if the backend persists data across restarts
|
||||
Persistent bool
|
||||
|
||||
// Eviction indicates if the backend supports automatic eviction
|
||||
Eviction bool
|
||||
|
||||
// TTL indicates if the backend supports TTL (time-to-live)
|
||||
TTL bool
|
||||
|
||||
// MaxKeySize is the maximum size of a key in bytes (0 = unlimited)
|
||||
MaxKeySize int64
|
||||
|
||||
// MaxValueSize is the maximum size of a value in bytes (0 = unlimited)
|
||||
MaxValueSize int64
|
||||
|
||||
// MaxKeys is the maximum number of keys (0 = unlimited)
|
||||
MaxKeys int64
|
||||
|
||||
// SupportsExpire indicates if the backend supports expiration
|
||||
SupportsExpire bool
|
||||
|
||||
// SupportsMultiGet indicates if the backend supports batch get operations
|
||||
SupportsMultiGet bool
|
||||
|
||||
// SupportsTransaction indicates if the backend supports transactions
|
||||
SupportsTransaction bool
|
||||
|
||||
// SupportsCompression indicates if the backend supports compression
|
||||
SupportsCompression bool
|
||||
|
||||
// RequiresSerialize indicates if values must be serialized
|
||||
RequiresSerialize bool
|
||||
|
||||
// AtomicOperations indicates if the backend supports atomic operations
|
||||
AtomicOperations bool
|
||||
}
|
||||
+402
@@ -0,0 +1,402 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCacheBackendContract defines a set of tests that all CacheBackend implementations must pass
|
||||
// This ensures that Memory, Redis, and Hybrid backends all behave consistently
|
||||
func TestCacheBackendContract(t *testing.T) {
|
||||
// Test suite will be run against each backend type
|
||||
t.Run("MemoryBackend", func(t *testing.T) {
|
||||
backend := setupMemoryBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
|
||||
t.Run("RedisBackend", func(t *testing.T) {
|
||||
backend := setupRedisBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
|
||||
t.Run("HybridBackend", func(t *testing.T) {
|
||||
backend := setupHybridBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
}
|
||||
|
||||
// runContractTests executes all contract tests against a backend
|
||||
func runContractTests(t *testing.T, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("BasicSetGet", func(t *testing.T) {
|
||||
testBasicSetGet(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
testGetNonExistent(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("UpdateExisting", func(t *testing.T) {
|
||||
testUpdateExisting(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
testDelete(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("DeleteNonExistent", func(t *testing.T) {
|
||||
testDeleteNonExistent(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
testExists(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("TTLExpiration", func(t *testing.T) {
|
||||
testTTLExpiration(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
testClear(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
testPing(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
testStats(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentAccess", func(t *testing.T) {
|
||||
testConcurrentAccess(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("LargeValues", func(t *testing.T) {
|
||||
testLargeValues(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("EmptyValues", func(t *testing.T) {
|
||||
testEmptyValues(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("SpecialCharactersInKeys", func(t *testing.T) {
|
||||
testSpecialCharactersInKeys(t, ctx, backend)
|
||||
})
|
||||
}
|
||||
|
||||
// testBasicSetGet verifies basic set and get operations
|
||||
func testBasicSetGet(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "test-key-1"
|
||||
value := []byte("test-value-1")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
// Set value
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err, "Set should not return error")
|
||||
|
||||
// Get value
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err, "Get should not return error")
|
||||
assert.True(t, exists, "Key should exist")
|
||||
assert.Equal(t, value, retrieved, "Retrieved value should match")
|
||||
assert.Greater(t, remainingTTL, 50*time.Second, "TTL should be close to original")
|
||||
assert.LessOrEqual(t, remainingTTL, ttl, "TTL should not exceed original")
|
||||
}
|
||||
|
||||
// testGetNonExistent verifies behavior when getting non-existent keys
|
||||
func testGetNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "non-existent-key"
|
||||
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err, "Get should not return error for non-existent key")
|
||||
assert.False(t, exists, "Key should not exist")
|
||||
assert.Nil(t, retrieved, "Value should be nil")
|
||||
assert.Equal(t, time.Duration(0), ttl, "TTL should be zero")
|
||||
}
|
||||
|
||||
// testUpdateExisting verifies updating an existing key
|
||||
func testUpdateExisting(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
// Set initial value
|
||||
err := backend.Set(ctx, key, value1, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update value
|
||||
err = backend.Set(ctx, key, value2, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated value
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved, "Value should be updated")
|
||||
}
|
||||
|
||||
// testDelete verifies delete operation
|
||||
func testDelete(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "delete-key"
|
||||
value := []byte("delete-value")
|
||||
|
||||
// Set value
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Delete
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted, "Delete should return true for existing key")
|
||||
|
||||
// Verify deleted
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after delete")
|
||||
}
|
||||
|
||||
// testDeleteNonExistent verifies deleting non-existent keys
|
||||
func testDeleteNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "non-existent-delete-key"
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, deleted, "Delete should return false for non-existent key")
|
||||
}
|
||||
|
||||
// testExists verifies the Exists operation
|
||||
func testExists(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "exists-key"
|
||||
value := []byte("exists-value")
|
||||
|
||||
// Check non-existent key
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist initially")
|
||||
|
||||
// Set value
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check existing key
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key should exist after Set")
|
||||
}
|
||||
|
||||
// testTTLExpiration verifies TTL expiration behavior
|
||||
func testTTLExpiration(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "ttl-key"
|
||||
value := []byte("ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
// Set with short TTL
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key should exist immediately after Set")
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify expired
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after TTL expiration")
|
||||
}
|
||||
|
||||
// testClear verifies Clear operation
|
||||
func testClear(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
// Set multiple keys
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Clear all
|
||||
err := backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all keys are gone
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after Clear")
|
||||
}
|
||||
}
|
||||
|
||||
// testPing verifies Ping operation
|
||||
func testPing(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
err := backend.Ping(ctx)
|
||||
assert.NoError(t, err, "Ping should succeed on healthy backend")
|
||||
}
|
||||
|
||||
// testStats verifies GetStats operation
|
||||
func testStats(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
stats := backend.GetStats()
|
||||
assert.NotNil(t, stats, "Stats should not be nil")
|
||||
|
||||
// Stats should contain basic metrics
|
||||
_, hasHits := stats["hits"]
|
||||
_, hasMisses := stats["misses"]
|
||||
assert.True(t, hasHits || hasMisses, "Stats should contain hits or misses")
|
||||
}
|
||||
|
||||
// testConcurrentAccess verifies thread safety
|
||||
func testConcurrentAccess(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 10
|
||||
iterations := 20
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// testLargeValues verifies handling of large values
|
||||
func testLargeValues(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "large-value-key"
|
||||
value := GenerateLargeValue(1024 * 1024) // 1MB
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle large values")
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(value), len(retrieved), "Large value should be retrieved intact")
|
||||
}
|
||||
|
||||
// testEmptyValues verifies handling of empty values
|
||||
func testEmptyValues(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "empty-value-key"
|
||||
value := []byte{}
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle empty values")
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Empty value should exist")
|
||||
assert.Equal(t, 0, len(retrieved), "Retrieved value should be empty")
|
||||
}
|
||||
|
||||
// testSpecialCharactersInKeys verifies handling of special characters in keys
|
||||
func testSpecialCharactersInKeys(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
specialKeys := []string{
|
||||
"key:with:colons",
|
||||
"key/with/slashes",
|
||||
"key-with-dashes",
|
||||
"key_with_underscores",
|
||||
"key.with.dots",
|
||||
"key|with|pipes",
|
||||
}
|
||||
|
||||
for _, key := range specialKeys {
|
||||
value := []byte(fmt.Sprintf("value-for-%s", key))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle special character in key: %s", key)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key with special characters should exist: %s", key)
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions to setup different backend types
|
||||
// These will be implemented in respective test files
|
||||
|
||||
func setupMemoryBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
// This will be implemented in memory_test.go
|
||||
// For now, return nil to allow compilation
|
||||
t.Skip("MemoryBackend implementation pending")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupRedisBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
// This will be implemented in redis_test.go
|
||||
// For now, return nil to allow compilation
|
||||
t.Skip("RedisBackend implementation pending")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupHybridBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
// This will be implemented in hybrid_test.go
|
||||
// For now, return nil to allow compilation
|
||||
t.Skip("HybridBackend implementation pending")
|
||||
return nil
|
||||
}
|
||||
Vendored
+516
@@ -0,0 +1,516 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// memoryCacheItem represents an item in the memory cache
|
||||
type memoryCacheItem struct {
|
||||
key string
|
||||
value interface{}
|
||||
expiresAt time.Time
|
||||
createdAt time.Time
|
||||
accessedAt time.Time
|
||||
accessCount int64
|
||||
size int64
|
||||
element *list.Element // for LRU tracking
|
||||
}
|
||||
|
||||
// isExpired checks if the item is expired
|
||||
func (item *memoryCacheItem) isExpired() bool {
|
||||
if item.expiresAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(item.expiresAt)
|
||||
}
|
||||
|
||||
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
|
||||
type MemoryCacheBackend struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
currentSize int64
|
||||
currentMemory int64
|
||||
|
||||
// Statistics
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
sets atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Latency tracking
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
|
||||
// Status
|
||||
startTime time.Time
|
||||
lastError string
|
||||
lastErrorTime time.Time
|
||||
cleanupTicker *time.Ticker
|
||||
cleanupDone chan bool
|
||||
closed atomic.Bool
|
||||
|
||||
// Configuration
|
||||
cleanupInterval time.Duration
|
||||
evictionPolicy string // "lru", "lfu", "fifo"
|
||||
}
|
||||
|
||||
// NewMemoryCacheBackend creates a new memory cache backend
|
||||
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
|
||||
if maxSize <= 0 {
|
||||
maxSize = 10000 // Default to 10k items
|
||||
}
|
||||
if maxMemory <= 0 {
|
||||
maxMemory = 100 * 1024 * 1024 // Default to 100MB
|
||||
}
|
||||
if cleanupInterval <= 0 {
|
||||
cleanupInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
m := &MemoryCacheBackend{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
startTime: time.Now(),
|
||||
cleanupInterval: cleanupInterval,
|
||||
evictionPolicy: "lru",
|
||||
cleanupDone: make(chan bool),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
m.cleanupTicker = time.NewTicker(cleanupInterval)
|
||||
go m.cleanupLoop()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired items
|
||||
func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-m.cleanupTicker.C:
|
||||
m.cleanupExpired()
|
||||
case <-m.cleanupDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpired removes all expired items from the cache
|
||||
func (m *MemoryCacheBackend) cleanupExpired() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var keysToDelete []string
|
||||
for key, item := range m.items {
|
||||
if item.isExpired() {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range keysToDelete {
|
||||
m.deleteItemLocked(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{}, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(start).Nanoseconds()
|
||||
m.totalGetTime.Add(duration)
|
||||
m.getCount.Add(1)
|
||||
}()
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
m.mu.Lock()
|
||||
m.deleteItemLocked(key)
|
||||
m.mu.Unlock()
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
// Update access time and count
|
||||
m.mu.Lock()
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
// Move to front of LRU list
|
||||
if m.evictionPolicy == "lru" && item.element != nil {
|
||||
m.lruList.MoveToFront(item.element)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.hits.Add(1)
|
||||
return item.value, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with optional TTL
|
||||
func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(start).Nanoseconds()
|
||||
m.totalSetTime.Add(duration)
|
||||
m.setCount.Add(1)
|
||||
}()
|
||||
|
||||
// Calculate item size (simplified estimation)
|
||||
itemSize := int64(len(key)) + estimateValueSize(value)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Check if we need to evict items
|
||||
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
|
||||
m.evictLocked()
|
||||
}
|
||||
|
||||
// Check if key exists
|
||||
if oldItem, exists := m.items[key]; exists {
|
||||
m.currentMemory -= oldItem.size
|
||||
if oldItem.element != nil {
|
||||
m.lruList.Remove(oldItem.element)
|
||||
}
|
||||
} else {
|
||||
m.currentSize++
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
var expiresAt time.Time
|
||||
if ttl > 0 {
|
||||
expiresAt = now.Add(ttl)
|
||||
}
|
||||
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: itemSize,
|
||||
}
|
||||
|
||||
// Add to LRU list
|
||||
if m.evictionPolicy == "lru" {
|
||||
item.element = m.lruList.PushFront(item)
|
||||
}
|
||||
|
||||
m.items[key] = item
|
||||
m.currentMemory += itemSize
|
||||
m.sets.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.items[key]; !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.deleteItemLocked(key)
|
||||
m.deletes.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
|
||||
if item, exists := m.items[key]; exists {
|
||||
m.currentMemory -= item.size
|
||||
m.currentSize--
|
||||
if item.element != nil {
|
||||
m.lruList.Remove(item.element)
|
||||
}
|
||||
delete(m.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// evictLocked evicts items based on the eviction policy (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) evictLocked() {
|
||||
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
|
||||
// Evict least recently used item
|
||||
element := m.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
m.deleteItemLocked(item.key)
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if m.closed.Load() {
|
||||
return false, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return !item.isExpired(), nil
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache
|
||||
func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.items = make(map[string]*memoryCacheItem)
|
||||
m.lruList = list.New()
|
||||
m.currentSize = 0
|
||||
m.currentMemory = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keys returns all keys matching the pattern (use "*" for all keys)
|
||||
func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range m.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// Size returns the number of items in the cache
|
||||
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.currentSize, nil
|
||||
}
|
||||
|
||||
// TTL returns the remaining time-to-live for a key
|
||||
func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return 0, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, nil // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return remaining, nil
|
||||
}
|
||||
|
||||
// Expire updates the TTL for an existing key
|
||||
func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Duration) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
item, exists := m.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
return ErrCacheMiss
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the cache backend
|
||||
func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
lastError := m.lastError
|
||||
lastErrorTime := m.lastErrorTime
|
||||
m.mu.RUnlock()
|
||||
|
||||
avgGetLatency := time.Duration(0)
|
||||
if getCount := m.getCount.Load(); getCount > 0 {
|
||||
avgGetLatency = time.Duration(m.totalGetTime.Load() / getCount)
|
||||
}
|
||||
|
||||
avgSetLatency := time.Duration(0)
|
||||
if setCount := m.setCount.Load(); setCount > 0 {
|
||||
avgSetLatency = time.Duration(m.totalSetTime.Load() / setCount)
|
||||
}
|
||||
|
||||
return &BackendStats{
|
||||
Type: TypeMemory,
|
||||
Hits: m.hits.Load(),
|
||||
Misses: m.misses.Load(),
|
||||
Sets: m.sets.Load(),
|
||||
Deletes: m.deletes.Load(),
|
||||
Errors: m.errors.Load(),
|
||||
Evictions: m.evictions.Load(),
|
||||
CurrentSize: m.currentSize,
|
||||
MaxSize: m.maxSize,
|
||||
MemoryUsage: m.currentMemory,
|
||||
AverageGetLatency: avgGetLatency,
|
||||
AverageSetLatency: avgSetLatency,
|
||||
LastError: lastError,
|
||||
LastErrorTime: lastErrorTime,
|
||||
Uptime: time.Since(m.startTime),
|
||||
StartTime: m.startTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy
|
||||
func (m *MemoryCacheBackend) Ping(ctx context.Context) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the backend and releases resources
|
||||
func (m *MemoryCacheBackend) Close() error {
|
||||
if m.closed.Swap(true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
m.cleanupTicker.Stop()
|
||||
close(m.cleanupDone)
|
||||
|
||||
m.mu.Lock()
|
||||
m.items = nil
|
||||
m.lruList = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy
|
||||
func (m *MemoryCacheBackend) IsHealthy() bool {
|
||||
return !m.closed.Load()
|
||||
}
|
||||
|
||||
// Type returns the backend type
|
||||
func (m *MemoryCacheBackend) Type() BackendType {
|
||||
return TypeMemory
|
||||
}
|
||||
|
||||
// Capabilities returns the backend capabilities
|
||||
func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
|
||||
return &BackendCapabilities{
|
||||
Distributed: false,
|
||||
Persistent: false,
|
||||
Eviction: true,
|
||||
TTL: true,
|
||||
MaxKeySize: 1024, // 1KB
|
||||
MaxValueSize: 10485760, // 10MB
|
||||
MaxKeys: m.maxSize,
|
||||
SupportsExpire: true,
|
||||
SupportsMultiGet: true,
|
||||
SupportsTransaction: false,
|
||||
SupportsCompression: false,
|
||||
RequiresSerialize: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// estimateValueSize estimates the size of a value in bytes
|
||||
func estimateValueSize(value interface{}) int64 {
|
||||
// This is a simplified estimation
|
||||
// In production, you might want to use a more accurate method
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return int64(len(v))
|
||||
case []byte:
|
||||
return int64(len(v))
|
||||
case int, int32, int64, uint, uint32, uint64:
|
||||
return 8
|
||||
case float32, float64:
|
||||
return 8
|
||||
case bool:
|
||||
return 1
|
||||
default:
|
||||
// For complex types, use a default estimate
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
// matchPattern checks if a key matches a pattern (simplified glob matching)
|
||||
func matchPattern(pattern, key string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
// Simplified pattern matching - in production, use a proper glob library
|
||||
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
|
||||
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
|
||||
}
|
||||
+501
@@ -0,0 +1,501 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMemoryBackend_BasicOperations tests basic CRUD operations
|
||||
func TestMemoryBackend_BasicOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetAndGet", func(t *testing.T) {
|
||||
key := "test-key"
|
||||
value := []byte("test-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
assert.Greater(t, remainingTTL, 50*time.Second)
|
||||
assert.LessOrEqual(t, remainingTTL, ttl)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
_, _, exists, err := backend.Get(ctx, "non-existent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
key := "delete-key"
|
||||
value := []byte("delete-value")
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("DeleteNonExistent", func(t *testing.T) {
|
||||
deleted, err := backend.Delete(ctx, "non-existent-delete")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, deleted)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
key := "exists-key"
|
||||
value := []byte("exists-value")
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
// Add multiple items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
stats := backend.GetStats()
|
||||
size := stats["size"].(int64)
|
||||
assert.Equal(t, int64(0), size)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_TTLExpiration tests TTL and expiration
|
||||
func TestMemoryBackend_TTLExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.CleanupInterval = 50 * time.Millisecond
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ShortTTL", func(t *testing.T) {
|
||||
key := "short-ttl-key"
|
||||
value := []byte("short-ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be expired
|
||||
_, _, exists, err = backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("TTLDecrement", func(t *testing.T) {
|
||||
key := "ttl-decrement-key"
|
||||
value := []byte("ttl-decrement-value")
|
||||
ttl := 2 * time.Second
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check TTL immediately
|
||||
_, ttl1, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Check TTL again - should be less
|
||||
_, ttl2, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Less(t, ttl2, ttl1, "TTL should decrease over time")
|
||||
})
|
||||
|
||||
t.Run("CleanupExpiredItems", func(t *testing.T) {
|
||||
// Set multiple items with short TTL
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("cleanup-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 50*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Wait for cleanup to run
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// All items should be cleaned up
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expired items should be cleaned up")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_LRUEviction tests LRU eviction
|
||||
func TestMemoryBackend_LRUEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 5
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill cache to max size
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("lru-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("lru-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Access first item to make it most recently used
|
||||
_, _, exists, err := backend.Get(ctx, "lru-key-0")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Add a new item - should evict lru-key-1 (least recently used)
|
||||
err = backend.Set(ctx, "lru-key-new", []byte("new-value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// lru-key-0 should still exist (was accessed recently)
|
||||
exists, err = backend.Exists(ctx, "lru-key-0")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Recently accessed item should not be evicted")
|
||||
|
||||
// lru-key-1 should be evicted
|
||||
exists, err = backend.Exists(ctx, "lru-key-1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Least recently used item should be evicted")
|
||||
|
||||
// Check eviction count
|
||||
stats := backend.GetStats()
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have evictions")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_MemoryLimit tests memory-based eviction
|
||||
func TestMemoryBackend_MemoryLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100
|
||||
config.MaxMemoryBytes = 1024 // 1KB limit
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items until memory limit is reached
|
||||
largeValue := make([]byte, 512) // 512 bytes each
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("mem-key-%d", i)
|
||||
err := backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
stats := backend.GetStats()
|
||||
memory := stats["memory"].(int64)
|
||||
assert.LessOrEqual(t, memory, config.MaxMemoryBytes, "Memory should not exceed limit")
|
||||
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have memory-based evictions")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_ConcurrentAccess tests thread safety
|
||||
func TestMemoryBackend_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
// Random deletes
|
||||
if j%5 == 0 {
|
||||
backend.Delete(ctx, key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify stats are consistent
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
assert.Greater(t, hits+misses, int64(0), "Should have cache operations")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_UpdateExisting tests updating existing keys
|
||||
func TestMemoryBackend_UpdateExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
|
||||
// Set original
|
||||
err = backend.Set(ctx, key, value1, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update
|
||||
err = backend.Set(ctx, key, value2, 2*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved)
|
||||
assert.Greater(t, ttl, 1*time.Minute, "TTL should be updated")
|
||||
|
||||
// Size should not increase (same key)
|
||||
stats := backend.GetStats()
|
||||
size := stats["size"].(int64)
|
||||
assert.Equal(t, int64(1), size, "Size should be 1 for one key")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Stats tests statistics tracking
|
||||
func TestMemoryBackend_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := backend.GetStats()
|
||||
assert.Equal(t, int64(0), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(0), stats["misses"].(int64))
|
||||
|
||||
// Add items and track hits/misses
|
||||
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
backend.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
|
||||
// Hit
|
||||
backend.Get(ctx, "key1")
|
||||
// Miss
|
||||
backend.Get(ctx, "non-existent")
|
||||
|
||||
stats = backend.GetStats()
|
||||
assert.Equal(t, int64(1), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(1), stats["misses"].(int64))
|
||||
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
assert.InDelta(t, 0.5, hitRate, 0.01)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_EmptyValues tests handling of empty values
|
||||
func TestMemoryBackend_EmptyValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "empty-key"
|
||||
emptyValue := []byte{}
|
||||
|
||||
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 0, len(retrieved))
|
||||
}
|
||||
|
||||
// TestMemoryBackend_LargeValues tests handling of large values
|
||||
func TestMemoryBackend_LargeValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxMemoryBytes = 10 * 1024 * 1024 // 10MB
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "large-key"
|
||||
largeValue := make([]byte, 1024*1024) // 1MB
|
||||
|
||||
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(largeValue), len(retrieved))
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Close tests proper cleanup on close
|
||||
func TestMemoryBackend_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add some items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("close-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("close-value-%d", i))
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Close
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Operations after close should fail
|
||||
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
_, _, _, err = backend.Get(ctx, "close-key-0")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
// Closing again should be safe
|
||||
err = backend.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Ping tests ping operation
|
||||
func TestMemoryBackend_Ping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = backend.Ping(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Close and ping should fail
|
||||
backend.Close()
|
||||
err = backend.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_ValueIsolation tests that returned values are isolated
|
||||
func TestMemoryBackend_ValueIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "isolation-key"
|
||||
originalValue := []byte("original-value")
|
||||
|
||||
err = backend.Set(ctx, key, originalValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get value and modify it
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Modify retrieved value
|
||||
if len(retrieved) > 0 {
|
||||
retrieved[0] = 'X'
|
||||
}
|
||||
|
||||
// Get again - should be unchanged
|
||||
retrieved2, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, originalValue, retrieved2, "Original value should not be modified")
|
||||
}
|
||||
+153
@@ -0,0 +1,153 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryBackend wraps MemoryCacheBackend to implement the CacheBackend interface
|
||||
type MemoryBackend struct {
|
||||
*MemoryCacheBackend
|
||||
}
|
||||
|
||||
// NewMemoryBackend creates a new memory backend from a config
|
||||
func NewMemoryBackend(config *Config) (*MemoryBackend, error) {
|
||||
maxSize := int64(config.MaxSize)
|
||||
if maxSize <= 0 {
|
||||
maxSize = 1000
|
||||
}
|
||||
|
||||
cacheBackend := NewMemoryCacheBackend(maxSize, config.MaxMemoryBytes, config.CleanupInterval)
|
||||
return &MemoryBackend{
|
||||
MemoryCacheBackend: cacheBackend,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with the specified TTL
|
||||
func (m *MemoryBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
err := m.MemoryCacheBackend.Set(ctx, key, value, ttl)
|
||||
if err == ErrBackendUnavailable {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
val, err := m.MemoryCacheBackend.Get(ctx, key)
|
||||
if err != nil {
|
||||
if err == ErrCacheMiss {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
if err == ErrBackendUnavailable {
|
||||
return nil, 0, false, ErrBackendClosed
|
||||
}
|
||||
return nil, 0, false, err
|
||||
}
|
||||
|
||||
// Get the item directly to check TTL
|
||||
m.MemoryCacheBackend.mu.RLock()
|
||||
item, exists := m.MemoryCacheBackend.items[key]
|
||||
m.MemoryCacheBackend.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
var ttl time.Duration
|
||||
if !item.expiresAt.IsZero() {
|
||||
ttl = time.Until(item.expiresAt)
|
||||
if ttl < 0 {
|
||||
ttl = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Convert interface{} to []byte
|
||||
var valueBytes []byte
|
||||
if val != nil {
|
||||
if bytes, ok := val.([]byte); ok {
|
||||
valueBytes = bytes
|
||||
} else {
|
||||
// If it's not already []byte, we might need to handle other types
|
||||
// For now, we'll just return an error
|
||||
return nil, 0, false, ErrInvalidValue
|
||||
}
|
||||
}
|
||||
|
||||
return valueBytes, ttl, true, nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
// Check if key exists first
|
||||
exists, err := m.MemoryCacheBackend.Exists(ctx, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err = m.MemoryCacheBackend.Delete(ctx, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return m.MemoryCacheBackend.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Clear removes all keys from the cache
|
||||
func (m *MemoryBackend) Clear(ctx context.Context) error {
|
||||
return m.MemoryCacheBackend.Clear(ctx)
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics
|
||||
func (m *MemoryBackend) GetStats() map[string]interface{} {
|
||||
stats, err := m.MemoryCacheBackend.GetStats(context.Background())
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BackendStats to map
|
||||
hitRate := float64(0)
|
||||
total := stats.Hits + stats.Misses
|
||||
if total > 0 {
|
||||
hitRate = float64(stats.Hits) / float64(total)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the cache backend and releases resources
|
||||
func (m *MemoryBackend) Close() error {
|
||||
return m.MemoryCacheBackend.Close()
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy and responsive
|
||||
func (m *MemoryBackend) Ping(ctx context.Context) error {
|
||||
return m.MemoryCacheBackend.Ping(ctx)
|
||||
}
|
||||
|
||||
// Ensure MemoryBackend implements CacheBackend
|
||||
var _ CacheBackend = (*MemoryBackend)(nil)
|
||||
Vendored
+277
@@ -0,0 +1,277 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisBackend implements a Redis-based cache backend
|
||||
type RedisBackend struct {
|
||||
client *redis.Client
|
||||
config *Config
|
||||
|
||||
// Metrics
|
||||
hits int64
|
||||
misses int64
|
||||
|
||||
// Lifecycle
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// NewRedisBackend creates a new Redis cache backend
|
||||
func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
if config.RedisAddr == "" {
|
||||
return nil, fmt.Errorf("redis address is required")
|
||||
}
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
DB: config.RedisDB,
|
||||
PoolSize: config.PoolSize,
|
||||
})
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
client.Close()
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
backend := &RedisBackend{
|
||||
client: client,
|
||||
config: config,
|
||||
}
|
||||
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// Set stores a value with TTL
|
||||
func (r *RedisBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
return r.client.Set(ctx, prefixedKey, value, ttl).Err()
|
||||
}
|
||||
|
||||
// Get retrieves a value
|
||||
func (r *RedisBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, 0, false, ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
|
||||
// Get value
|
||||
value, err := r.client.Get(ctx, prefixedKey).Bytes()
|
||||
if err == redis.Nil {
|
||||
atomic.AddInt64(&r.misses, 1)
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
atomic.AddInt64(&r.misses, 1)
|
||||
return nil, 0, false, fmt.Errorf("failed to get key: %w", err)
|
||||
}
|
||||
|
||||
// Get TTL
|
||||
ttl, err := r.client.TTL(ctx, prefixedKey).Result()
|
||||
if err != nil {
|
||||
// Value exists but couldn't get TTL
|
||||
atomic.AddInt64(&r.hits, 1)
|
||||
return value, 0, true, nil
|
||||
}
|
||||
|
||||
atomic.AddInt64(&r.hits, 1)
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
|
||||
// Delete removes a key
|
||||
func (r *RedisBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
if r.closed.Load() {
|
||||
return false, ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
result, err := r.client.Del(ctx, prefixedKey).Result()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to delete key: %w", err)
|
||||
}
|
||||
|
||||
return result > 0, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists
|
||||
func (r *RedisBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if r.closed.Load() {
|
||||
return false, ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
result, err := r.client.Exists(ctx, prefixedKey).Result()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check existence: %w", err)
|
||||
}
|
||||
|
||||
return result > 0, nil
|
||||
}
|
||||
|
||||
// Clear removes all keys with the prefix
|
||||
func (r *RedisBackend) Clear(ctx context.Context) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
// Use SCAN to find all keys with prefix
|
||||
pattern := r.config.RedisPrefix + "*"
|
||||
iter := r.client.Scan(ctx, 0, pattern, 0).Iterator()
|
||||
|
||||
var keys []string
|
||||
for iter.Next(ctx) {
|
||||
keys = append(keys, iter.Val())
|
||||
}
|
||||
|
||||
if err := iter.Err(); err != nil {
|
||||
return fmt.Errorf("failed to scan keys: %w", err)
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete in batches to avoid blocking Redis
|
||||
batchSize := 100
|
||||
for i := 0; i < len(keys); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
|
||||
if err := r.client.Del(ctx, keys[i:end]...).Err(); err != nil {
|
||||
return fmt.Errorf("failed to delete keys: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats returns statistics
|
||||
func (r *RedisBackend) GetStats() map[string]interface{} {
|
||||
hits := atomic.LoadInt64(&r.hits)
|
||||
misses := atomic.LoadInt64(&r.misses)
|
||||
total := hits + misses
|
||||
hitRate := 0.0
|
||||
if total > 0 {
|
||||
hitRate = float64(hits) / float64(total)
|
||||
}
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"type": "redis",
|
||||
"hits": hits,
|
||||
"misses": misses,
|
||||
"hit_rate": hitRate,
|
||||
}
|
||||
|
||||
// Try to get Redis info (non-critical)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if info, err := r.client.Info(ctx, "memory").Result(); err == nil {
|
||||
stats["redis_info"] = info
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks Redis health
|
||||
func (r *RedisBackend) Ping(ctx context.Context) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
return r.client.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
// Close shuts down the backend
|
||||
func (r *RedisBackend) Close() error {
|
||||
if !r.closed.CompareAndSwap(false, true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
return r.client.Close()
|
||||
}
|
||||
|
||||
// prefixKey adds the configured prefix to a key
|
||||
func (r *RedisBackend) prefixKey(key string) string {
|
||||
if r.config.RedisPrefix == "" {
|
||||
return key
|
||||
}
|
||||
return r.config.RedisPrefix + key
|
||||
}
|
||||
|
||||
// Pipeline operations for batch operations (future enhancement)
|
||||
|
||||
// SetMany stores multiple key-value pairs in a pipeline
|
||||
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
for key, value := range items {
|
||||
prefixedKey := r.prefixKey(key)
|
||||
pipe.Set(ctx, prefixedKey, value, ttl)
|
||||
}
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values in a pipeline
|
||||
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
cmds := make(map[string]*redis.StringCmd, len(keys))
|
||||
|
||||
for _, key := range keys {
|
||||
prefixedKey := r.prefixKey(key)
|
||||
cmds[key] = pipe.Get(ctx, prefixedKey)
|
||||
}
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results := make(map[string][]byte)
|
||||
for key, cmd := range cmds {
|
||||
value, err := cmd.Bytes()
|
||||
if err == nil {
|
||||
results[key] = value
|
||||
atomic.AddInt64(&r.hits, 1)
|
||||
} else if err != redis.Nil {
|
||||
atomic.AddInt64(&r.misses, 1)
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
+545
@@ -0,0 +1,545 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRedisBackend_BasicOperations tests basic Redis operations
|
||||
func TestRedisBackend_BasicOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetAndGet", func(t *testing.T) {
|
||||
key := "redis-test-key"
|
||||
value := []byte("redis-test-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
assert.Greater(t, remainingTTL, 50*time.Second)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
_, _, exists, err := backend.Get(ctx, "non-existent-redis-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
key := "redis-delete-key"
|
||||
value := []byte("redis-delete-value")
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
key := "redis-exists-key"
|
||||
value := []byte("redis-exists-value")
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_KeyPrefixing tests key namespace prefixing
|
||||
func TestRedisBackend_KeyPrefixing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "test:prefix:"
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "my-key"
|
||||
value := []byte("my-value")
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that key is stored with prefix
|
||||
keys := mr.CheckKeys()
|
||||
require.Len(t, keys, 1)
|
||||
assert.Equal(t, "test:prefix:my-key", keys[0])
|
||||
|
||||
// Get should work without prefix
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
// TestRedisBackend_TTLExpiration tests TTL handling
|
||||
func TestRedisBackend_TTLExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ShortTTL", func(t *testing.T) {
|
||||
key := "ttl-key"
|
||||
value := []byte("ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Fast forward time in miniredis
|
||||
mr.FastForward(150 * time.Millisecond)
|
||||
|
||||
// Should be expired
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("TTLRemaining", func(t *testing.T) {
|
||||
key := "ttl-remaining-key"
|
||||
value := []byte("ttl-remaining-value")
|
||||
ttl := 10 * time.Second
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get immediately
|
||||
_, ttl1, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Fast forward 2 seconds
|
||||
mr.FastForward(2 * time.Second)
|
||||
|
||||
// Check TTL is less
|
||||
_, ttl2, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Less(t, ttl2, ttl1)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_Clear tests clearing all keys
|
||||
func TestRedisBackend_Clear(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "clear-test:"
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add multiple keys
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify keys exist
|
||||
keys := mr.CheckKeys()
|
||||
assert.Len(t, keys, 10)
|
||||
|
||||
// Clear all
|
||||
err = backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all keys are gone
|
||||
keys = mr.CheckKeys()
|
||||
assert.Len(t, keys, 0)
|
||||
}
|
||||
|
||||
// TestRedisBackend_ConnectionFailure tests behavior on connection failure
|
||||
func TestRedisBackend_ConnectionFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Try to connect to non-existent Redis
|
||||
config := DefaultRedisConfig("localhost:9999")
|
||||
_, err := NewRedisBackend(config)
|
||||
assert.Error(t, err, "Should fail to connect to non-existent Redis")
|
||||
}
|
||||
|
||||
// TestRedisBackend_RedisErrors tests handling of Redis errors
|
||||
func TestRedisBackend_RedisErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Simulate Redis error
|
||||
mr.SetError("simulated error")
|
||||
|
||||
// Operations should fail
|
||||
err = backend.Set(ctx, "error-key", []byte("error-value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Clear error
|
||||
mr.ClearError()
|
||||
|
||||
// Operations should work again
|
||||
err = backend.Set(ctx, "success-key", []byte("success-value"), 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_ConcurrentAccess tests thread safety
|
||||
func TestRedisBackend_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
if j%5 == 0 {
|
||||
backend.Delete(ctx, key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
assert.Greater(t, hits+misses, int64(0))
|
||||
}
|
||||
|
||||
// TestRedisBackend_Stats tests statistics tracking
|
||||
func TestRedisBackend_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := backend.GetStats()
|
||||
assert.Equal(t, int64(0), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(0), stats["misses"].(int64))
|
||||
|
||||
// Add and access items
|
||||
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
backend.Get(ctx, "key1") // Hit
|
||||
backend.Get(ctx, "non-existent") // Miss
|
||||
|
||||
stats = backend.GetStats()
|
||||
assert.Equal(t, int64(1), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(1), stats["misses"].(int64))
|
||||
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
assert.InDelta(t, 0.5, hitRate, 0.01)
|
||||
}
|
||||
|
||||
// TestRedisBackend_Ping tests health check
|
||||
func TestRedisBackend_Ping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = backend.Ping(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Close and ping should fail
|
||||
backend.Close()
|
||||
err = backend.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_Close tests proper cleanup
|
||||
func TestRedisBackend_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("close-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("close-value-%d", i))
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Close
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Operations should fail
|
||||
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
// Double close should be safe
|
||||
err = backend.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_UpdateExisting tests updating existing keys
|
||||
func TestRedisBackend_UpdateExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
|
||||
// Set original
|
||||
err = backend.Set(ctx, key, value1, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update
|
||||
err = backend.Set(ctx, key, value2, 2*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved)
|
||||
assert.Greater(t, ttl, 1*time.Minute)
|
||||
}
|
||||
|
||||
// TestRedisBackend_LargeValues tests handling of large values
|
||||
func TestRedisBackend_LargeValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "large-key"
|
||||
largeValue := make([]byte, 1024*1024) // 1MB
|
||||
|
||||
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(largeValue), len(retrieved))
|
||||
}
|
||||
|
||||
// TestRedisBackend_EmptyValues tests handling of empty values
|
||||
func TestRedisBackend_EmptyValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "empty-key"
|
||||
emptyValue := []byte{}
|
||||
|
||||
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 0, len(retrieved))
|
||||
}
|
||||
|
||||
// TestRedisBackend_PipelineOperations tests batch operations
|
||||
func TestRedisBackend_PipelineOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMany", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("batch-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("batch-value-%d", i))
|
||||
items[key] = value
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all items were set
|
||||
for key, expectedValue := range items {
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, retrieved)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetMany", func(t *testing.T) {
|
||||
// Set test data
|
||||
testData := GenerateTestData(5)
|
||||
for key, value := range testData {
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Get all keys
|
||||
keys := make([]string, 0, len(testData))
|
||||
for key := range testData {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, len(testData))
|
||||
|
||||
for key, expectedValue := range testData {
|
||||
retrievedValue, exists := results[key]
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, retrievedValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetManyWithNonExistent", func(t *testing.T) {
|
||||
keys := []string{"exists-1", "non-existent", "exists-2"}
|
||||
|
||||
backend.Set(ctx, "exists-1", []byte("value-1"), 1*time.Minute)
|
||||
backend.Set(ctx, "exists-2", []byte("value-2"), 1*time.Minute)
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only existing keys
|
||||
assert.Equal(t, []byte("value-1"), results["exists-1"])
|
||||
assert.Equal(t, []byte("value-2"), results["exists-2"])
|
||||
_, exists := results["non-existent"]
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_NoPrefix tests operation without prefix
|
||||
func TestRedisBackend_NoPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "" // No prefix
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "no-prefix-key"
|
||||
value := []byte("no-prefix-value")
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check key is stored without prefix
|
||||
keys := mr.CheckKeys()
|
||||
require.Len(t, keys, 1)
|
||||
assert.Equal(t, key, keys[0])
|
||||
}
|
||||
+184
@@ -0,0 +1,184 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestLogger implements a simple logger for tests
|
||||
type TestLogger struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func NewTestLogger(t *testing.T) *TestLogger {
|
||||
return &TestLogger{t: t}
|
||||
}
|
||||
|
||||
func (l *TestLogger) Debug(format string, args ...interface{}) {
|
||||
l.t.Logf("[DEBUG] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Info(format string, args ...interface{}) {
|
||||
l.t.Logf("[INFO] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Error(format string, args ...interface{}) {
|
||||
l.t.Logf("[ERROR] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Debugf(format string, args ...interface{}) {
|
||||
l.Debug(format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Infof(format string, args ...interface{}) {
|
||||
l.Info(format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Errorf(format string, args ...interface{}) {
|
||||
l.Error(format, args...)
|
||||
}
|
||||
|
||||
// MiniredisServer manages a miniredis instance for testing
|
||||
type MiniredisServer struct {
|
||||
server *miniredis.Miniredis
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
// NewMiniredisServer creates a new miniredis server for testing
|
||||
func NewMiniredisServer(t *testing.T) *MiniredisServer {
|
||||
t.Helper()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err, "failed to start miniredis")
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
})
|
||||
|
||||
// Verify connection
|
||||
ctx := context.Background()
|
||||
err = client.Ping(ctx).Err()
|
||||
require.NoError(t, err, "failed to ping miniredis")
|
||||
|
||||
t.Cleanup(func() {
|
||||
client.Close()
|
||||
mr.Close()
|
||||
})
|
||||
|
||||
return &MiniredisServer{
|
||||
server: mr,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAddr returns the address of the miniredis server
|
||||
func (m *MiniredisServer) GetAddr() string {
|
||||
return m.server.Addr()
|
||||
}
|
||||
|
||||
// GetClient returns the Redis client
|
||||
func (m *MiniredisServer) GetClient() *redis.Client {
|
||||
return m.client
|
||||
}
|
||||
|
||||
// FastForward advances the miniredis server's time
|
||||
func (m *MiniredisServer) FastForward(d time.Duration) {
|
||||
m.server.FastForward(d)
|
||||
}
|
||||
|
||||
// FlushAll removes all keys from the database
|
||||
func (m *MiniredisServer) FlushAll() {
|
||||
m.server.FlushAll()
|
||||
}
|
||||
|
||||
// SetError simulates a Redis error
|
||||
func (m *MiniredisServer) SetError(err string) {
|
||||
m.server.SetError(err)
|
||||
}
|
||||
|
||||
// ClearError clears any simulated errors
|
||||
func (m *MiniredisServer) ClearError() {
|
||||
m.server.SetError("")
|
||||
}
|
||||
|
||||
// CheckKeys verifies that specific keys exist in Redis
|
||||
func (m *MiniredisServer) CheckKeys() []string {
|
||||
return m.server.Keys()
|
||||
}
|
||||
|
||||
// TestConfig provides default test configuration
|
||||
type TestConfig struct {
|
||||
MaxSize int
|
||||
DefaultTTL time.Duration
|
||||
CleanupInterval time.Duration
|
||||
EnableMetrics bool
|
||||
}
|
||||
|
||||
// DefaultTestConfig returns a standard test configuration
|
||||
func DefaultTestConfig() *TestConfig {
|
||||
return &TestConfig{
|
||||
MaxSize: 100,
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
CleanupInterval: 1 * time.Second,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateTestData creates test cache data
|
||||
func GenerateTestData(count int) map[string][]byte {
|
||||
data := make(map[string][]byte, count)
|
||||
for i := 0; i < count; i++ {
|
||||
key := fmt.Sprintf("test-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("test-value-%d", i))
|
||||
data[key] = value
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// GenerateLargeValue creates a large test value
|
||||
func GenerateLargeValue(sizeBytes int) []byte {
|
||||
return make([]byte, sizeBytes)
|
||||
}
|
||||
|
||||
// AssertCacheStats is a helper to verify cache statistics
|
||||
func AssertCacheStats(t *testing.T, stats map[string]interface{}, expectedHits, expectedMisses int64) {
|
||||
t.Helper()
|
||||
|
||||
hits, ok := stats["hits"].(int64)
|
||||
require.True(t, ok, "hits should be int64")
|
||||
require.Equal(t, expectedHits, hits, "unexpected hit count")
|
||||
|
||||
misses, ok := stats["misses"].(int64)
|
||||
require.True(t, ok, "misses should be int64")
|
||||
require.Equal(t, expectedMisses, misses, "unexpected miss count")
|
||||
}
|
||||
|
||||
// WaitForCondition waits for a condition to be true or times out
|
||||
func WaitForCondition(t *testing.T, timeout time.Duration, checkInterval time.Duration, condition func() bool) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
time.Sleep(checkInterval)
|
||||
}
|
||||
t.Fatal("timeout waiting for condition")
|
||||
}
|
||||
|
||||
// AssertEventuallyExpires verifies that a key eventually expires
|
||||
func AssertEventuallyExpires(t *testing.T, backend CacheBackend, ctx context.Context, key string, maxWait time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
WaitForCondition(t, maxWait, 100*time.Millisecond, func() bool {
|
||||
_, _, exists, err := backend.Get(ctx, key)
|
||||
return err == nil && !exists
|
||||
})
|
||||
}
|
||||
+329
@@ -0,0 +1,329 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
// ErrCircuitOpen is returned when the circuit breaker is open
|
||||
ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||
|
||||
// ErrTooManyRequests is returned when too many requests are made in half-open state
|
||||
ErrTooManyRequests = errors.New("too many requests in half-open state")
|
||||
)
|
||||
|
||||
// State represents the state of the circuit breaker
|
||||
type State int32
|
||||
|
||||
const (
|
||||
// StateClosed allows all operations to pass through
|
||||
StateClosed State = iota
|
||||
|
||||
// StateOpen blocks all operations
|
||||
StateOpen
|
||||
|
||||
// StateHalfOpen allows a limited number of operations to test recovery
|
||||
StateHalfOpen
|
||||
)
|
||||
|
||||
// String returns the string representation of the state
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateClosed:
|
||||
return "closed"
|
||||
case StateOpen:
|
||||
return "open"
|
||||
case StateHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds configuration for the circuit breaker
|
||||
type CircuitBreakerConfig struct {
|
||||
// MaxFailures is the number of consecutive failures before opening the circuit
|
||||
MaxFailures int
|
||||
|
||||
// FailureThreshold is the failure rate threshold (0.0 to 1.0)
|
||||
FailureThreshold float64
|
||||
|
||||
// Timeout is how long the circuit stays open before trying half-open
|
||||
Timeout time.Duration
|
||||
|
||||
// HalfOpenMaxRequests is the number of requests allowed in half-open state
|
||||
HalfOpenMaxRequests int
|
||||
|
||||
// ResetTimeout is how long to wait before resetting counters in closed state
|
||||
ResetTimeout time.Duration
|
||||
|
||||
// OnStateChange is called when the circuit breaker changes state
|
||||
OnStateChange func(from, to State)
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns default configuration
|
||||
func DefaultCircuitBreakerConfig() *CircuitBreakerConfig {
|
||||
return &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
FailureThreshold: 0.6,
|
||||
Timeout: 30 * time.Second,
|
||||
HalfOpenMaxRequests: 3,
|
||||
ResetTimeout: 60 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern
|
||||
type CircuitBreaker struct {
|
||||
config *CircuitBreakerConfig
|
||||
|
||||
// State management
|
||||
state atomic.Int32
|
||||
lastStateChange time.Time
|
||||
stateMu sync.RWMutex
|
||||
|
||||
// Failure tracking
|
||||
consecutiveFailures atomic.Int32
|
||||
totalRequests atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
halfOpenRequests atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
nextRetryTime time.Time
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
stateTransitions atomic.Int64
|
||||
rejectedRequests atomic.Int64
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker
|
||||
func NewCircuitBreaker(config *CircuitBreakerConfig) *CircuitBreaker {
|
||||
if config == nil {
|
||||
config = DefaultCircuitBreakerConfig()
|
||||
}
|
||||
|
||||
return &CircuitBreaker{
|
||||
config: config,
|
||||
lastStateChange: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs a function through the circuit breaker
|
||||
func (cb *CircuitBreaker) Execute(ctx context.Context, fn func() error) error {
|
||||
if !cb.AllowRequest() {
|
||||
cb.rejectedRequests.Add(1)
|
||||
return ErrCircuitOpen
|
||||
}
|
||||
|
||||
cb.totalRequests.Add(1)
|
||||
|
||||
err := fn()
|
||||
if err != nil {
|
||||
cb.RecordFailure()
|
||||
} else {
|
||||
cb.RecordSuccess()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AllowRequest checks if a request is allowed to proceed
|
||||
func (cb *CircuitBreaker) AllowRequest() bool {
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
return true
|
||||
|
||||
case StateOpen:
|
||||
// Check if timeout has passed and we should try half-open
|
||||
cb.timeMu.RLock()
|
||||
shouldRetry := time.Now().After(cb.nextRetryTime)
|
||||
cb.timeMu.RUnlock()
|
||||
|
||||
if shouldRetry {
|
||||
cb.setState(StateHalfOpen)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
case StateHalfOpen:
|
||||
// Allow limited requests in half-open state
|
||||
current := cb.halfOpenRequests.Add(1)
|
||||
return current <= int32(cb.config.HalfOpenMaxRequests)
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation
|
||||
func (cb *CircuitBreaker) RecordSuccess() {
|
||||
cb.timeMu.Lock()
|
||||
cb.lastSuccessTime = time.Now()
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Reset consecutive failures
|
||||
cb.consecutiveFailures.Store(0)
|
||||
|
||||
case StateHalfOpen:
|
||||
// If we've had enough successful requests, close the circuit
|
||||
successfulRequests := cb.halfOpenRequests.Load()
|
||||
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordFailure records a failed operation
|
||||
func (cb *CircuitBreaker) RecordFailure() {
|
||||
cb.totalFailures.Add(1)
|
||||
failures := cb.consecutiveFailures.Add(1)
|
||||
|
||||
cb.timeMu.Lock()
|
||||
cb.lastFailureTime = time.Now()
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Check if we should open the circuit
|
||||
if failures >= int32(cb.config.MaxFailures) {
|
||||
cb.openCircuit()
|
||||
} else if cb.config.FailureThreshold > 0 {
|
||||
// Check failure rate
|
||||
total := cb.totalRequests.Load()
|
||||
failureCount := cb.totalFailures.Load()
|
||||
if total > 10 && float64(failureCount)/float64(total) > cb.config.FailureThreshold {
|
||||
cb.openCircuit()
|
||||
}
|
||||
}
|
||||
|
||||
case StateHalfOpen:
|
||||
// Any failure in half-open state reopens the circuit
|
||||
cb.openCircuit()
|
||||
}
|
||||
}
|
||||
|
||||
// openCircuit transitions to open state
|
||||
func (cb *CircuitBreaker) openCircuit() {
|
||||
cb.setState(StateOpen)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
|
||||
cb.timeMu.Lock()
|
||||
cb.nextRetryTime = time.Now().Add(cb.config.Timeout)
|
||||
cb.timeMu.Unlock()
|
||||
}
|
||||
|
||||
// GetState returns the current state
|
||||
func (cb *CircuitBreaker) GetState() State {
|
||||
return State(cb.state.Load())
|
||||
}
|
||||
|
||||
// setState changes the circuit breaker state
|
||||
func (cb *CircuitBreaker) setState(newState State) {
|
||||
oldState := State(cb.state.Swap(int32(newState)))
|
||||
|
||||
if oldState != newState {
|
||||
cb.stateTransitions.Add(1)
|
||||
|
||||
cb.stateMu.Lock()
|
||||
cb.lastStateChange = time.Now()
|
||||
cb.stateMu.Unlock()
|
||||
|
||||
if cb.config.OnStateChange != nil {
|
||||
cb.config.OnStateChange(oldState, newState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset resets the circuit breaker to closed state
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
cb.totalRequests.Store(0)
|
||||
cb.totalFailures.Store(0)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
cb.rejectedRequests.Store(0)
|
||||
cb.stateTransitions.Store(0)
|
||||
|
||||
now := time.Now()
|
||||
cb.timeMu.Lock()
|
||||
cb.lastFailureTime = now
|
||||
cb.lastSuccessTime = now
|
||||
cb.nextRetryTime = now
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
cb.stateMu.Lock()
|
||||
cb.lastStateChange = now
|
||||
cb.stateMu.Unlock()
|
||||
}
|
||||
|
||||
// Stats returns circuit breaker statistics
|
||||
func (cb *CircuitBreaker) Stats() CircuitBreakerStats {
|
||||
cb.timeMu.RLock()
|
||||
lastFailure := cb.lastFailureTime
|
||||
lastSuccess := cb.lastSuccessTime
|
||||
nextRetry := cb.nextRetryTime
|
||||
cb.timeMu.RUnlock()
|
||||
|
||||
cb.stateMu.RLock()
|
||||
lastChange := cb.lastStateChange
|
||||
cb.stateMu.RUnlock()
|
||||
|
||||
totalReq := cb.totalRequests.Load()
|
||||
totalFail := cb.totalFailures.Load()
|
||||
successRate := float64(0)
|
||||
if totalReq > 0 {
|
||||
successRate = float64(totalReq-totalFail) / float64(totalReq)
|
||||
}
|
||||
|
||||
return CircuitBreakerStats{
|
||||
State: cb.GetState(),
|
||||
ConsecutiveFailures: cb.consecutiveFailures.Load(),
|
||||
TotalRequests: totalReq,
|
||||
TotalFailures: totalFail,
|
||||
SuccessRate: successRate,
|
||||
RejectedRequests: cb.rejectedRequests.Load(),
|
||||
StateTransitions: cb.stateTransitions.Load(),
|
||||
LastFailureTime: lastFailure,
|
||||
LastSuccessTime: lastSuccess,
|
||||
LastStateChange: lastChange,
|
||||
NextRetryTime: nextRetry,
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerStats holds statistics for the circuit breaker
|
||||
type CircuitBreakerStats struct {
|
||||
State State
|
||||
ConsecutiveFailures int32
|
||||
TotalRequests int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
RejectedRequests int64
|
||||
StateTransitions int64
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastStateChange time.Time
|
||||
NextRetryTime time.Time
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the circuit breaker is in a healthy state
|
||||
func (cb *CircuitBreaker) IsHealthy() bool {
|
||||
return cb.GetState() != StateOpen
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
)
|
||||
|
||||
// CircuitBreakerBackend wraps a cache backend with circuit breaker protection
|
||||
type CircuitBreakerBackend struct {
|
||||
backend backends.CacheBackend
|
||||
cb *CircuitBreaker
|
||||
}
|
||||
|
||||
// NewCircuitBreakerBackend creates a new circuit breaker wrapped backend
|
||||
func NewCircuitBreakerBackend(b backends.CacheBackend, config *CircuitBreakerConfig) backends.CacheBackend {
|
||||
if config == nil {
|
||||
config = DefaultCircuitBreakerConfig()
|
||||
}
|
||||
|
||||
return &CircuitBreakerBackend{
|
||||
backend: b,
|
||||
cb: NewCircuitBreaker(config),
|
||||
}
|
||||
}
|
||||
|
||||
// Set stores a value with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Set(ctx, key, value, ttl)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return nil, 0, false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
value, ttl, exists, err := c.backend.Get(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return value, ttl, exists, err
|
||||
}
|
||||
|
||||
// Delete removes a key with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
deleted, err := c.backend.Delete(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return deleted, err
|
||||
}
|
||||
|
||||
// Exists checks if a key exists with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
exists, err := c.backend.Exists(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// Clear removes all keys with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Clear(ctx context.Context) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Clear(ctx)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetStats returns statistics including circuit breaker state
|
||||
func (c *CircuitBreakerBackend) GetStats() map[string]interface{} {
|
||||
stats := c.backend.GetStats()
|
||||
if stats == nil {
|
||||
stats = make(map[string]interface{})
|
||||
}
|
||||
|
||||
cbStats := c.cb.Stats()
|
||||
stats["circuit_breaker"] = map[string]interface{}{
|
||||
"state": cbStats.State.String(),
|
||||
"consecutive_failures": cbStats.ConsecutiveFailures,
|
||||
"total_requests": cbStats.TotalRequests,
|
||||
"total_failures": cbStats.TotalFailures,
|
||||
"success_rate": cbStats.SuccessRate,
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks backend health with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Ping(ctx context.Context) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Ping(ctx)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close shuts down the backend
|
||||
func (c *CircuitBreakerBackend) Close() error {
|
||||
return c.backend.Close()
|
||||
}
|
||||
+553
@@ -0,0 +1,553 @@
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestCircuitBreaker_StateTransitions tests state machine transitions
|
||||
func TestCircuitBreaker_StateTransitions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Initial state is closed", func(t *testing.T) {
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("Closed to Open after max failures", func(t *testing.T) {
|
||||
cb.Reset()
|
||||
|
||||
// Simulate failures
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("Open to HalfOpen after timeout", func(t *testing.T) {
|
||||
// Open the circuit
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should allow request and transition to half-open
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, StateHalfOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("HalfOpen to Closed after successful requests", func(t *testing.T) {
|
||||
// Open circuit then wait for half-open
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// First request transitions to half-open and succeeds
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
// Should be in half-open after first request
|
||||
state := cb.GetState()
|
||||
assert.True(t, state == StateHalfOpen || state == StateClosed,
|
||||
"After first successful request, should be half-open or potentially closed")
|
||||
|
||||
if state == StateHalfOpen {
|
||||
// Need more successful requests to close
|
||||
// The exact number depends on implementation but should be within HalfOpenMaxRequests
|
||||
for i := 0; i < config.HalfOpenMaxRequests; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
// After multiple successful requests, should eventually close
|
||||
finalState := cb.GetState()
|
||||
assert.True(t, finalState == StateClosed || finalState == StateHalfOpen,
|
||||
"After successful requests, circuit should transition towards closed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HalfOpen to Open on failure", func(t *testing.T) {
|
||||
// Open circuit then wait for half-open
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// First call transitions to half-open, second failure reopens
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_OpenCircuitBlocks tests that open circuit blocks requests
|
||||
func TestCircuitBreaker_OpenCircuitBlocks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures to open circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Requests should be blocked
|
||||
err := cb.Execute(ctx, func() error {
|
||||
t.Fatal("Should not execute function when circuit is open")
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_HalfOpenMaxRequests tests max requests in half-open state
|
||||
func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open circuit then wait for half-open
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// After timeout, circuit should allow transition to half-open
|
||||
// Execute HalfOpenMaxRequests successful requests
|
||||
successCount := 0
|
||||
for i := 0; i < config.HalfOpenMaxRequests; i++ {
|
||||
err := cb.Execute(ctx, func() error {
|
||||
successCount++
|
||||
return nil
|
||||
})
|
||||
// Should allow up to HalfOpenMaxRequests
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify we executed the expected number
|
||||
assert.Equal(t, config.HalfOpenMaxRequests, successCount)
|
||||
|
||||
// After successful requests, circuit behavior depends on implementation
|
||||
// It could close (allowing more requests) or stay half-open (blocking)
|
||||
// The important thing is that we allowed exactly HalfOpenMaxRequests
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_SuccessResetsFailures tests failure counter reset
|
||||
func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Have some failures (but less than max)
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats := cb.Stats()
|
||||
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
|
||||
|
||||
// One success should reset failures
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats = cb.Stats()
|
||||
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_ConcurrentAccess tests thread safety
|
||||
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 10,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 5,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Mix of successes and failures
|
||||
cb.Execute(ctx, func() error {
|
||||
if (id+j)%3 == 0 {
|
||||
return errors.New("test error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Random state checks
|
||||
_ = cb.GetState()
|
||||
_ = cb.Stats()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should complete without panics
|
||||
stats := cb.Stats()
|
||||
assert.NotNil(t, stats)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Stats tests statistics tracking
|
||||
func TestCircuitBreaker_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Execute some requests
|
||||
cb.Execute(ctx, func() error { return nil }) // Success
|
||||
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
|
||||
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
|
||||
|
||||
stats := cb.Stats()
|
||||
|
||||
assert.Equal(t, StateClosed, stats.State)
|
||||
assert.Equal(t, int64(3), stats.TotalRequests)
|
||||
assert.Equal(t, int64(2), stats.TotalFailures)
|
||||
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Reset tests circuit reset
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open the circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Reset
|
||||
cb.Reset()
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats := cb.Stats()
|
||||
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
|
||||
assert.Equal(t, int64(0), stats.TotalRequests)
|
||||
assert.Equal(t, int64(0), stats.TotalFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_StateChangeCallback tests state change notifications
|
||||
func TestCircuitBreaker_StateChangeCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var transitions []string
|
||||
var mu sync.Mutex
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 50 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
OnStateChange: func(from, to State) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
transitions = append(transitions, from.String()+"->"+to.String())
|
||||
},
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger state transitions
|
||||
// Closed -> Open
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
// Should be open now
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait for timeout to allow half-open transition
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Open -> HalfOpen on first request after timeout
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Execute more successful requests to trigger HalfOpen -> Closed
|
||||
for i := 0; i < config.HalfOpenMaxRequests-1; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.Contains(t, transitions, "closed->open")
|
||||
assert.Contains(t, transitions, "open->half-open")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_IsHealthy tests health check
|
||||
func TestCircuitBreaker_IsHealthy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, cb.IsHealthy())
|
||||
|
||||
// Open circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
assert.False(t, cb.IsHealthy(), "Should not be healthy when open")
|
||||
|
||||
// Wait for timeout and allow successful request
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Should be healthy after recovery
|
||||
assert.True(t, cb.IsHealthy(), "Should be healthy after recovery")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_RapidFailures tests rapid consecutive failures
|
||||
func TestCircuitBreaker_RapidFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
Timeout: 200 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Rapid failures
|
||||
for i := 0; i < 10; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("rapid error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
stats := cb.Stats()
|
||||
assert.GreaterOrEqual(t, stats.TotalFailures, int64(5))
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_TimeoutAccuracy tests timeout precision
|
||||
func TestCircuitBreaker_TimeoutAccuracy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
timeout := 100 * time.Millisecond
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: timeout,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open circuit
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait just before timeout
|
||||
time.Sleep(timeout - 20*time.Millisecond)
|
||||
assert.False(t, cb.IsHealthy())
|
||||
|
||||
// Wait until after timeout
|
||||
time.Sleep(40 * time.Millisecond)
|
||||
// After timeout, AllowRequest should return true for transition to half-open
|
||||
assert.True(t, cb.AllowRequest())
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_DefaultConfig tests default configuration
|
||||
func TestCircuitBreaker_DefaultConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cb := NewCircuitBreaker(nil) // Should use defaults
|
||||
|
||||
assert.NotNil(t, cb)
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
|
||||
// Verify defaults by triggering circuit breaker behavior
|
||||
ctx := context.Background()
|
||||
|
||||
// Test that it takes 5 failures to open (default MaxFailures)
|
||||
for i := 0; i < 4; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateClosed, cb.GetState(), "Should still be closed after 4 failures")
|
||||
|
||||
// 5th failure should open it
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
assert.Equal(t, StateOpen, cb.GetState(), "Should be open after 5 failures (default threshold)")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_StateString tests state string representation
|
||||
func TestCircuitBreaker_StateString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "closed", StateClosed.String())
|
||||
assert.Equal(t, "open", StateOpen.String())
|
||||
assert.Equal(t, "half-open", StateHalfOpen.String())
|
||||
assert.Equal(t, "unknown", State(999).String())
|
||||
}
|
||||
|
||||
// Benchmark circuit breaker performance
|
||||
func BenchmarkCircuitBreaker_Execute(b *testing.B) {
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 100,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 10,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_ExecuteWithFailures(b *testing.B) {
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 1000,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 10,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
if i%10 == 0 {
|
||||
return errors.New("error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
+375
@@ -0,0 +1,375 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthStatus represents the health status of a backend
|
||||
type HealthStatus int32
|
||||
|
||||
const (
|
||||
// HealthUnknown indicates unknown health status
|
||||
HealthUnknown HealthStatus = iota
|
||||
|
||||
// HealthHealthy indicates the backend is healthy
|
||||
HealthHealthy
|
||||
|
||||
// HealthDegraded indicates the backend is degraded but operational
|
||||
HealthDegraded
|
||||
|
||||
// HealthUnhealthy indicates the backend is unhealthy
|
||||
HealthUnhealthy
|
||||
)
|
||||
|
||||
// String returns the string representation of the health status
|
||||
func (h HealthStatus) String() string {
|
||||
switch h {
|
||||
case HealthHealthy:
|
||||
return "healthy"
|
||||
case HealthDegraded:
|
||||
return "degraded"
|
||||
case HealthUnhealthy:
|
||||
return "unhealthy"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// HealthCheckConfig holds configuration for the health checker
|
||||
type HealthCheckConfig struct {
|
||||
// CheckInterval is how often to check health
|
||||
CheckInterval time.Duration
|
||||
|
||||
// Timeout is the timeout for each health check
|
||||
Timeout time.Duration
|
||||
|
||||
// HealthyThreshold is the number of consecutive successes to become healthy
|
||||
HealthyThreshold int
|
||||
|
||||
// UnhealthyThreshold is the number of consecutive failures to become unhealthy
|
||||
UnhealthyThreshold int
|
||||
|
||||
// DegradedThreshold is the latency threshold in ms to mark as degraded
|
||||
DegradedThreshold time.Duration
|
||||
|
||||
// OnStatusChange is called when health status changes
|
||||
OnStatusChange func(from, to HealthStatus)
|
||||
|
||||
// CheckFunc is the function to check health
|
||||
CheckFunc func(ctx context.Context) error
|
||||
}
|
||||
|
||||
// DefaultHealthCheckConfig returns default configuration
|
||||
func DefaultHealthCheckConfig() *HealthCheckConfig {
|
||||
return &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
HealthyThreshold: 3,
|
||||
UnhealthyThreshold: 3,
|
||||
DegradedThreshold: 100 * time.Millisecond,
|
||||
}
|
||||
}
|
||||
|
||||
// HealthChecker monitors the health of a backend
|
||||
type HealthChecker struct {
|
||||
config *HealthCheckConfig
|
||||
|
||||
// Status tracking
|
||||
status atomic.Int32
|
||||
consecutiveSuccesses atomic.Int32
|
||||
consecutiveFailures atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastCheckTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureTime time.Time
|
||||
averageLatency atomic.Int64
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalChecks atomic.Int64
|
||||
totalSuccesses atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
statusChanges atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
ticker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
stopped atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewHealthChecker creates a new health checker
|
||||
func NewHealthChecker(config *HealthCheckConfig) *HealthChecker {
|
||||
if config == nil {
|
||||
config = DefaultHealthCheckConfig()
|
||||
}
|
||||
|
||||
hc := &HealthChecker{
|
||||
config: config,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
hc.status.Store(int32(HealthUnknown))
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
// Start begins health checking
|
||||
func (hc *HealthChecker) Start() {
|
||||
if hc.stopped.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
hc.ticker = time.NewTicker(hc.config.CheckInterval)
|
||||
hc.wg.Add(1)
|
||||
go hc.checkLoop()
|
||||
}
|
||||
|
||||
// Stop stops health checking
|
||||
func (hc *HealthChecker) Stop() {
|
||||
if hc.stopped.Swap(true) {
|
||||
return // Already stopped
|
||||
}
|
||||
|
||||
close(hc.stopChan)
|
||||
if hc.ticker != nil {
|
||||
hc.ticker.Stop()
|
||||
}
|
||||
hc.wg.Wait()
|
||||
}
|
||||
|
||||
// checkLoop runs periodic health checks
|
||||
func (hc *HealthChecker) checkLoop() {
|
||||
defer hc.wg.Done()
|
||||
|
||||
// Initial check - log error but continue
|
||||
if err := hc.Check(context.Background()); err != nil {
|
||||
// Error is already tracked in Check() method, no need to log again
|
||||
_ = err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hc.stopChan:
|
||||
return
|
||||
case <-hc.ticker.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hc.config.Timeout)
|
||||
if err := hc.Check(ctx); err != nil {
|
||||
// Error is already tracked in Check() method, no need to log again
|
||||
_ = err
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check performs a health check
|
||||
func (hc *HealthChecker) Check(ctx context.Context) error {
|
||||
if hc.config.CheckFunc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
hc.totalChecks.Add(1)
|
||||
start := time.Now()
|
||||
|
||||
// Create timeout context if not already set
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, hc.config.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Perform health check
|
||||
err := hc.config.CheckFunc(ctx)
|
||||
latency := time.Since(start)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastCheckTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Update average latency
|
||||
hc.updateAverageLatency(latency)
|
||||
|
||||
if err != nil {
|
||||
hc.recordFailure()
|
||||
} else {
|
||||
hc.recordSuccess(latency)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// recordSuccess records a successful health check
|
||||
func (hc *HealthChecker) recordSuccess(latency time.Duration) {
|
||||
hc.totalSuccesses.Add(1)
|
||||
successes := hc.consecutiveSuccesses.Add(1)
|
||||
hc.consecutiveFailures.Store(0)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastSuccessTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
currentStatus := hc.GetStatus()
|
||||
newStatus := currentStatus
|
||||
|
||||
// Check if we should become healthy
|
||||
if successes >= int32(hc.config.HealthyThreshold) {
|
||||
if latency > hc.config.DegradedThreshold {
|
||||
newStatus = HealthDegraded
|
||||
} else {
|
||||
newStatus = HealthHealthy
|
||||
}
|
||||
}
|
||||
|
||||
if newStatus != currentStatus {
|
||||
hc.setStatus(newStatus)
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failed health check
|
||||
func (hc *HealthChecker) recordFailure() {
|
||||
hc.totalFailures.Add(1)
|
||||
failures := hc.consecutiveFailures.Add(1)
|
||||
hc.consecutiveSuccesses.Store(0)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastFailureTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Check if we should become unhealthy
|
||||
if failures >= int32(hc.config.UnhealthyThreshold) {
|
||||
hc.setStatus(HealthUnhealthy)
|
||||
}
|
||||
}
|
||||
|
||||
// updateAverageLatency updates the rolling average latency
|
||||
func (hc *HealthChecker) updateAverageLatency(latency time.Duration) {
|
||||
// Simple exponential moving average
|
||||
currentAvg := time.Duration(hc.averageLatency.Load())
|
||||
if currentAvg == 0 {
|
||||
hc.averageLatency.Store(int64(latency))
|
||||
} else {
|
||||
// Weight: 0.2 for new value, 0.8 for old average
|
||||
newAvg := (currentAvg*4 + latency) / 5
|
||||
hc.averageLatency.Store(int64(newAvg))
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus returns the current health status
|
||||
func (hc *HealthChecker) GetStatus() HealthStatus {
|
||||
return HealthStatus(hc.status.Load())
|
||||
}
|
||||
|
||||
// setStatus changes the health status
|
||||
func (hc *HealthChecker) setStatus(newStatus HealthStatus) {
|
||||
oldStatus := HealthStatus(hc.status.Swap(int32(newStatus)))
|
||||
|
||||
if oldStatus != newStatus {
|
||||
hc.statusChanges.Add(1)
|
||||
if hc.config.OnStatusChange != nil {
|
||||
hc.config.OnStatusChange(oldStatus, newStatus)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy or degraded
|
||||
func (hc *HealthChecker) IsHealthy() bool {
|
||||
status := hc.GetStatus()
|
||||
return status == HealthHealthy || status == HealthDegraded
|
||||
}
|
||||
|
||||
// LastCheckTime returns the time of the last health check
|
||||
func (hc *HealthChecker) LastCheckTime() time.Time {
|
||||
hc.timeMu.RLock()
|
||||
defer hc.timeMu.RUnlock()
|
||||
return hc.lastCheckTime
|
||||
}
|
||||
|
||||
// HealthScore returns a health score between 0.0 (unhealthy) and 1.0 (healthy)
|
||||
func (hc *HealthChecker) HealthScore() float64 {
|
||||
status := hc.GetStatus()
|
||||
switch status {
|
||||
case HealthHealthy:
|
||||
return 1.0
|
||||
case HealthDegraded:
|
||||
return 0.7
|
||||
case HealthUnhealthy:
|
||||
return 0.0
|
||||
default:
|
||||
return 0.5
|
||||
}
|
||||
}
|
||||
|
||||
// Stats returns health checker statistics
|
||||
func (hc *HealthChecker) Stats() HealthCheckerStats {
|
||||
hc.timeMu.RLock()
|
||||
lastCheck := hc.lastCheckTime
|
||||
lastSuccess := hc.lastSuccessTime
|
||||
lastFailure := hc.lastFailureTime
|
||||
hc.timeMu.RUnlock()
|
||||
|
||||
totalChecks := hc.totalChecks.Load()
|
||||
totalSuccesses := hc.totalSuccesses.Load()
|
||||
totalFailures := hc.totalFailures.Load()
|
||||
|
||||
successRate := float64(0)
|
||||
if totalChecks > 0 {
|
||||
successRate = float64(totalSuccesses) / float64(totalChecks)
|
||||
}
|
||||
|
||||
return HealthCheckerStats{
|
||||
Status: hc.GetStatus(),
|
||||
ConsecutiveSuccesses: hc.consecutiveSuccesses.Load(),
|
||||
ConsecutiveFailures: hc.consecutiveFailures.Load(),
|
||||
TotalChecks: totalChecks,
|
||||
TotalSuccesses: totalSuccesses,
|
||||
TotalFailures: totalFailures,
|
||||
SuccessRate: successRate,
|
||||
AverageLatency: time.Duration(hc.averageLatency.Load()),
|
||||
StatusChanges: hc.statusChanges.Load(),
|
||||
LastCheckTime: lastCheck,
|
||||
LastSuccessTime: lastSuccess,
|
||||
LastFailureTime: lastFailure,
|
||||
HealthScore: hc.HealthScore(),
|
||||
}
|
||||
}
|
||||
|
||||
// HealthCheckerStats holds statistics for the health checker
|
||||
type HealthCheckerStats struct {
|
||||
Status HealthStatus
|
||||
ConsecutiveSuccesses int32
|
||||
ConsecutiveFailures int32
|
||||
TotalChecks int64
|
||||
TotalSuccesses int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
AverageLatency time.Duration
|
||||
StatusChanges int64
|
||||
LastCheckTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastFailureTime time.Time
|
||||
HealthScore float64
|
||||
}
|
||||
|
||||
// Reset resets the health checker statistics
|
||||
func (hc *HealthChecker) Reset() {
|
||||
hc.status.Store(int32(HealthUnknown))
|
||||
hc.consecutiveSuccesses.Store(0)
|
||||
hc.consecutiveFailures.Store(0)
|
||||
hc.totalChecks.Store(0)
|
||||
hc.totalSuccesses.Store(0)
|
||||
hc.totalFailures.Store(0)
|
||||
hc.statusChanges.Store(0)
|
||||
hc.averageLatency.Store(0)
|
||||
|
||||
now := time.Now()
|
||||
hc.timeMu.Lock()
|
||||
hc.lastCheckTime = now
|
||||
hc.lastSuccessTime = now
|
||||
hc.lastFailureTime = now
|
||||
hc.timeMu.Unlock()
|
||||
}
|
||||
+215
@@ -0,0 +1,215 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
)
|
||||
|
||||
// HealthCheckBackend wraps a cache backend with health checking
|
||||
type HealthCheckBackend struct {
|
||||
backend backends.CacheBackend
|
||||
config *HealthCheckConfig
|
||||
|
||||
// Health tracking
|
||||
status atomic.Int32
|
||||
consecutiveFails atomic.Int32
|
||||
consecutiveOK atomic.Int32
|
||||
lastCheck time.Time
|
||||
checkMutex sync.RWMutex
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewHealthCheckBackend creates a new health check wrapped backend
|
||||
func NewHealthCheckBackend(b backends.CacheBackend, config *HealthCheckConfig) backends.CacheBackend {
|
||||
if config == nil {
|
||||
config = DefaultHealthCheckConfig()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
hc := &HealthCheckBackend{
|
||||
backend: b,
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Set initial status to healthy (optimistic)
|
||||
hc.status.Store(int32(HealthHealthy))
|
||||
|
||||
// Start health check routine
|
||||
hc.wg.Add(1)
|
||||
go hc.healthCheckLoop()
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
// Set stores a value and tracks health
|
||||
func (h *HealthCheckBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
// Allow operations even if unhealthy (may recover)
|
||||
err := h.backend.Set(ctx, key, value, ttl)
|
||||
h.recordResult(err == nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value and tracks health
|
||||
func (h *HealthCheckBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
value, ttl, exists, err := h.backend.Get(ctx, key)
|
||||
h.recordResult(err == nil)
|
||||
return value, ttl, exists, err
|
||||
}
|
||||
|
||||
// Delete removes a key and tracks health
|
||||
func (h *HealthCheckBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
deleted, err := h.backend.Delete(ctx, key)
|
||||
h.recordResult(err == nil)
|
||||
return deleted, err
|
||||
}
|
||||
|
||||
// Exists checks if a key exists and tracks health
|
||||
func (h *HealthCheckBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
exists, err := h.backend.Exists(ctx, key)
|
||||
h.recordResult(err == nil)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// Clear removes all keys and tracks health
|
||||
func (h *HealthCheckBackend) Clear(ctx context.Context) error {
|
||||
err := h.backend.Clear(ctx)
|
||||
h.recordResult(err == nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetStats returns statistics including health status
|
||||
func (h *HealthCheckBackend) GetStats() map[string]interface{} {
|
||||
stats := h.backend.GetStats()
|
||||
if stats == nil {
|
||||
stats = make(map[string]interface{})
|
||||
}
|
||||
|
||||
h.checkMutex.RLock()
|
||||
lastCheck := h.lastCheck
|
||||
h.checkMutex.RUnlock()
|
||||
|
||||
status := HealthStatus(h.status.Load())
|
||||
stats["health"] = map[string]interface{}{
|
||||
"status": status.String(),
|
||||
"consecutive_fails": h.consecutiveFails.Load(),
|
||||
"consecutive_ok": h.consecutiveOK.Load(),
|
||||
"last_check": lastCheck.Format(time.RFC3339),
|
||||
"time_since_check": time.Since(lastCheck).Seconds(),
|
||||
"check_interval_sec": h.config.CheckInterval.Seconds(),
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks backend health
|
||||
func (h *HealthCheckBackend) Ping(ctx context.Context) error {
|
||||
err := h.backend.Ping(ctx)
|
||||
h.recordResult(err == nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close shuts down the health checker and backend
|
||||
func (h *HealthCheckBackend) Close() error {
|
||||
// Stop health check routine
|
||||
h.cancel()
|
||||
|
||||
// Wait for routine to finish
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
h.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Finished normally
|
||||
case <-time.After(2 * time.Second):
|
||||
// Timeout
|
||||
}
|
||||
|
||||
return h.backend.Close()
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy
|
||||
func (h *HealthCheckBackend) IsHealthy() bool {
|
||||
status := HealthStatus(h.status.Load())
|
||||
return status == HealthHealthy || status == HealthDegraded
|
||||
}
|
||||
|
||||
// recordResult records the result of an operation for health tracking
|
||||
func (h *HealthCheckBackend) recordResult(success bool) {
|
||||
if success {
|
||||
fails := h.consecutiveFails.Swap(0)
|
||||
oks := h.consecutiveOK.Add(1)
|
||||
|
||||
// Check if we should transition to healthy
|
||||
if fails > 0 && oks >= int32(h.config.HealthyThreshold) {
|
||||
oldStatus := HealthStatus(h.status.Swap(int32(HealthHealthy)))
|
||||
if oldStatus != HealthHealthy && h.config.OnStatusChange != nil {
|
||||
h.config.OnStatusChange(oldStatus, HealthHealthy)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
oks := h.consecutiveOK.Swap(0)
|
||||
fails := h.consecutiveFails.Add(1)
|
||||
|
||||
// Check if we should transition to unhealthy
|
||||
if oks > 0 && fails >= int32(h.config.UnhealthyThreshold) {
|
||||
oldStatus := HealthStatus(h.status.Swap(int32(HealthUnhealthy)))
|
||||
if oldStatus != HealthUnhealthy && h.config.OnStatusChange != nil {
|
||||
h.config.OnStatusChange(oldStatus, HealthUnhealthy)
|
||||
}
|
||||
} else if fails >= int32(h.config.UnhealthyThreshold)*2 {
|
||||
// Severely degraded
|
||||
h.status.Store(int32(HealthUnhealthy))
|
||||
} else if fails >= int32(h.config.UnhealthyThreshold) {
|
||||
// Degraded but still trying
|
||||
h.status.Store(int32(HealthDegraded))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthCheckLoop runs periodic health checks
|
||||
func (h *HealthCheckBackend) healthCheckLoop() {
|
||||
defer h.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(h.config.CheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Do initial check
|
||||
h.performHealthCheck()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
h.performHealthCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthCheck performs a single health check
|
||||
func (h *HealthCheckBackend) performHealthCheck() {
|
||||
h.checkMutex.Lock()
|
||||
h.lastCheck = time.Now()
|
||||
h.checkMutex.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), h.config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
err := h.backend.Ping(ctx)
|
||||
h.recordResult(err == nil)
|
||||
}
|
||||
+447
@@ -0,0 +1,447 @@
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestHealthChecker_StatusTransitions tests health status transitions
|
||||
func TestHealthChecker_StatusTransitions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
var shouldFail atomic.Bool
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
if shouldFail.Load() {
|
||||
return errors.New("health check failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
// Initially unknown
|
||||
assert.Equal(t, HealthUnknown, hc.GetStatus())
|
||||
|
||||
// Trigger failures
|
||||
shouldFail.Store(true)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Should be unhealthy after threshold failures
|
||||
status := hc.GetStatus()
|
||||
assert.True(t, status == HealthUnhealthy || status == HealthDegraded)
|
||||
|
||||
// Recover
|
||||
shouldFail.Store(false)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should recover towards healthy
|
||||
finalStatus := hc.GetStatus()
|
||||
assert.True(t, finalStatus == HealthHealthy || finalStatus == HealthDegraded || finalStatus == HealthUnknown)
|
||||
}
|
||||
|
||||
// TestHealthChecker_InitialState tests initial health status
|
||||
func TestHealthChecker_InitialState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
hc := NewHealthChecker(config)
|
||||
assert.Equal(t, HealthUnknown, hc.GetStatus())
|
||||
assert.False(t, hc.IsHealthy())
|
||||
}
|
||||
|
||||
// TestHealthChecker_ForceCheck tests manual health check trigger
|
||||
func TestHealthChecker_ForceCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 10 * time.Second, // Long interval
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
initialCount := callCount.Load()
|
||||
|
||||
// Force check
|
||||
hc.Check(context.Background())
|
||||
|
||||
// Should have been called
|
||||
assert.Greater(t, callCount.Load(), initialCount)
|
||||
}
|
||||
|
||||
// TestHealthChecker_StatusChangeCallback tests status change notifications
|
||||
func TestHealthChecker_StatusChangeCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var transitions []string
|
||||
var mu sync.Mutex
|
||||
var shouldFail atomic.Bool
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
if shouldFail.Load() {
|
||||
return errors.New("health check failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 2,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
OnStatusChange: func(from, to HealthStatus) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
transitions = append(transitions, from.String()+"->"+to.String())
|
||||
},
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
// Trigger failures
|
||||
shouldFail.Store(true)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Recover
|
||||
shouldFail.Store(false)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Should have status transitions
|
||||
assert.NotEmpty(t, transitions)
|
||||
}
|
||||
|
||||
// TestHealthChecker_Stats tests statistics tracking
|
||||
func TestHealthChecker_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
if callCount.Load()%2 == 0 {
|
||||
return errors.New("failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 20 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 5,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
stats := hc.Stats()
|
||||
|
||||
assert.Greater(t, stats.TotalChecks, int64(0))
|
||||
assert.Greater(t, stats.TotalFailures, int64(0))
|
||||
assert.Greater(t, stats.SuccessRate, 0.0)
|
||||
assert.Less(t, stats.SuccessRate, 1.0)
|
||||
}
|
||||
|
||||
// TestHealthChecker_Timeout tests check timeout handling
|
||||
func TestHealthChecker_Timeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
// Simulate slow check
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond, // Short timeout
|
||||
UnhealthyThreshold: 2,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be unhealthy due to timeouts
|
||||
status := hc.GetStatus()
|
||||
assert.NotEqual(t, HealthHealthy, status)
|
||||
}
|
||||
|
||||
// TestHealthChecker_ConcurrentAccess tests thread safety
|
||||
func TestHealthChecker_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 10 * time.Millisecond,
|
||||
Timeout: 5 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 50; j++ {
|
||||
_ = hc.GetStatus()
|
||||
_ = hc.IsHealthy()
|
||||
_ = hc.Stats()
|
||||
hc.Check(context.Background())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
// Should complete without panics
|
||||
}
|
||||
|
||||
// TestHealthChecker_StopAndStart tests lifecycle management
|
||||
func TestHealthChecker_StopAndStart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 20 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
// Start
|
||||
hc.Start()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
count1 := callCount.Load()
|
||||
assert.Greater(t, count1, int32(0))
|
||||
|
||||
// Stop
|
||||
hc.Stop()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
count2 := callCount.Load()
|
||||
|
||||
// Should not have increased significantly after stop
|
||||
assert.Less(t, count2-count1, int32(3))
|
||||
}
|
||||
|
||||
// TestHealthChecker_DegradedState tests degraded status
|
||||
func TestHealthChecker_DegradedState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
count := callCount.Add(1)
|
||||
// Fail once, then succeed
|
||||
if count == 1 {
|
||||
return errors.New("single failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3, // Need 3 failures for unhealthy
|
||||
HealthyThreshold: 2, // Need 2 successes for healthy
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// After initial checks, status should be set (might be healthy or degraded based on execution)
|
||||
status := hc.GetStatus()
|
||||
assert.True(t, status != HealthUnknown, "Status should not be unknown after checks")
|
||||
}
|
||||
|
||||
// TestHealthChecker_DefaultConfig tests default configuration
|
||||
func TestHealthChecker_DefaultConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
assert.NotNil(t, hc)
|
||||
assert.Equal(t, HealthUnknown, hc.GetStatus())
|
||||
|
||||
// Verify default config was applied (we can't access private fields, so just check it works)
|
||||
assert.NotNil(t, hc)
|
||||
}
|
||||
|
||||
// TestHealthChecker_StatusString tests status string representation
|
||||
func TestHealthChecker_StatusString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "healthy", HealthHealthy.String())
|
||||
assert.Equal(t, "unhealthy", HealthUnhealthy.String())
|
||||
assert.Equal(t, "degraded", HealthDegraded.String())
|
||||
assert.Equal(t, "unknown", HealthStatus(999).String())
|
||||
}
|
||||
|
||||
// TestHealthChecker_RecoveryPattern tests typical failure and recovery
|
||||
func TestHealthChecker_RecoveryPattern(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var checkNumber atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
n := checkNumber.Add(1)
|
||||
// Fail checks 3-5, succeed others
|
||||
if n >= 3 && n <= 5 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var statusLog []HealthStatus
|
||||
var mu sync.Mutex
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
OnStatusChange: func(from, to HealthStatus) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
statusLog = append(statusLog, to)
|
||||
},
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Should see transitions through unhealthy and back to healthy
|
||||
assert.NotEmpty(t, statusLog)
|
||||
|
||||
// Final status should be healthy or degraded (recovered)
|
||||
finalStatus := hc.GetStatus()
|
||||
assert.True(t, finalStatus == HealthHealthy || finalStatus == HealthDegraded, "Should have recovered")
|
||||
}
|
||||
|
||||
// Benchmark health checker performance
|
||||
func BenchmarkHealthChecker_ForceCheck(b *testing.B) {
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 10 * time.Minute,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
hc.Check(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHealthChecker_Status(b *testing.B) {
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = hc.GetStatus()
|
||||
}
|
||||
}
|
||||
@@ -39,25 +39,25 @@ func (p *Auth0Provider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
// Ensure offline_access scope is present for refresh tokens
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
if scope == ScopeOfflineAccess {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
scopes = append(scopes, ScopeOfflineAccess)
|
||||
}
|
||||
|
||||
// Ensure openid scope is present
|
||||
hasOpenID := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "openid" {
|
||||
if scope == ScopeOpenID {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID {
|
||||
scopes = append(scopes, "openid")
|
||||
scopes = append(scopes, ScopeOpenID)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -40,7 +40,7 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
|
||||
// Remove offline_access scope as Cognito doesn't use it (case-insensitive)
|
||||
var filteredScopes []string
|
||||
for _, scope := range scopes {
|
||||
if strings.ToLower(scope) != "offline_access" {
|
||||
if strings.ToLower(scope) != ScopeOfflineAccess {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
@@ -48,18 +48,18 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
|
||||
// Ensure openid scope is present
|
||||
hasOpenID := false
|
||||
for _, scope := range filteredScopes {
|
||||
if scope == "openid" {
|
||||
if scope == ScopeOpenID {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID {
|
||||
filteredScopes = append(filteredScopes, "openid")
|
||||
filteredScopes = append(filteredScopes, ScopeOpenID)
|
||||
}
|
||||
|
||||
// Default Cognito scopes if none specified
|
||||
if len(filteredScopes) == 1 && filteredScopes[0] == "openid" {
|
||||
filteredScopes = append(filteredScopes, "email", "profile")
|
||||
if len(filteredScopes) == 1 && filteredScopes[0] == ScopeOpenID {
|
||||
filteredScopes = append(filteredScopes, ScopeEmail, ScopeProfile)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -38,13 +38,13 @@ func (p *AzureProvider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
if scope == ScopeOfflineAccess {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
scopes = append(scopes, ScopeOfflineAccess)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -102,17 +102,17 @@ func (p *BaseProvider) ValidateTokenExpiry(session Session, token string, tokenC
|
||||
}
|
||||
|
||||
// BuildAuthParams constructs authorization parameters for the provider.
|
||||
// It includes the "offline_access" scope by default for refresh token support.
|
||||
// It includes the offline_access scope by default for refresh token support.
|
||||
func (p *BaseProvider) BuildAuthParams(baseParams url.Values, scopes []string) (*AuthParams, error) {
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
if scope == ScopeOfflineAccess {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
scopes = append(scopes, ScopeOfflineAccess)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -38,7 +38,7 @@ func (p *GitHubProvider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
// GitHub doesn't use offline_access scope, so remove it if present
|
||||
var filteredScopes []string
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
if scope != ScopeOfflineAccess {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (p *GitLabProvider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
// Remove offline_access scope as GitLab doesn't use it
|
||||
var filteredScopes []string
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
if scope != ScopeOfflineAccess {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
@@ -47,18 +47,18 @@ func (p *GitLabProvider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
// Ensure openid scope is present for OIDC
|
||||
hasOpenID := false
|
||||
for _, scope := range filteredScopes {
|
||||
if scope == "openid" {
|
||||
if scope == ScopeOpenID {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID {
|
||||
filteredScopes = append(filteredScopes, "openid")
|
||||
filteredScopes = append(filteredScopes, ScopeOpenID)
|
||||
}
|
||||
|
||||
// Default GitLab scopes if none specified
|
||||
if len(filteredScopes) == 1 && filteredScopes[0] == "openid" {
|
||||
filteredScopes = append(filteredScopes, "profile", "email")
|
||||
if len(filteredScopes) == 1 && filteredScopes[0] == ScopeOpenID {
|
||||
filteredScopes = append(filteredScopes, ScopeProfile, ScopeEmail)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -36,10 +36,10 @@ func (p *GoogleProvider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
baseParams.Set("access_type", "offline")
|
||||
baseParams.Set("prompt", "consent")
|
||||
|
||||
// Google does not use the "offline_access" scope, so we remove it if present.
|
||||
// Google does not use the ScopeOfflineAccess scope, so we remove it if present.
|
||||
var filteredScopes []string
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
if scope != ScopeOfflineAccess {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,14 @@ const (
|
||||
ProviderTypeGitLab
|
||||
)
|
||||
|
||||
// Standard OAuth2/OIDC scope constants
|
||||
const (
|
||||
ScopeOfflineAccess = "offline_access"
|
||||
ScopeOpenID = "openid"
|
||||
ScopeProfile = "profile"
|
||||
ScopeEmail = "email"
|
||||
)
|
||||
|
||||
// ProviderCapabilities defines the specific features and behaviors of an OIDC provider.
|
||||
type ProviderCapabilities struct {
|
||||
PreferredTokenValidation string
|
||||
|
||||
@@ -39,25 +39,25 @@ func (p *KeycloakProvider) BuildAuthParams(baseParams url.Values, scopes []strin
|
||||
// Ensure offline_access scope is present for refresh tokens
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
if scope == ScopeOfflineAccess {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
scopes = append(scopes, ScopeOfflineAccess)
|
||||
}
|
||||
|
||||
// Ensure openid scope is present
|
||||
hasOpenID := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "openid" {
|
||||
if scope == ScopeOpenID {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID {
|
||||
scopes = append(scopes, "openid")
|
||||
scopes = append(scopes, ScopeOpenID)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -39,25 +39,25 @@ func (p *OktaProvider) BuildAuthParams(baseParams url.Values, scopes []string) (
|
||||
// Ensure offline_access scope is present for refresh tokens
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
if scope == ScopeOfflineAccess {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
scopes = append(scopes, ScopeOfflineAccess)
|
||||
}
|
||||
|
||||
// Ensure openid scope is present
|
||||
hasOpenID := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "openid" {
|
||||
if scope == ScopeOpenID {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID {
|
||||
scopes = append(scopes, "openid")
|
||||
scopes = append(scopes, ScopeOpenID)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -61,7 +61,7 @@ func (v *ConfigValidator) ValidateScopes(scopes []string) error {
|
||||
|
||||
hasOpenIDScope := false
|
||||
for _, scope := range scopes {
|
||||
if strings.TrimSpace(scope) == "openid" {
|
||||
if strings.TrimSpace(scope) == ScopeOpenID {
|
||||
hasOpenIDScope = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
httpClient = CreateDefaultHTTPClient()
|
||||
}
|
||||
goroutineWG := &sync.WaitGroup{}
|
||||
cacheManager := GetGlobalCacheManager(goroutineWG)
|
||||
cacheManager := GetGlobalCacheManagerWithConfig(goroutineWG, config)
|
||||
|
||||
// Use provided context instead of creating new one
|
||||
var pluginCtx context.Context
|
||||
|
||||
@@ -0,0 +1,404 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRedisIntegration_MultipleInstances tests cache sharing across multiple instances
|
||||
func TestRedisIntegration_MultipleInstances(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Start miniredis server
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create two backend instances sharing the same Redis
|
||||
config1 := backends.DefaultRedisConfig(mr.Addr())
|
||||
config1.RedisPrefix = "shared:"
|
||||
backend1, err := backends.NewRedisBackend(config1)
|
||||
require.NoError(t, err)
|
||||
defer backend1.Close()
|
||||
|
||||
config2 := backends.DefaultRedisConfig(mr.Addr())
|
||||
config2.RedisPrefix = "shared:"
|
||||
backend2, err := backends.NewRedisBackend(config2)
|
||||
require.NoError(t, err)
|
||||
defer backend2.Close()
|
||||
|
||||
t.Run("ShareTokenBlacklist", func(t *testing.T) {
|
||||
// Instance 1 blacklists a JTI
|
||||
jti := "test-jti-12345"
|
||||
err := backend1.Set(ctx, "jti:"+jti, []byte("blacklisted"), 10*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instance 2 should see the blacklisted JTI
|
||||
_, _, exists, err := backend2.Get(ctx, "jti:"+jti)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "JTI should be visible across instances")
|
||||
})
|
||||
|
||||
t.Run("ShareTokenCache", func(t *testing.T) {
|
||||
// Instance 1 caches a token
|
||||
token := "access-token-xyz"
|
||||
tokenData := []byte(`{"sub":"user123","exp":1234567890}`)
|
||||
err := backend1.Set(ctx, "token:"+token, tokenData, 5*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instance 2 retrieves the cached token
|
||||
retrieved, _, exists, err := backend2.Get(ctx, "token:"+token)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, tokenData, retrieved)
|
||||
})
|
||||
|
||||
t.Run("ShareMetadataCache", func(t *testing.T) {
|
||||
// Instance 1 caches provider metadata
|
||||
metadataKey := "metadata:provider123"
|
||||
metadata := []byte(`{"issuer":"https://example.com","jwks_uri":"https://example.com/jwks"}`)
|
||||
err := backend1.Set(ctx, metadataKey, metadata, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instance 2 retrieves the metadata
|
||||
retrieved, ttl, exists, err := backend2.Get(ctx, metadataKey)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, metadata, retrieved)
|
||||
assert.Greater(t, ttl, 50*time.Minute)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisIntegration_JTIReplayDetection tests JTI replay detection across instances
|
||||
func TestRedisIntegration_JTIReplayDetection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Multiple Traefik instances
|
||||
instances := make([]*backends.RedisBackend, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
config := backends.DefaultRedisConfig(mr.Addr())
|
||||
config.RedisPrefix = "jti:"
|
||||
instances[i], err = backends.NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer instances[i].Close()
|
||||
}
|
||||
|
||||
t.Run("PreventReplayAcrossInstances", func(t *testing.T) {
|
||||
jti := "replay-test-jti"
|
||||
|
||||
// First instance processes token and blacklists JTI
|
||||
err := instances[0].Set(ctx, jti, []byte("used"), 24*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Other instances should detect the used JTI
|
||||
for i := 1; i < 3; i++ {
|
||||
exists, err := instances[i].Exists(ctx, jti)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Instance %d should see blacklisted JTI", i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ConcurrentJTIChecks", func(t *testing.T) {
|
||||
jtiBase := "concurrent-jti"
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Simulate concurrent token processing across instances
|
||||
for instanceID := 0; instanceID < 3; instanceID++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
jti := fmt.Sprintf("%s-%d-%d", jtiBase, id, j)
|
||||
|
||||
// Check if JTI exists
|
||||
exists, _ := instances[id].Exists(ctx, jti)
|
||||
if !exists {
|
||||
// Mark as used
|
||||
instances[id].Set(ctx, jti, []byte("used"), 1*time.Hour)
|
||||
}
|
||||
}
|
||||
}(instanceID)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all JTIs were recorded
|
||||
for instanceID := 0; instanceID < 3; instanceID++ {
|
||||
for j := 0; j < 10; j++ {
|
||||
jti := fmt.Sprintf("%s-%d-%d", jtiBase, instanceID, j)
|
||||
exists, err := instances[0].Exists(ctx, jti)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "JTI %s should exist", jti)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisIntegration_Failover tests failover scenarios
|
||||
func TestRedisIntegration_Failover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
config := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisBackend, err := backends.NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer redisBackend.Close()
|
||||
|
||||
t.Run("RedisTemporaryFailure", func(t *testing.T) {
|
||||
// Set some data
|
||||
key := "failover-key"
|
||||
value := []byte("failover-value")
|
||||
err := redisBackend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate Redis error
|
||||
mr.SetError("simulated connection error")
|
||||
|
||||
// Operations should fail gracefully
|
||||
_, _, exists, err := redisBackend.Get(ctx, key)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
// Clear error
|
||||
mr.SetError("")
|
||||
|
||||
// Operations should work again
|
||||
retrieved, _, exists, err := redisBackend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisIntegration_HighLoad tests high load scenarios
|
||||
func TestRedisIntegration_HighLoad(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping high load test in short mode")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
config := backends.DefaultRedisConfig(mr.Addr())
|
||||
config.PoolSize = 20
|
||||
redisBackend, err := backends.NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer redisBackend.Close()
|
||||
|
||||
t.Run("HighConcurrency", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 50
|
||||
operations := 100
|
||||
|
||||
errors := make(chan error, goroutines*operations)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < operations; j++ {
|
||||
key := fmt.Sprintf("high-load-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("high-load-value-%d-%d", id, j))
|
||||
|
||||
// Write
|
||||
if err := redisBackend.Set(ctx, key, value, 1*time.Minute); err != nil {
|
||||
errors <- err
|
||||
continue
|
||||
}
|
||||
|
||||
// Read
|
||||
retrieved, _, exists, err := redisBackend.Get(ctx, key)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
continue
|
||||
}
|
||||
if !exists {
|
||||
errors <- fmt.Errorf("key %s does not exist", key)
|
||||
continue
|
||||
}
|
||||
if string(retrieved) != string(value) {
|
||||
errors <- fmt.Errorf("value mismatch for key %s", key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
errorCount := 0
|
||||
for err := range errors {
|
||||
t.Logf("Operation error: %v", err)
|
||||
errorCount++
|
||||
}
|
||||
|
||||
// Allow small error rate (< 1%)
|
||||
totalOps := goroutines * operations
|
||||
errorRate := float64(errorCount) / float64(totalOps)
|
||||
assert.Less(t, errorRate, 0.01, "Error rate should be less than 1%%")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisIntegration_TTLConsistency tests TTL consistency across operations
|
||||
func TestRedisIntegration_TTLConsistency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
config := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisBackend, err := backends.NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer redisBackend.Close()
|
||||
|
||||
t.Run("TTLAccuracy", func(t *testing.T) {
|
||||
key := "ttl-test-key"
|
||||
value := []byte("ttl-test-value")
|
||||
ttl := 5 * time.Second
|
||||
|
||||
err := redisBackend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check TTL immediately
|
||||
_, ttl1, exists, err := redisBackend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Greater(t, ttl1, 4*time.Second)
|
||||
assert.LessOrEqual(t, ttl1, ttl)
|
||||
|
||||
// Fast forward 2 seconds
|
||||
mr.FastForward(2 * time.Second)
|
||||
|
||||
// Check TTL again
|
||||
_, ttl2, exists, err := redisBackend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Less(t, ttl2, ttl1)
|
||||
assert.Greater(t, ttl2, 2*time.Second)
|
||||
|
||||
// Fast forward past expiration
|
||||
mr.FastForward(4 * time.Second)
|
||||
|
||||
// Should be expired
|
||||
_, _, exists, err = redisBackend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisIntegration_MemoryUsage tests memory efficiency
|
||||
func TestRedisIntegration_MemoryUsage(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping memory usage test in short mode")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
config := backends.DefaultRedisConfig(mr.Addr())
|
||||
redisBackend, err := backends.NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer redisBackend.Close()
|
||||
|
||||
t.Run("LargeDataset", func(t *testing.T) {
|
||||
// Store 10,000 items
|
||||
itemCount := 10000
|
||||
for i := 0; i < itemCount; i++ {
|
||||
key := fmt.Sprintf("memory-test-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("memory-test-value-%d-with-some-padding-to-make-it-larger", i))
|
||||
err := redisBackend.Set(ctx, key, value, 10*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Log progress
|
||||
if i%1000 == 0 {
|
||||
t.Logf("Stored %d items", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all items exist
|
||||
for i := 0; i < itemCount; i += 100 {
|
||||
key := fmt.Sprintf("memory-test-key-%d", i)
|
||||
exists, err := redisBackend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
// Check stats
|
||||
stats := redisBackend.GetStats()
|
||||
t.Logf("Redis backend stats: %+v", stats)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisIntegration_Cleanup tests cache cleanup functionality
|
||||
func TestRedisIntegration_Cleanup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err)
|
||||
defer mr.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
config := backends.DefaultRedisConfig(mr.Addr())
|
||||
config.RedisPrefix = "cleanup-test:"
|
||||
redisBackend, err := backends.NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer redisBackend.Close()
|
||||
|
||||
t.Run("BulkCleanup", func(t *testing.T) {
|
||||
// Add many items
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("cleanup-value-%d", i))
|
||||
err := redisBackend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Clear all
|
||||
err := redisBackend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all items are gone
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
exists, err := redisBackend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
}
|
||||
})
|
||||
}
|
||||
+17
-49
@@ -971,7 +971,9 @@ func (cm *ChunkManager) validateTokenExpiration(token string, config TokenConfig
|
||||
// Returns:
|
||||
// - The expiration time if present, nil if no 'exp' claim.
|
||||
// - An error if JWT parsing fails.
|
||||
func (cm *ChunkManager) extractJWTExpiration(token string) (*time.Time, error) {
|
||||
//
|
||||
// extractJWTClaim extracts a time claim from a JWT token
|
||||
func (cm *ChunkManager) extractJWTClaim(token, claimName string) (*time.Time, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format")
|
||||
@@ -992,25 +994,29 @@ func (cm *ChunkManager) extractJWTExpiration(token string) (*time.Time, error) {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
exp, exists := claims["exp"]
|
||||
claimValue, exists := claims[claimName]
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Convert expiration to time.Time
|
||||
var expTime time.Time
|
||||
switch v := exp.(type) {
|
||||
// Convert claim to time.Time
|
||||
var claimTime time.Time
|
||||
switch v := claimValue.(type) {
|
||||
case float64:
|
||||
expTime = time.Unix(int64(v), 0)
|
||||
claimTime = time.Unix(int64(v), 0)
|
||||
case int64:
|
||||
expTime = time.Unix(v, 0)
|
||||
claimTime = time.Unix(v, 0)
|
||||
case int:
|
||||
expTime = time.Unix(int64(v), 0)
|
||||
claimTime = time.Unix(int64(v), 0)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid expiration format: %T", exp)
|
||||
return nil, fmt.Errorf("invalid %s format: %T", claimName, claimValue)
|
||||
}
|
||||
|
||||
return &expTime, nil
|
||||
return &claimTime, nil
|
||||
}
|
||||
|
||||
func (cm *ChunkManager) extractJWTExpiration(token string) (*time.Time, error) {
|
||||
return cm.extractJWTClaim(token, "exp")
|
||||
}
|
||||
|
||||
// validateTokenFreshness checks if token is fresh enough for storage.
|
||||
@@ -1062,45 +1068,7 @@ func (cm *ChunkManager) validateTokenFreshness(token string, config TokenConfig)
|
||||
// - The issued at time if present, nil if no 'iat' claim.
|
||||
// - An error if JWT parsing fails.
|
||||
func (cm *ChunkManager) extractJWTIssuedAt(token string) (*time.Time, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWT payload: %w", err)
|
||||
}
|
||||
|
||||
// Parse the JSON payload using pooled decoder
|
||||
var claims map[string]interface{}
|
||||
pm := pool.Get()
|
||||
decoder := pm.GetJSONDecoder(bytes.NewReader(payload))
|
||||
defer pm.PutJSONDecoder(decoder)
|
||||
|
||||
if err := decoder.Decode(&claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
iat, exists := claims["iat"]
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Convert issued at to time.Time
|
||||
var iatTime time.Time
|
||||
switch v := iat.(type) {
|
||||
case float64:
|
||||
iatTime = time.Unix(int64(v), 0)
|
||||
case int64:
|
||||
iatTime = time.Unix(v, 0)
|
||||
case int:
|
||||
iatTime = time.Unix(int64(v), 0)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid issued at format: %T", iat)
|
||||
}
|
||||
|
||||
return &iatTime, nil
|
||||
return cm.extractJWTClaim(token, "iat")
|
||||
}
|
||||
|
||||
// CleanupExpiredSessions removes expired sessions to prevent memory leaks.
|
||||
|
||||
+413
@@ -89,6 +89,73 @@ type Config struct {
|
||||
// Recommended: true for multi-replica deployments
|
||||
DisableReplayDetection bool `json:"disableReplayDetection,omitempty"`
|
||||
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
|
||||
// Redis configures the Redis cache backend for distributed caching.
|
||||
// When enabled, provides cache sharing across multiple Traefik replicas.
|
||||
// Default: nil (disabled - uses in-memory caching)
|
||||
Redis *RedisConfig `json:"redis,omitempty"`
|
||||
}
|
||||
|
||||
// RedisConfig configures Redis cache backend settings for distributed caching.
|
||||
// All fields support both JSON and YAML configuration for compatibility with Traefik's
|
||||
// dynamic configuration (labels, YAML files, etc.)
|
||||
type RedisConfig struct {
|
||||
// Enabled indicates if Redis caching should be used (default: false)
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
|
||||
// Address is the Redis server address (e.g., "localhost:6379", "redis:6379")
|
||||
Address string `json:"address" yaml:"address"`
|
||||
|
||||
// Password for Redis authentication (optional, leave empty for no auth)
|
||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||
|
||||
// DB is the Redis database number to use (default: 0)
|
||||
DB int `json:"db" yaml:"db"`
|
||||
|
||||
// KeyPrefix is the prefix for all Redis keys (default: "traefikoidc:")
|
||||
KeyPrefix string `json:"keyPrefix" yaml:"keyPrefix"`
|
||||
|
||||
// PoolSize is the maximum number of socket connections (default: 10)
|
||||
PoolSize int `json:"poolSize" yaml:"poolSize"`
|
||||
|
||||
// ConnectTimeout is the timeout for establishing connections in seconds (default: 5)
|
||||
ConnectTimeout int `json:"connectTimeout" yaml:"connectTimeout"`
|
||||
|
||||
// ReadTimeout is the timeout for read operations in seconds (default: 3)
|
||||
ReadTimeout int `json:"readTimeout" yaml:"readTimeout"`
|
||||
|
||||
// WriteTimeout is the timeout for write operations in seconds (default: 3)
|
||||
WriteTimeout int `json:"writeTimeout" yaml:"writeTimeout"`
|
||||
|
||||
// EnableTLS indicates if TLS should be used for Redis connections (default: false)
|
||||
EnableTLS bool `json:"enableTLS" yaml:"enableTLS"`
|
||||
|
||||
// TLSSkipVerify skips TLS certificate verification (not recommended for production)
|
||||
TLSSkipVerify bool `json:"tlsSkipVerify" yaml:"tlsSkipVerify"`
|
||||
|
||||
// CacheMode determines the caching strategy: "redis" (Redis only), "hybrid" (Memory+Redis), "memory" (Memory only)
|
||||
// Default: "redis" when enabled
|
||||
CacheMode string `json:"cacheMode" yaml:"cacheMode"`
|
||||
|
||||
// HybridL1Size is the maximum number of items in L1 cache for hybrid mode (default: 500)
|
||||
HybridL1Size int `json:"hybridL1Size" yaml:"hybridL1Size"`
|
||||
|
||||
// HybridL1MemoryMB is the maximum memory in MB for L1 cache in hybrid mode (default: 10)
|
||||
HybridL1MemoryMB int64 `json:"hybridL1MemoryMB" yaml:"hybridL1MemoryMB"`
|
||||
|
||||
// EnableCircuitBreaker enables circuit breaker for Redis failures (default: true)
|
||||
EnableCircuitBreaker bool `json:"enableCircuitBreaker" yaml:"enableCircuitBreaker"`
|
||||
|
||||
// CircuitBreakerThreshold is the number of failures before opening circuit (default: 5)
|
||||
CircuitBreakerThreshold int `json:"circuitBreakerThreshold" yaml:"circuitBreakerThreshold"`
|
||||
|
||||
// CircuitBreakerTimeout is the timeout in seconds before attempting to close circuit (default: 60)
|
||||
CircuitBreakerTimeout int `json:"circuitBreakerTimeout" yaml:"circuitBreakerTimeout"`
|
||||
|
||||
// EnableHealthCheck enables periodic health checks for Redis (default: true)
|
||||
EnableHealthCheck bool `json:"enableHealthCheck" yaml:"enableHealthCheck"`
|
||||
|
||||
// HealthCheckInterval is the interval in seconds between health checks (default: 30)
|
||||
HealthCheckInterval int `json:"healthCheckInterval" yaml:"healthCheckInterval"`
|
||||
}
|
||||
|
||||
// SecurityHeadersConfig configures security headers for the plugin
|
||||
@@ -167,11 +234,14 @@ const (
|
||||
// - PostLogoutRedirectURI: "/"
|
||||
// - ForceHTTPS: true (for security)
|
||||
// - EnablePKCE: false (PKCE is opt-in)
|
||||
// - Redis: nil (disabled by default, can be configured via Traefik config or env vars)
|
||||
//
|
||||
// CreateConfig initializes a new Config struct with default values for optional fields.
|
||||
// It sets default scopes, log level, rate limit, enables ForceHTTPS, and sets the
|
||||
// default refresh grace period. Required fields like ProviderURL, ClientID, ClientSecret,
|
||||
// CallbackURL, and SessionEncryptionKey must be set explicitly after creation.
|
||||
// Redis configuration can be provided through Traefik's dynamic configuration or
|
||||
// as a fallback through environment variables.
|
||||
//
|
||||
// Returns:
|
||||
// - A pointer to a new Config struct with default settings applied.
|
||||
@@ -185,6 +255,7 @@ func CreateConfig() *Config {
|
||||
OverrideScopes: false, // Default to appending scopes, not overriding
|
||||
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
|
||||
SecurityHeaders: createDefaultSecurityConfig(),
|
||||
Redis: nil, // Redis is disabled by default, configure via Traefik or env vars
|
||||
}
|
||||
|
||||
return c
|
||||
@@ -329,6 +400,13 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate Redis configuration if provided
|
||||
if c.Redis != nil && c.Redis.Enabled {
|
||||
if err := c.Redis.Validate(); err != nil {
|
||||
return fmt.Errorf("redis configuration error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate headers configuration for template security
|
||||
for _, header := range c.Headers {
|
||||
if header.Name == "" {
|
||||
@@ -803,6 +881,341 @@ func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Req
|
||||
}
|
||||
|
||||
// isOriginAllowed checks if an origin is in the allowed list
|
||||
// Validate checks if the Redis configuration is valid
|
||||
func (rc *RedisConfig) Validate() error {
|
||||
if !rc.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rc.Address == "" {
|
||||
return fmt.Errorf("redis address is required when Redis is enabled")
|
||||
}
|
||||
|
||||
// Validate cache mode
|
||||
if rc.CacheMode != "" {
|
||||
validModes := map[string]bool{
|
||||
"redis": true,
|
||||
"hybrid": true,
|
||||
"memory": true,
|
||||
}
|
||||
if !validModes[rc.CacheMode] {
|
||||
return fmt.Errorf("invalid cache mode: %s (must be 'redis', 'hybrid', or 'memory')", rc.CacheMode)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate connection settings
|
||||
if rc.PoolSize < 0 {
|
||||
return fmt.Errorf("pool size cannot be negative")
|
||||
}
|
||||
if rc.ConnectTimeout < 0 {
|
||||
return fmt.Errorf("connect timeout cannot be negative")
|
||||
}
|
||||
if rc.ReadTimeout < 0 {
|
||||
return fmt.Errorf("read timeout cannot be negative")
|
||||
}
|
||||
if rc.WriteTimeout < 0 {
|
||||
return fmt.Errorf("write timeout cannot be negative")
|
||||
}
|
||||
|
||||
// Validate hybrid mode settings
|
||||
if rc.CacheMode == "hybrid" {
|
||||
if rc.HybridL1Size < 0 {
|
||||
return fmt.Errorf("hybrid L1 size cannot be negative")
|
||||
}
|
||||
if rc.HybridL1MemoryMB < 0 {
|
||||
return fmt.Errorf("hybrid L1 memory cannot be negative")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate circuit breaker settings
|
||||
if rc.CircuitBreakerThreshold < 0 {
|
||||
return fmt.Errorf("circuit breaker threshold cannot be negative")
|
||||
}
|
||||
if rc.CircuitBreakerTimeout < 0 {
|
||||
return fmt.Errorf("circuit breaker timeout cannot be negative")
|
||||
}
|
||||
|
||||
// Validate health check settings
|
||||
if rc.HealthCheckInterval < 0 {
|
||||
return fmt.Errorf("health check interval cannot be negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyDefaults sets default values for Redis configuration when fields are not explicitly set.
|
||||
// This ensures reasonable defaults while allowing full customization through configuration.
|
||||
func (rc *RedisConfig) ApplyDefaults() {
|
||||
// Only apply defaults if Redis is enabled
|
||||
if !rc.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
// Connection defaults
|
||||
if rc.KeyPrefix == "" {
|
||||
rc.KeyPrefix = "traefikoidc:"
|
||||
}
|
||||
if rc.PoolSize == 0 {
|
||||
rc.PoolSize = 10
|
||||
}
|
||||
if rc.ConnectTimeout == 0 {
|
||||
rc.ConnectTimeout = 5
|
||||
}
|
||||
if rc.ReadTimeout == 0 {
|
||||
rc.ReadTimeout = 3
|
||||
}
|
||||
if rc.WriteTimeout == 0 {
|
||||
rc.WriteTimeout = 3
|
||||
}
|
||||
|
||||
// Cache mode defaults
|
||||
if rc.CacheMode == "" {
|
||||
rc.CacheMode = "redis" // Default to redis-only mode for simplicity
|
||||
}
|
||||
|
||||
// Hybrid mode specific defaults
|
||||
if rc.CacheMode == "hybrid" {
|
||||
if rc.HybridL1Size == 0 {
|
||||
rc.HybridL1Size = 500
|
||||
}
|
||||
if rc.HybridL1MemoryMB == 0 {
|
||||
rc.HybridL1MemoryMB = 10
|
||||
}
|
||||
}
|
||||
|
||||
// Resilience features - these use a different pattern to detect if they were explicitly set
|
||||
// Since bool fields default to false, we need to be careful about defaults
|
||||
// For now, we'll enable by default only if not explicitly disabled via environment
|
||||
if rc.CircuitBreakerThreshold == 0 {
|
||||
rc.CircuitBreakerThreshold = 5
|
||||
}
|
||||
if rc.CircuitBreakerTimeout == 0 {
|
||||
rc.CircuitBreakerTimeout = 60
|
||||
}
|
||||
if rc.HealthCheckInterval == 0 {
|
||||
rc.HealthCheckInterval = 30
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyEnvFallbacks applies environment variable values as fallbacks for empty config fields.
|
||||
// This allows environment variables to be used as optional overrides only when the
|
||||
// corresponding config field is not set through Traefik's dynamic configuration.
|
||||
// The plugin configuration takes precedence over environment variables.
|
||||
func (rc *RedisConfig) ApplyEnvFallbacks() {
|
||||
// Only apply env fallbacks if Redis is not already configured
|
||||
if !rc.Enabled {
|
||||
// Check if Redis should be enabled from environment
|
||||
enabledStr := os.Getenv("REDIS_ENABLED")
|
||||
if enabledStr == "true" || enabledStr == "1" {
|
||||
rc.Enabled = true
|
||||
}
|
||||
}
|
||||
|
||||
// Only apply other env vars if Redis is enabled
|
||||
if !rc.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
// Apply environment variables only for empty fields
|
||||
if rc.Address == "" {
|
||||
if addr := os.Getenv("REDIS_ADDRESS"); addr != "" {
|
||||
rc.Address = addr
|
||||
}
|
||||
}
|
||||
|
||||
if rc.Password == "" {
|
||||
rc.Password = os.Getenv("REDIS_PASSWORD")
|
||||
}
|
||||
|
||||
if rc.KeyPrefix == "" {
|
||||
if prefix := os.Getenv("REDIS_KEY_PREFIX"); prefix != "" {
|
||||
rc.KeyPrefix = prefix
|
||||
}
|
||||
}
|
||||
|
||||
if rc.CacheMode == "" {
|
||||
if mode := os.Getenv("REDIS_CACHE_MODE"); mode != "" {
|
||||
rc.CacheMode = mode
|
||||
}
|
||||
}
|
||||
|
||||
// Apply numeric values only if not already set
|
||||
if rc.DB == 0 {
|
||||
if dbStr := os.Getenv("REDIS_DB"); dbStr != "" {
|
||||
if db, err := strconv.Atoi(dbStr); err == nil && db > 0 {
|
||||
rc.DB = db
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rc.PoolSize == 0 {
|
||||
if poolSizeStr := os.Getenv("REDIS_POOL_SIZE"); poolSizeStr != "" {
|
||||
if poolSize, err := strconv.Atoi(poolSizeStr); err == nil && poolSize > 0 {
|
||||
rc.PoolSize = poolSize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rc.ConnectTimeout == 0 {
|
||||
if timeoutStr := os.Getenv("REDIS_CONNECT_TIMEOUT"); timeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(timeoutStr); err == nil && timeout > 0 {
|
||||
rc.ConnectTimeout = timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rc.ReadTimeout == 0 {
|
||||
if timeoutStr := os.Getenv("REDIS_READ_TIMEOUT"); timeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(timeoutStr); err == nil && timeout > 0 {
|
||||
rc.ReadTimeout = timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rc.WriteTimeout == 0 {
|
||||
if timeoutStr := os.Getenv("REDIS_WRITE_TIMEOUT"); timeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(timeoutStr); err == nil && timeout > 0 {
|
||||
rc.WriteTimeout = timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply boolean values from env only if not already set in config
|
||||
if !rc.EnableTLS {
|
||||
if tlsStr := os.Getenv("REDIS_ENABLE_TLS"); tlsStr == "true" || tlsStr == "1" {
|
||||
rc.EnableTLS = true
|
||||
}
|
||||
}
|
||||
|
||||
if !rc.TLSSkipVerify {
|
||||
if skipStr := os.Getenv("REDIS_TLS_SKIP_VERIFY"); skipStr == "true" || skipStr == "1" {
|
||||
rc.TLSSkipVerify = true
|
||||
}
|
||||
}
|
||||
|
||||
// Hybrid mode settings
|
||||
if rc.HybridL1Size == 0 {
|
||||
if sizeStr := os.Getenv("REDIS_HYBRID_L1_SIZE"); sizeStr != "" {
|
||||
if size, err := strconv.Atoi(sizeStr); err == nil && size > 0 {
|
||||
rc.HybridL1Size = size
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rc.HybridL1MemoryMB == 0 {
|
||||
if memStr := os.Getenv("REDIS_HYBRID_L1_MEMORY_MB"); memStr != "" {
|
||||
if mem, err := strconv.ParseInt(memStr, 10, 64); err == nil && mem > 0 {
|
||||
rc.HybridL1MemoryMB = mem
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LoadRedisConfigFromEnv loads Redis configuration from environment variables.
|
||||
// Deprecated: Use RedisConfig.ApplyEnvFallbacks() on an existing config instead.
|
||||
// This function is kept for backward compatibility but should not be used directly.
|
||||
func LoadRedisConfigFromEnv() *RedisConfig {
|
||||
// Check if Redis is enabled
|
||||
enabledStr := os.Getenv("REDIS_ENABLED")
|
||||
if enabledStr == "" || enabledStr == "false" || enabledStr == "0" {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &RedisConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
// Parse numeric values
|
||||
if dbStr := os.Getenv("REDIS_DB"); dbStr != "" {
|
||||
if db, err := strconv.Atoi(dbStr); err == nil {
|
||||
config.DB = db
|
||||
}
|
||||
}
|
||||
|
||||
if poolSizeStr := os.Getenv("REDIS_POOL_SIZE"); poolSizeStr != "" {
|
||||
if poolSize, err := strconv.Atoi(poolSizeStr); err == nil {
|
||||
config.PoolSize = poolSize
|
||||
}
|
||||
}
|
||||
|
||||
if connectTimeoutStr := os.Getenv("REDIS_CONNECT_TIMEOUT"); connectTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(connectTimeoutStr); err == nil {
|
||||
config.ConnectTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
if readTimeoutStr := os.Getenv("REDIS_READ_TIMEOUT"); readTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(readTimeoutStr); err == nil {
|
||||
config.ReadTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
if writeTimeoutStr := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(writeTimeoutStr); err == nil {
|
||||
config.WriteTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Parse boolean values
|
||||
if enableTLSStr := os.Getenv("REDIS_ENABLE_TLS"); enableTLSStr == "true" || enableTLSStr == "1" {
|
||||
config.EnableTLS = true
|
||||
}
|
||||
|
||||
if skipVerifyStr := os.Getenv("REDIS_TLS_SKIP_VERIFY"); skipVerifyStr == "true" || skipVerifyStr == "1" {
|
||||
config.TLSSkipVerify = true
|
||||
}
|
||||
|
||||
// Parse hybrid mode settings
|
||||
if l1SizeStr := os.Getenv("REDIS_HYBRID_L1_SIZE"); l1SizeStr != "" {
|
||||
if size, err := strconv.Atoi(l1SizeStr); err == nil {
|
||||
config.HybridL1Size = size
|
||||
}
|
||||
}
|
||||
|
||||
if l1MemoryStr := os.Getenv("REDIS_HYBRID_L1_MEMORY_MB"); l1MemoryStr != "" {
|
||||
if memory, err := strconv.ParseInt(l1MemoryStr, 10, 64); err == nil {
|
||||
config.HybridL1MemoryMB = memory
|
||||
}
|
||||
}
|
||||
|
||||
// Parse circuit breaker settings
|
||||
if enableCBStr := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCBStr == "false" || enableCBStr == "0" {
|
||||
config.EnableCircuitBreaker = false
|
||||
} else {
|
||||
config.EnableCircuitBreaker = true // Default to enabled
|
||||
}
|
||||
|
||||
if cbThresholdStr := os.Getenv("REDIS_CIRCUIT_BREAKER_THRESHOLD"); cbThresholdStr != "" {
|
||||
if threshold, err := strconv.Atoi(cbThresholdStr); err == nil {
|
||||
config.CircuitBreakerThreshold = threshold
|
||||
}
|
||||
}
|
||||
|
||||
if cbTimeoutStr := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeoutStr != "" {
|
||||
if timeout, err := strconv.Atoi(cbTimeoutStr); err == nil {
|
||||
config.CircuitBreakerTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Parse health check settings
|
||||
if enableHCStr := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHCStr == "false" || enableHCStr == "0" {
|
||||
config.EnableHealthCheck = false
|
||||
} else {
|
||||
config.EnableHealthCheck = true // Default to enabled
|
||||
}
|
||||
|
||||
if hcIntervalStr := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcIntervalStr != "" {
|
||||
if interval, err := strconv.Atoi(hcIntervalStr); err == nil {
|
||||
config.HealthCheckInterval = interval
|
||||
}
|
||||
}
|
||||
|
||||
// Apply defaults after loading from env
|
||||
config.ApplyDefaults()
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
for _, allowed := range allowedOrigins {
|
||||
if origin == allowed || allowed == "*" {
|
||||
|
||||
@@ -3,10 +3,13 @@ package traefikoidc
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
)
|
||||
|
||||
// CacheType defines the type of cache for optimized behavior
|
||||
@@ -89,6 +92,9 @@ type UniversalCache struct {
|
||||
config UniversalCacheConfig
|
||||
logger *Logger
|
||||
|
||||
// Backend for distributed caching (NEW)
|
||||
backend backends.CacheBackend
|
||||
|
||||
// Memory management
|
||||
currentSize int64
|
||||
currentMemory int64
|
||||
@@ -110,6 +116,13 @@ func NewUniversalCache(config UniversalCacheConfig) *UniversalCache {
|
||||
return createUniversalCache(config)
|
||||
}
|
||||
|
||||
// NewUniversalCacheWithBackend creates a new universal cache with a specific backend
|
||||
func NewUniversalCacheWithBackend(config UniversalCacheConfig, cacheBackend backends.CacheBackend) *UniversalCache {
|
||||
cache := createUniversalCache(config)
|
||||
cache.backend = cacheBackend
|
||||
return cache
|
||||
}
|
||||
|
||||
// createUniversalCache is the internal constructor
|
||||
func createUniversalCache(config UniversalCacheConfig) *UniversalCache {
|
||||
// Apply type-specific defaults first (including MaxSize)
|
||||
@@ -223,6 +236,25 @@ func (c *UniversalCache) Set(key string, value interface{}, ttl time.Duration) e
|
||||
ttl = c.config.DefaultTTL
|
||||
}
|
||||
|
||||
// If we have a backend, use it for distributed caching
|
||||
if c.backend != nil {
|
||||
// Serialize the value
|
||||
data, err := c.serialize(value)
|
||||
if err != nil {
|
||||
c.logger.Errorf("Failed to serialize value for key %s: %v", key, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Store in backend
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := c.backend.Set(ctx, c.prefixKey(key), data, ttl); err != nil {
|
||||
c.logger.Infof("Backend set error for key %s: %v", key, err)
|
||||
// Continue with local cache even if backend fails
|
||||
}
|
||||
}
|
||||
|
||||
size := c.estimateSize(value)
|
||||
|
||||
c.mu.Lock()
|
||||
@@ -285,6 +317,32 @@ func (c *UniversalCache) Set(key string, value interface{}, ttl time.Duration) e
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
func (c *UniversalCache) Get(key string) (interface{}, bool) {
|
||||
// Try backend first if available (for distributed consistency)
|
||||
if c.backend != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
data, _, exists, err := c.backend.Get(ctx, c.prefixKey(key))
|
||||
if err != nil {
|
||||
c.logger.Debugf("Backend get error for key %s: %v", key, err)
|
||||
// Fall through to local cache
|
||||
} else if exists {
|
||||
// Deserialize the value
|
||||
var value interface{}
|
||||
if err := c.deserialize(data, &value); err != nil {
|
||||
c.logger.Errorf("Failed to deserialize value for key %s: %v", key, err)
|
||||
// Fall through to local cache
|
||||
} else {
|
||||
atomic.AddInt64(&c.hits, 1)
|
||||
// Update local cache with backend value
|
||||
go func() {
|
||||
_ = c.updateLocalCache(key, value, c.config.DefaultTTL)
|
||||
}()
|
||||
return value, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@@ -341,6 +399,17 @@ func (c *UniversalCache) Get(key string) (interface{}, bool) {
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (c *UniversalCache) Delete(key string) bool {
|
||||
// Delete from backend if available
|
||||
if c.backend != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if _, err := c.backend.Delete(ctx, c.prefixKey(key)); err != nil {
|
||||
c.logger.Debugf("Backend delete error for key %s: %v", key, err)
|
||||
// Continue with local delete
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@@ -355,6 +424,17 @@ func (c *UniversalCache) Delete(key string) bool {
|
||||
|
||||
// Clear removes all items from the cache
|
||||
func (c *UniversalCache) Clear() {
|
||||
// Clear backend if available
|
||||
if c.backend != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := c.backend.Clear(ctx); err != nil {
|
||||
c.logger.Infof("Backend clear error: %v", err)
|
||||
// Continue with local clear
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@@ -437,6 +517,13 @@ func (c *UniversalCache) Close() error {
|
||||
// Clear all items
|
||||
c.Clear()
|
||||
|
||||
// Close backend if present
|
||||
if c.backend != nil {
|
||||
if err := c.backend.Close(); err != nil {
|
||||
c.logger.Infof("Failed to close cache backend: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.logger.Debugf("UniversalCache[%s]: Closed", c.config.Type)
|
||||
return nil
|
||||
}
|
||||
@@ -701,3 +788,60 @@ func (c *UniversalCache) Mutex() *sync.RWMutex {
|
||||
func (c *UniversalCache) Strategy() CacheStrategy {
|
||||
return c.config.Strategy
|
||||
}
|
||||
|
||||
// serialize converts a value to bytes for backend storage
|
||||
func (c *UniversalCache) serialize(value interface{}) ([]byte, error) {
|
||||
// Use JSON for serialization - simple and universal
|
||||
return json.Marshal(value)
|
||||
}
|
||||
|
||||
// deserialize converts bytes from backend storage to a value
|
||||
func (c *UniversalCache) deserialize(data []byte, value interface{}) error {
|
||||
// Use JSON for deserialization
|
||||
return json.Unmarshal(data, value)
|
||||
}
|
||||
|
||||
// prefixKey adds a cache type prefix to the key for backend storage
|
||||
func (c *UniversalCache) prefixKey(key string) string {
|
||||
return fmt.Sprintf("%s:%s", c.config.Type, key)
|
||||
}
|
||||
|
||||
// updateLocalCache updates the local cache with a value from the backend
|
||||
func (c *UniversalCache) updateLocalCache(key string, value interface{}, ttl time.Duration) error {
|
||||
size := c.estimateSize(value)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Check memory limits
|
||||
if c.config.MaxMemoryBytes > 0 {
|
||||
for c.currentMemory+size > c.config.MaxMemoryBytes && c.lruList.Len() > 0 {
|
||||
c.evictOldest()
|
||||
}
|
||||
}
|
||||
|
||||
// Check size limits
|
||||
if c.lruList.Len() >= c.config.MaxSize {
|
||||
c.evictOldest()
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
item := &CacheItem{
|
||||
Key: key,
|
||||
Value: value,
|
||||
Size: size,
|
||||
ExpiresAt: now.Add(ttl),
|
||||
LastAccessed: now,
|
||||
AccessCount: 1,
|
||||
CacheType: c.config.Type,
|
||||
Metadata: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
item.element = c.lruList.PushFront(key)
|
||||
c.items[key] = item
|
||||
|
||||
c.currentSize++
|
||||
c.currentMemory += size
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+270
-39
@@ -3,6 +3,9 @@ package traefikoidc
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/resilience"
|
||||
)
|
||||
|
||||
// UniversalCacheManager manages all cache instances using the universal cache
|
||||
@@ -34,25 +37,217 @@ func GetUniversalCacheManager(logger *Logger) *UniversalCacheManager {
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// Initialize token cache - CRITICAL FIX: Reduced from 5000 to 1000
|
||||
universalCacheManager.tokenCache = NewUniversalCache(UniversalCacheConfig{
|
||||
// Initialize with default in-memory backends
|
||||
initializeDefaultCaches(universalCacheManager, logger)
|
||||
})
|
||||
|
||||
return universalCacheManager
|
||||
}
|
||||
|
||||
// GetUniversalCacheManagerWithConfig returns the singleton universal cache manager with Redis configuration
|
||||
func GetUniversalCacheManagerWithConfig(logger *Logger, redisConfig *RedisConfig) *UniversalCacheManager {
|
||||
universalCacheManagerOnce.Do(func() {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
universalCacheManager = &UniversalCacheManager{
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if redisConfig != nil && redisConfig.Enabled {
|
||||
logger.Infof("Initializing cache manager with Redis backend: %s", redisConfig.Address)
|
||||
initializeCachesWithRedis(universalCacheManager, logger, redisConfig)
|
||||
} else {
|
||||
logger.Info("Initializing cache manager with memory-only backend")
|
||||
initializeDefaultCaches(universalCacheManager, logger)
|
||||
}
|
||||
})
|
||||
|
||||
return universalCacheManager
|
||||
}
|
||||
|
||||
// initializeDefaultCaches initializes caches with memory-only backends
|
||||
func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) {
|
||||
// Initialize token cache - CRITICAL FIX: Reduced from 5000 to 1000
|
||||
manager.tokenCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000, // CRITICAL FIX: Reduced from 5000 to 1000 items
|
||||
MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit
|
||||
DefaultTTL: 1 * time.Hour,
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
// Initialize blacklist cache
|
||||
manager.blacklistCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000,
|
||||
DefaultTTL: 24 * time.Hour,
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
// Initialize metadata cache with grace periods
|
||||
manager.metadataCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeMetadata,
|
||||
MaxSize: 100,
|
||||
DefaultTTL: 1 * time.Hour,
|
||||
MetadataConfig: &MetadataCacheConfig{
|
||||
GracePeriod: 5 * time.Minute,
|
||||
ExtendedGracePeriod: 15 * time.Minute,
|
||||
MaxGracePeriod: 30 * time.Minute,
|
||||
SecurityCriticalMaxGracePeriod: 15 * time.Minute,
|
||||
SecurityCriticalFields: []string{
|
||||
"jwks_uri",
|
||||
"token_endpoint",
|
||||
"authorization_endpoint",
|
||||
"issuer",
|
||||
},
|
||||
},
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
// Initialize JWK cache
|
||||
manager.jwkCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeJWK,
|
||||
MaxSize: 200,
|
||||
DefaultTTL: 1 * time.Hour,
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
// Initialize session cache - CRITICAL FIX: Reduced from 10000 to 2000
|
||||
manager.sessionCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeSession,
|
||||
MaxSize: 2000, // CRITICAL FIX: Reduced from 10000 to 2000 items
|
||||
MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit
|
||||
DefaultTTL: 30 * time.Minute,
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
// Initialize introspection cache for OAuth 2.0 Token Introspection (RFC 7662)
|
||||
manager.introspectionCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken, // Use token cache type for introspection results
|
||||
MaxSize: 1000, // Cache up to 1000 introspection results
|
||||
DefaultTTL: 5 * time.Minute, // Short TTL for security (introspect frequently)
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
// Initialize token type cache for performance optimization
|
||||
manager.tokenTypeCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken, // Use token cache type for token type detection
|
||||
MaxSize: 2000, // Cache up to 2000 token type detections
|
||||
DefaultTTL: 5 * time.Minute, // 5 minute TTL for token type detection
|
||||
Logger: logger,
|
||||
})
|
||||
}
|
||||
|
||||
// initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration
|
||||
func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, redisConfig *RedisConfig) {
|
||||
// Apply defaults to Redis config
|
||||
redisConfig.ApplyDefaults()
|
||||
|
||||
// Create Redis backend
|
||||
redisBackendConfig := &backends.Config{
|
||||
Type: backends.BackendTypeRedis,
|
||||
RedisAddr: redisConfig.Address,
|
||||
RedisPassword: redisConfig.Password,
|
||||
RedisDB: redisConfig.DB,
|
||||
RedisPrefix: redisConfig.KeyPrefix,
|
||||
PoolSize: redisConfig.PoolSize,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
|
||||
var redisBackend backends.CacheBackend
|
||||
var err error
|
||||
|
||||
// Create Redis backend with resilience features if enabled
|
||||
redisBackend, err = backends.NewRedisBackend(redisBackendConfig)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create Redis backend: %v. Falling back to memory-only mode.", err)
|
||||
initializeDefaultCaches(manager, logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap with circuit breaker if enabled
|
||||
if redisConfig.EnableCircuitBreaker {
|
||||
cbConfig := resilience.DefaultCircuitBreakerConfig()
|
||||
cbConfig.MaxFailures = redisConfig.CircuitBreakerThreshold
|
||||
cbConfig.Timeout = time.Duration(redisConfig.CircuitBreakerTimeout) * time.Second
|
||||
cbConfig.OnStateChange = func(from, to resilience.State) {
|
||||
logger.Infof("Circuit breaker state changed from %s to %s", from, to)
|
||||
}
|
||||
|
||||
redisBackend = resilience.NewCircuitBreakerBackend(redisBackend, cbConfig)
|
||||
logger.Info("Redis backend wrapped with circuit breaker")
|
||||
}
|
||||
|
||||
// Wrap with health checker if enabled
|
||||
if redisConfig.EnableHealthCheck {
|
||||
hcConfig := &resilience.HealthCheckConfig{
|
||||
CheckInterval: time.Duration(redisConfig.HealthCheckInterval) * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
HealthyThreshold: 2,
|
||||
UnhealthyThreshold: 3,
|
||||
OnStatusChange: func(from, to resilience.HealthStatus) {
|
||||
logger.Infof("Redis backend health status changed from %s to %s", from, to)
|
||||
},
|
||||
}
|
||||
|
||||
redisBackend = resilience.NewHealthCheckBackend(redisBackend, hcConfig)
|
||||
logger.Info("Redis backend wrapped with health checker")
|
||||
}
|
||||
|
||||
// Decide which backend to use based on cache mode
|
||||
var createBackend func(cacheType CacheType) backends.CacheBackend
|
||||
|
||||
switch redisConfig.CacheMode {
|
||||
case "redis":
|
||||
// Redis-only mode
|
||||
createBackend = func(cacheType CacheType) backends.CacheBackend {
|
||||
return redisBackend
|
||||
}
|
||||
logger.Info("Using Redis-only cache backend")
|
||||
|
||||
case "hybrid":
|
||||
// Hybrid mode is not currently supported due to interface incompatibilities
|
||||
// Fall back to Redis-only mode
|
||||
logger.Info("Hybrid mode not currently supported, using Redis-only mode")
|
||||
createBackend = func(cacheType CacheType) backends.CacheBackend {
|
||||
return redisBackend
|
||||
}
|
||||
|
||||
default:
|
||||
// Memory-only mode (fallback)
|
||||
logger.Infof("Invalid cache mode: %s. Using memory-only mode.", redisConfig.CacheMode)
|
||||
initializeDefaultCaches(manager, logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize token cache with backend
|
||||
manager.tokenCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000, // CRITICAL FIX: Reduced from 5000 to 1000 items
|
||||
MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit
|
||||
MaxSize: 1000,
|
||||
MaxMemoryBytes: 5 * 1024 * 1024,
|
||||
DefaultTTL: 1 * time.Hour,
|
||||
Logger: logger,
|
||||
})
|
||||
},
|
||||
createBackend(CacheTypeToken),
|
||||
)
|
||||
|
||||
// Initialize blacklist cache
|
||||
universalCacheManager.blacklistCache = NewUniversalCache(UniversalCacheConfig{
|
||||
// Initialize blacklist cache (CRITICAL - must be consistent across replicas)
|
||||
manager.blacklistCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000,
|
||||
DefaultTTL: 24 * time.Hour,
|
||||
Logger: logger,
|
||||
})
|
||||
},
|
||||
createBackend("blacklist"),
|
||||
)
|
||||
|
||||
// Initialize metadata cache with grace periods
|
||||
universalCacheManager.metadataCache = NewUniversalCache(UniversalCacheConfig{
|
||||
// Initialize metadata cache
|
||||
manager.metadataCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeMetadata,
|
||||
MaxSize: 100,
|
||||
DefaultTTL: 1 * time.Hour,
|
||||
@@ -69,43 +264,50 @@ func GetUniversalCacheManager(logger *Logger) *UniversalCacheManager {
|
||||
},
|
||||
},
|
||||
Logger: logger,
|
||||
})
|
||||
},
|
||||
createBackend(CacheTypeMetadata),
|
||||
)
|
||||
|
||||
// Initialize JWK cache
|
||||
universalCacheManager.jwkCache = NewUniversalCache(UniversalCacheConfig{
|
||||
// Initialize JWK cache
|
||||
manager.jwkCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeJWK,
|
||||
MaxSize: 200,
|
||||
DefaultTTL: 1 * time.Hour,
|
||||
Logger: logger,
|
||||
})
|
||||
},
|
||||
createBackend(CacheTypeJWK),
|
||||
)
|
||||
|
||||
// Initialize session cache - CRITICAL FIX: Reduced from 10000 to 2000
|
||||
universalCacheManager.sessionCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeSession,
|
||||
MaxSize: 2000, // CRITICAL FIX: Reduced from 10000 to 2000 items
|
||||
MaxMemoryBytes: 5 * 1024 * 1024, // CRITICAL FIX: Added 5MB memory limit
|
||||
DefaultTTL: 30 * time.Minute,
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
// Initialize introspection cache for OAuth 2.0 Token Introspection (RFC 7662)
|
||||
universalCacheManager.introspectionCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken, // Use token cache type for introspection results
|
||||
MaxSize: 1000, // Cache up to 1000 introspection results
|
||||
DefaultTTL: 5 * time.Minute, // Short TTL for security (introspect frequently)
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
// Initialize token type cache for performance optimization
|
||||
universalCacheManager.tokenTypeCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken, // Use token cache type for token type detection
|
||||
MaxSize: 2000, // Cache up to 2000 token type detections
|
||||
DefaultTTL: 5 * time.Minute, // 5 minute TTL for token type detection
|
||||
Logger: logger,
|
||||
})
|
||||
// Session cache stays memory-only (high volume, local state)
|
||||
manager.sessionCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeSession,
|
||||
MaxSize: 2000,
|
||||
MaxMemoryBytes: 5 * 1024 * 1024,
|
||||
DefaultTTL: 30 * time.Minute,
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
return universalCacheManager
|
||||
// Introspection cache uses backend for sharing results
|
||||
manager.introspectionCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000,
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
Logger: logger,
|
||||
},
|
||||
createBackend(CacheTypeToken),
|
||||
)
|
||||
|
||||
// Token type cache stays memory-only (local optimization)
|
||||
manager.tokenTypeCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 2000,
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
|
||||
}
|
||||
|
||||
// GetTokenCache returns the token cache
|
||||
@@ -174,6 +376,35 @@ func (m *UniversalCacheManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitializeCacheManagerFromConfig initializes the cache manager with configuration
|
||||
// This should be called early in the application startup with the loaded configuration
|
||||
func InitializeCacheManagerFromConfig(config *Config) *UniversalCacheManager {
|
||||
logger := NewLogger(config.LogLevel)
|
||||
|
||||
// Initialize Redis config if not present
|
||||
if config.Redis == nil {
|
||||
config.Redis = &RedisConfig{}
|
||||
}
|
||||
|
||||
// Apply environment variable fallbacks for fields not set in config
|
||||
// This allows env vars to be used as optional overrides only when
|
||||
// the config field is not explicitly set through Traefik
|
||||
config.Redis.ApplyEnvFallbacks()
|
||||
|
||||
// Apply defaults after env fallbacks
|
||||
config.Redis.ApplyDefaults()
|
||||
|
||||
// Log cache backend selection
|
||||
if config.Redis != nil && config.Redis.Enabled {
|
||||
logger.Infof("Initializing cache backend with Redis: mode=%s, address=%s",
|
||||
config.Redis.CacheMode, config.Redis.Address)
|
||||
} else {
|
||||
logger.Info("Initializing cache backend with memory-only mode")
|
||||
}
|
||||
|
||||
return GetUniversalCacheManagerWithConfig(logger, config.Redis)
|
||||
}
|
||||
|
||||
// ResetUniversalCacheManagerForTesting resets the singleton for testing purposes only
|
||||
// This should only be called in test code to ensure proper cleanup between tests
|
||||
func ResetUniversalCacheManagerForTesting() {
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
/integration/redis_src/
|
||||
/integration/dump.rdb
|
||||
*.swp
|
||||
/integration/nodes.conf
|
||||
.idea/
|
||||
miniredis.iml
|
||||
+328
@@ -0,0 +1,328 @@
|
||||
## Changelog
|
||||
|
||||
|
||||
## v2.35.0
|
||||
|
||||
- add Lua redis.setresp({2,3})
|
||||
- embed gopher-json package
|
||||
- fix XAUTOCLAIM (thanks @kgunning)
|
||||
- fix writeXpending (thanks @gnpaone)
|
||||
- fix BLMOVE TTL special case
|
||||
- constants for key types @alyssaruth
|
||||
|
||||
|
||||
### v2.34.0
|
||||
|
||||
- fix ZINTERSTORE where target is one of the source sets
|
||||
- added support for ZRank and ZRevRank with score (thanks Jeff Howell)
|
||||
- fix MEMORY subcommand casing (thanks @joshaber)
|
||||
- use streamCmp in Xtrim (thanks @daniel-cohere)
|
||||
|
||||
|
||||
### v2.33.0
|
||||
|
||||
- minimum Go version is now 1.17
|
||||
- fix integer overflow (thanks @wszaranski)
|
||||
- test against the last BSD redis (7.2.4)
|
||||
- ignore 'redis.set_repl()' call (thanks @TingluoHuang)
|
||||
- various build fixes (thanks @wszaranski)
|
||||
- add StartAddrTLS function (thanks @agriffaut)
|
||||
- support for the NOMKSTREAM option for XADD (thanks @Jahaja)
|
||||
- return empty array for SRANDMEMBER on nonexistent key (thanks @WKBae)
|
||||
|
||||
|
||||
### v2.32.1
|
||||
|
||||
- support for SINTERCARD (thanks @s-barr-fetch)
|
||||
- support for EXPIRETIME and PEXPIRETIME (thanks @wszaranski)
|
||||
- fix GEO* units to be case insensitive
|
||||
|
||||
|
||||
### v2.31.1
|
||||
|
||||
- support COUNT in SCAN and ZSCAN (thanks @BarakSilverfort)
|
||||
- support for OBJECT IDLETIME (thanks @nerd2)
|
||||
- support for HRANDFIELD (thanks @sejin-P)
|
||||
|
||||
|
||||
### v2.31.0
|
||||
|
||||
- support for MEMORY USAGE (thanks @davidroman0O)
|
||||
- test against Redis 7.2.0
|
||||
- support for CLIENT SETNAME/GETNAME (thanks @mr-karan)
|
||||
- fix very small numbers (thanks @zsh1995)
|
||||
- use the same float-to-string logic real Redis uses
|
||||
|
||||
|
||||
### v2.30.5
|
||||
|
||||
- support SMISMEMBER (thanks @sandyharvie)
|
||||
|
||||
|
||||
### v2.30.4
|
||||
|
||||
- fix ZADD LT/LG (thanks @sejin-P)
|
||||
- fix COPY (thanks @jerargus)
|
||||
- quicker SPOP
|
||||
|
||||
|
||||
### v2.30.3
|
||||
|
||||
- fix lua error_reply (thanks @pkierski)
|
||||
- fix use of blocking functions in lua
|
||||
- support for ZMSCORE (thanks @lsgndln)
|
||||
- lua cache (thanks @tonyhb)
|
||||
|
||||
|
||||
### v2.30.2
|
||||
|
||||
- support MINID in XADD (thanks @nathan-cormier)
|
||||
- support BLMOVE (thanks @sevein)
|
||||
- fix COMMAND (thanks @pje)
|
||||
- fix 'XREAD ... $' on a non-existing stream
|
||||
|
||||
|
||||
### v2.30.1
|
||||
|
||||
- support SET NX GET special case
|
||||
|
||||
|
||||
### v2.30.0
|
||||
|
||||
- implement redis 7.0.x (from 6.X). Main changes:
|
||||
- test against 7.0.7
|
||||
- update error messages
|
||||
- support nx|xx|gt|lt options in [P]EXPIRE[AT]
|
||||
- update how deleted items are processed in pending queues in streams
|
||||
|
||||
|
||||
### v2.23.1
|
||||
|
||||
- resolve $ to latest ID in XREAD (thanks @josh-hook)
|
||||
- handle disconnect in blocking functions (thanks @jgirtakovskis)
|
||||
- fix type conversion bug in redisToLua (thanks Sandy Harvie)
|
||||
- BRPOP{LPUSH} timeout can be float since 6.0
|
||||
|
||||
|
||||
### v2.23.0
|
||||
|
||||
- basic INFO support (thanks @kirill-a-belov)
|
||||
- support COUNT in SSCAN (thanks @Abdi-dd)
|
||||
- test and support Go 1.19
|
||||
- support LPOS (thanks @ianstarz)
|
||||
- support XPENDING, XGROUP {CREATECONSUMER,DESTROY,DELCONSUMER}, XINFO {CONSUMERS,GROUPS}, XCLAIM (thanks @sandyharvie)
|
||||
|
||||
|
||||
### v2.22.0
|
||||
|
||||
- set miniredis.DumpMaxLineLen to get more Dump() info (thanks @afjoseph)
|
||||
- fix invalid resposne of COMMAND (thanks @zsh1995)
|
||||
- fix possibility to generate duplicate IDs in XADD (thanks @readams)
|
||||
- adds support for XAUTOCLAIM min-idle parameter (thanks @readams)
|
||||
|
||||
|
||||
### v2.21.0
|
||||
|
||||
- support for GETEX (thanks @dntj)
|
||||
- support for GT and LT in ZADD (thanks @lsgndln)
|
||||
- support for XAUTOCLAIM (thanks @randall-fulton)
|
||||
|
||||
|
||||
### v2.20.0
|
||||
|
||||
- back to support Go >= 1.14 (thanks @ajatprabha and @marcind)
|
||||
|
||||
|
||||
### v2.19.0
|
||||
|
||||
- support for TYPE in SCAN (thanks @0xDiddi)
|
||||
- update BITPOS (thanks @dirkm)
|
||||
- fix a lua redis.call() return value (thanks @mpetronic)
|
||||
- update ZRANGE (thanks @valdemarpereira)
|
||||
|
||||
|
||||
### v2.18.0
|
||||
|
||||
- support for ZUNION (thanks @propan)
|
||||
- support for COPY (thanks @matiasinsaurralde and @rockitbaby)
|
||||
- support for LMOVE (thanks @btwear)
|
||||
|
||||
|
||||
### v2.17.0
|
||||
|
||||
- added miniredis.RunT(t)
|
||||
|
||||
|
||||
### v2.16.1
|
||||
|
||||
- fix ZINTERSTORE with sets (thanks @lingjl2010 and @okhowang)
|
||||
- fix exclusive ranges in XRANGE (thanks @joseotoro)
|
||||
|
||||
|
||||
### v2.16.0
|
||||
|
||||
- simplify some code (thanks @zonque)
|
||||
- support for EXAT/PXAT in SET
|
||||
- support for XTRIM (thanks @joseotoro)
|
||||
- support for ZRANDMEMBER
|
||||
- support for redis.log() in lua (thanks @dirkm)
|
||||
|
||||
|
||||
### v2.15.2
|
||||
|
||||
- Fix race condition in blocking code (thanks @zonque and @robx)
|
||||
- XREAD accepts '$' as ID (thanks @bradengroom)
|
||||
|
||||
|
||||
### v2.15.1
|
||||
|
||||
- EVAL should cache the script (thanks @guoshimin)
|
||||
|
||||
|
||||
### v2.15.0
|
||||
|
||||
- target redis 6.2 and added new args to various commands
|
||||
- support for all hyperlog commands (thanks @ilbaktin)
|
||||
- support for GETDEL (thanks @wszaranski)
|
||||
|
||||
|
||||
### v2.14.5
|
||||
|
||||
- added XPENDING
|
||||
- support for BLOCK option in XREAD and XREADGROUP
|
||||
|
||||
|
||||
### v2.14.4
|
||||
|
||||
- fix BITPOS error (thanks @xiaoyuzdy)
|
||||
- small fixes for XREAD, XACK, and XDEL. Mostly error cases.
|
||||
- fix empty EXEC return type (thanks @ashanbrown)
|
||||
- fix XDEL (thanks @svakili and @yvesf)
|
||||
- fix FLUSHALL for streams (thanks @svakili)
|
||||
|
||||
|
||||
### v2.14.3
|
||||
|
||||
- fix problem where Lua code didn't set the selected DB
|
||||
- update to redis 6.0.10 (thanks @lazappa)
|
||||
|
||||
|
||||
### v2.14.2
|
||||
|
||||
- update LUA dependency
|
||||
- deal with (p)unsubscribe when there are no channels
|
||||
|
||||
|
||||
### v2.14.1
|
||||
|
||||
- mod tidy
|
||||
|
||||
|
||||
### v2.14.0
|
||||
|
||||
- support for HELLO and the RESP3 protocol
|
||||
- KEEPTTL in SET (thanks @johnpena)
|
||||
|
||||
|
||||
### v2.13.3
|
||||
|
||||
- support Go 1.14 and 1.15
|
||||
- update the `Check...()` methods
|
||||
- support for XREAD (thanks @pieterlexis)
|
||||
|
||||
|
||||
### v2.13.2
|
||||
|
||||
- Use SAN instead of CN in self signed cert for testing (thanks @johejo)
|
||||
- Travis CI now tests against the most recent two versions of Go (thanks @johejo)
|
||||
- changed unit and integration tests to compare raw payloads, not parsed payloads
|
||||
- remove "redigo" dependency
|
||||
|
||||
|
||||
### v2.13.1
|
||||
|
||||
- added HSTRLEN
|
||||
- minimal support for ACL users in AUTH
|
||||
|
||||
|
||||
### v2.13.0
|
||||
|
||||
- added RunTLS(...)
|
||||
- added SetError(...)
|
||||
|
||||
|
||||
### v2.12.0
|
||||
|
||||
- redis 6
|
||||
- Lua json update (thanks @gsmith85)
|
||||
- CLUSTER commands (thanks @kratisto)
|
||||
- fix TOUCH
|
||||
- fix a shutdown race condition
|
||||
|
||||
|
||||
### v2.11.4
|
||||
|
||||
- ZUNIONSTORE now supports standard set types (thanks @wshirey)
|
||||
|
||||
|
||||
### v2.11.3
|
||||
|
||||
- support for TOUCH (thanks @cleroux)
|
||||
- support for cluster and stream commands (thanks @kak-tus)
|
||||
|
||||
|
||||
### v2.11.2
|
||||
|
||||
- make sure Lua code is executed concurrently
|
||||
- add command GEORADIUSBYMEMBER (thanks @kyeett)
|
||||
|
||||
|
||||
### v2.11.1
|
||||
|
||||
- globals protection for Lua code (thanks @vk-outreach)
|
||||
- HSET update (thanks @carlgreen)
|
||||
- fix BLPOP block on shutdown (thanks @Asalle)
|
||||
|
||||
|
||||
### v2.11.0
|
||||
|
||||
- added XRANGE/XREVRANGE, XADD, and XLEN (thanks @skateinmars)
|
||||
- added GEODIST
|
||||
- improved precision for geohashes, closer to what real redis does
|
||||
- use 128bit floats internally for INCRBYFLOAT and related (thanks @timnd)
|
||||
|
||||
|
||||
### v2.10.1
|
||||
|
||||
- added m.Server()
|
||||
|
||||
|
||||
### v2.10.0
|
||||
|
||||
- added UNLINK
|
||||
- fix DEL zero-argument case
|
||||
- cleanup some direct access commands
|
||||
- added GEOADD, GEOPOS, GEORADIUS, and GEORADIUS_RO
|
||||
|
||||
|
||||
### v2.9.1
|
||||
|
||||
- fix issue with ZRANGEBYLEX
|
||||
- fix issue with BRPOPLPUSH and direct access
|
||||
|
||||
|
||||
### v2.9.0
|
||||
|
||||
- proper versioned import of github.com/gomodule/redigo (thanks @yfei1)
|
||||
- fix messages generated by PSUBSCRIBE
|
||||
- optional internal seed (thanks @zikaeroh)
|
||||
|
||||
|
||||
### v2.8.0
|
||||
|
||||
Proper `v2` in go.mod.
|
||||
|
||||
|
||||
### older
|
||||
|
||||
See https://github.com/alicebob/miniredis/releases for the full changelog
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Harmen
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
+33
@@ -0,0 +1,33 @@
|
||||
.PHONY: test
|
||||
test: ### Run unit tests
|
||||
go test ./...
|
||||
|
||||
.PHONY: testrace
|
||||
testrace: ### Run unit tests with race detector
|
||||
go test -race ./...
|
||||
|
||||
.PHONY: int
|
||||
int: ### Run integration tests (doesn't download redis server)
|
||||
${MAKE} -C integration int
|
||||
|
||||
.PHONY: ci
|
||||
ci: ### Run full tests suite (including download and compilation of proper redis server)
|
||||
${MAKE} test
|
||||
${MAKE} -C integration redis_src/redis-server int
|
||||
${MAKE} testrace
|
||||
|
||||
.PHONY: clean
|
||||
clean: ### Clean integration test files and remove compiled redis from integration/redis_src
|
||||
${MAKE} -C integration clean
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
ifeq ($(UNAME), Linux)
|
||||
@grep -P '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | \
|
||||
awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
||||
else
|
||||
@# this is not tested, but prepared in advance for you, Mac drivers
|
||||
@awk -F ':.*###' '$$0 ~ FS {printf "%15s%s\n", $$1 ":", $$2}' \
|
||||
$(MAKEFILE_LIST) | grep -v '@awk' | sort
|
||||
endif
|
||||
|
||||
+342
@@ -0,0 +1,342 @@
|
||||
# Miniredis
|
||||
|
||||
Pure Go Redis test server, used in Go unittests.
|
||||
|
||||
|
||||
##
|
||||
|
||||
Sometimes you want to test code which uses Redis, without making it a full-blown
|
||||
integration test.
|
||||
Miniredis implements (parts of) the Redis server, to be used in unittests. It
|
||||
enables a simple, cheap, in-memory, Redis replacement, with a real TCP interface. Think of it as the Redis version of `net/http/httptest`.
|
||||
|
||||
It saves you from using mock code, and since the redis server lives in the
|
||||
test process you can query for values directly, without going through the server
|
||||
stack.
|
||||
|
||||
There are no dependencies on external binaries, so you can easily integrate it in automated build processes.
|
||||
|
||||
Be sure to import v2:
|
||||
```
|
||||
import "github.com/alicebob/miniredis/v2"
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
Implemented commands:
|
||||
|
||||
- Connection (complete)
|
||||
- AUTH -- see RequireAuth()
|
||||
- ECHO
|
||||
- HELLO -- see RequireUserAuth()
|
||||
- PING
|
||||
- SELECT
|
||||
- SWAPDB
|
||||
- QUIT
|
||||
- Key
|
||||
- COPY
|
||||
- DEL
|
||||
- EXISTS
|
||||
- EXPIRE
|
||||
- EXPIREAT
|
||||
- EXPIRETIME
|
||||
- KEYS
|
||||
- MOVE
|
||||
- PERSIST
|
||||
- PEXPIRE
|
||||
- PEXPIREAT
|
||||
- PEXPIRETIME
|
||||
- PTTL
|
||||
- RANDOMKEY -- see m.Seed(...)
|
||||
- RENAME
|
||||
- RENAMENX
|
||||
- SCAN
|
||||
- TOUCH
|
||||
- TTL
|
||||
- TYPE
|
||||
- UNLINK
|
||||
- Transactions (complete)
|
||||
- DISCARD
|
||||
- EXEC
|
||||
- MULTI
|
||||
- UNWATCH
|
||||
- WATCH
|
||||
- Server
|
||||
- DBSIZE
|
||||
- FLUSHALL
|
||||
- FLUSHDB
|
||||
- TIME -- returns time.Now() or value set by SetTime()
|
||||
- COMMAND -- partly
|
||||
- INFO -- partly, returns only "clients" section with one field "connected_clients"
|
||||
- String keys (complete)
|
||||
- APPEND
|
||||
- BITCOUNT
|
||||
- BITOP
|
||||
- BITPOS
|
||||
- DECR
|
||||
- DECRBY
|
||||
- GET
|
||||
- GETBIT
|
||||
- GETRANGE
|
||||
- GETSET
|
||||
- GETDEL
|
||||
- GETEX
|
||||
- INCR
|
||||
- INCRBY
|
||||
- INCRBYFLOAT
|
||||
- MGET
|
||||
- MSET
|
||||
- MSETNX
|
||||
- PSETEX
|
||||
- SET
|
||||
- SETBIT
|
||||
- SETEX
|
||||
- SETNX
|
||||
- SETRANGE
|
||||
- STRLEN
|
||||
- Hash keys (complete)
|
||||
- HDEL
|
||||
- HEXISTS
|
||||
- HGET
|
||||
- HGETALL
|
||||
- HINCRBY
|
||||
- HINCRBYFLOAT
|
||||
- HKEYS
|
||||
- HLEN
|
||||
- HMGET
|
||||
- HMSET
|
||||
- HRANDFIELD
|
||||
- HSET
|
||||
- HSETNX
|
||||
- HSTRLEN
|
||||
- HVALS
|
||||
- HSCAN
|
||||
- List keys (complete)
|
||||
- BLPOP
|
||||
- BRPOP
|
||||
- BRPOPLPUSH
|
||||
- LINDEX
|
||||
- LINSERT
|
||||
- LLEN
|
||||
- LPOP
|
||||
- LPUSH
|
||||
- LPUSHX
|
||||
- LRANGE
|
||||
- LREM
|
||||
- LSET
|
||||
- LTRIM
|
||||
- RPOP
|
||||
- RPOPLPUSH
|
||||
- RPUSH
|
||||
- RPUSHX
|
||||
- LMOVE
|
||||
- BLMOVE
|
||||
- Pub/Sub (complete)
|
||||
- PSUBSCRIBE
|
||||
- PUBLISH
|
||||
- PUBSUB
|
||||
- PUNSUBSCRIBE
|
||||
- SUBSCRIBE
|
||||
- UNSUBSCRIBE
|
||||
- Set keys (complete)
|
||||
- SADD
|
||||
- SCARD
|
||||
- SDIFF
|
||||
- SDIFFSTORE
|
||||
- SINTER
|
||||
- SINTERSTORE
|
||||
- SINTERCARD
|
||||
- SISMEMBER
|
||||
- SMEMBERS
|
||||
- SMISMEMBER
|
||||
- SMOVE
|
||||
- SPOP -- see m.Seed(...)
|
||||
- SRANDMEMBER -- see m.Seed(...)
|
||||
- SREM
|
||||
- SSCAN
|
||||
- SUNION
|
||||
- SUNIONSTORE
|
||||
- Sorted Set keys (complete)
|
||||
- ZADD
|
||||
- ZCARD
|
||||
- ZCOUNT
|
||||
- ZINCRBY
|
||||
- ZINTER
|
||||
- ZINTERSTORE
|
||||
- ZLEXCOUNT
|
||||
- ZPOPMIN
|
||||
- ZPOPMAX
|
||||
- ZRANDMEMBER
|
||||
- ZRANGE
|
||||
- ZRANGEBYLEX
|
||||
- ZRANGEBYSCORE
|
||||
- ZRANK
|
||||
- ZREM
|
||||
- ZREMRANGEBYLEX
|
||||
- ZREMRANGEBYRANK
|
||||
- ZREMRANGEBYSCORE
|
||||
- ZREVRANGE
|
||||
- ZREVRANGEBYLEX
|
||||
- ZREVRANGEBYSCORE
|
||||
- ZREVRANK
|
||||
- ZSCORE
|
||||
- ZUNION
|
||||
- ZUNIONSTORE
|
||||
- ZSCAN
|
||||
- Stream keys
|
||||
- XACK
|
||||
- XADD
|
||||
- XAUTOCLAIM
|
||||
- XCLAIM
|
||||
- XDEL
|
||||
- XGROUP CREATE
|
||||
- XGROUP CREATECONSUMER
|
||||
- XGROUP DESTROY
|
||||
- XGROUP DELCONSUMER
|
||||
- XINFO STREAM -- partly
|
||||
- XINFO GROUPS
|
||||
- XINFO CONSUMERS -- partly
|
||||
- XLEN
|
||||
- XRANGE
|
||||
- XREAD
|
||||
- XREADGROUP
|
||||
- XREVRANGE
|
||||
- XPENDING
|
||||
- XTRIM
|
||||
- Scripting
|
||||
- EVAL
|
||||
- EVALSHA
|
||||
- SCRIPT LOAD
|
||||
- SCRIPT EXISTS
|
||||
- SCRIPT FLUSH
|
||||
- GEO
|
||||
- GEOADD
|
||||
- GEODIST
|
||||
- ~~GEOHASH~~
|
||||
- GEOPOS
|
||||
- GEORADIUS
|
||||
- GEORADIUS_RO
|
||||
- GEORADIUSBYMEMBER
|
||||
- GEORADIUSBYMEMBER_RO
|
||||
- Cluster
|
||||
- CLUSTER SLOTS
|
||||
- CLUSTER KEYSLOT
|
||||
- CLUSTER NODES
|
||||
- HyperLogLog (complete)
|
||||
- PFADD
|
||||
- PFCOUNT
|
||||
- PFMERGE
|
||||
|
||||
|
||||
## TTLs, key expiration, and time
|
||||
|
||||
Since miniredis is intended to be used in unittests TTLs don't decrease
|
||||
automatically. You can use `TTL()` to get the TTL (as a time.Duration) of a
|
||||
key. It will return 0 when no TTL is set.
|
||||
|
||||
`m.FastForward(d)` can be used to decrement all TTLs. All TTLs which become <=
|
||||
0 will be removed.
|
||||
|
||||
EXPIREAT and PEXPIREAT values will be
|
||||
converted to a duration. For that you can either set m.SetTime(t) to use that
|
||||
time as the base for the (P)EXPIREAT conversion, or don't call SetTime(), in
|
||||
which case time.Now() will be used.
|
||||
|
||||
SetTime() also sets the value returned by TIME, which defaults to time.Now().
|
||||
It is not updated by FastForward, only by SetTime.
|
||||
|
||||
## Randomness and Seed()
|
||||
|
||||
Miniredis will use `math/rand`'s global RNG for randomness unless a seed is
|
||||
provided by calling `m.Seed(...)`. If a seed is provided, then miniredis will
|
||||
use its own RNG based on that seed.
|
||||
|
||||
Commands which use randomness are: RANDOMKEY, SPOP, and SRANDMEMBER.
|
||||
|
||||
## Example
|
||||
|
||||
``` Go
|
||||
|
||||
import (
|
||||
...
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
...
|
||||
)
|
||||
|
||||
func TestSomething(t *testing.T) {
|
||||
s := miniredis.RunT(t)
|
||||
|
||||
// Optionally set some keys your code expects:
|
||||
s.Set("foo", "bar")
|
||||
s.HSet("some", "other", "key")
|
||||
|
||||
// Run your code and see if it behaves.
|
||||
// An example using the redigo library from "github.com/gomodule/redigo/redis":
|
||||
c, err := redis.Dial("tcp", s.Addr())
|
||||
_, err = c.Do("SET", "foo", "bar")
|
||||
|
||||
// Optionally check values in redis...
|
||||
if got, err := s.Get("foo"); err != nil || got != "bar" {
|
||||
t.Error("'foo' has the wrong value")
|
||||
}
|
||||
// ... or use a helper for that:
|
||||
s.CheckGet(t, "foo", "bar")
|
||||
|
||||
// TTL and expiration:
|
||||
s.Set("foo", "bar")
|
||||
s.SetTTL("foo", 10*time.Second)
|
||||
s.FastForward(11 * time.Second)
|
||||
if s.Exists("foo") {
|
||||
t.Fatal("'foo' should not have existed anymore")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Not supported
|
||||
|
||||
Commands which will probably not be implemented:
|
||||
|
||||
- CLUSTER (all)
|
||||
- ~~CLUSTER *~~
|
||||
- ~~READONLY~~
|
||||
- ~~READWRITE~~
|
||||
- Key
|
||||
- ~~DUMP~~
|
||||
- ~~MIGRATE~~
|
||||
- ~~OBJECT~~
|
||||
- ~~RESTORE~~
|
||||
- ~~WAIT~~
|
||||
- Scripting
|
||||
- ~~FCALL / FCALL_RO *~~
|
||||
- ~~FUNCTION *~~
|
||||
- ~~SCRIPT DEBUG~~
|
||||
- ~~SCRIPT KILL~~
|
||||
- Server
|
||||
- ~~BGSAVE~~
|
||||
- ~~BGWRITEAOF~~
|
||||
- ~~CLIENT *~~
|
||||
- ~~CONFIG *~~
|
||||
- ~~DEBUG *~~
|
||||
- ~~LASTSAVE~~
|
||||
- ~~MONITOR~~
|
||||
- ~~ROLE~~
|
||||
- ~~SAVE~~
|
||||
- ~~SHUTDOWN~~
|
||||
- ~~SLAVEOF~~
|
||||
- ~~SLOWLOG~~
|
||||
- ~~SYNC~~
|
||||
|
||||
|
||||
## &c.
|
||||
|
||||
Integration tests are run against Redis 7.2.4. The [./integration](./integration/) subdir
|
||||
compares miniredis against a real redis instance.
|
||||
|
||||
The Redis 6 RESP3 protocol is supported. If there are problems, please open
|
||||
an issue.
|
||||
|
||||
If you want to test Redis Sentinel have a look at [minisentinel](https://github.com/Bose/minisentinel).
|
||||
|
||||
A changelog is kept at [CHANGELOG.md](https://github.com/alicebob/miniredis/blob/master/CHANGELOG.md).
|
||||
|
||||
[](https://pkg.go.dev/github.com/alicebob/miniredis/v2)
|
||||
+63
@@ -0,0 +1,63 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// T is implemented by Testing.T
|
||||
type T interface {
|
||||
Helper()
|
||||
Errorf(string, ...interface{})
|
||||
}
|
||||
|
||||
// CheckGet does not call Errorf() iff there is a string key with the
|
||||
// expected value. Normal use case is `m.CheckGet(t, "username", "theking")`.
|
||||
func (m *Miniredis) CheckGet(t T, key, expected string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.Get(key)
|
||||
if err != nil {
|
||||
t.Errorf("GET error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
if found != expected {
|
||||
t.Errorf("GET error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CheckList does not call Errorf() iff there is a list key with the
|
||||
// expected values.
|
||||
// Normal use case is `m.CheckGet(t, "favorite_colors", "red", "green", "infrared")`.
|
||||
func (m *Miniredis) CheckList(t T, key string, expected ...string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.List(key)
|
||||
if err != nil {
|
||||
t.Errorf("List error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(expected, found) {
|
||||
t.Errorf("List error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CheckSet does not call Errorf() iff there is a set key with the
|
||||
// expected values.
|
||||
// Normal use case is `m.CheckSet(t, "visited", "Rome", "Stockholm", "Dublin")`.
|
||||
func (m *Miniredis) CheckSet(t T, key string, expected ...string) {
|
||||
t.Helper()
|
||||
|
||||
found, err := m.Members(key)
|
||||
if err != nil {
|
||||
t.Errorf("Set error, key %#v: %v", key, err)
|
||||
return
|
||||
}
|
||||
sort.Strings(expected)
|
||||
if !reflect.DeepEqual(expected, found) {
|
||||
t.Errorf("Set error, key %#v: Expected %#v, got %#v", key, expected, found)
|
||||
return
|
||||
}
|
||||
}
|
||||
+68
@@ -0,0 +1,68 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsClient handles client operations.
|
||||
func commandsClient(m *Miniredis) {
|
||||
m.srv.Register("CLIENT", m.cmdClient)
|
||||
}
|
||||
|
||||
// CLIENT
|
||||
func (m *Miniredis) cmdClient(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR wrong number of arguments for 'client' command")
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
switch cmd := strings.ToUpper(args[0]); cmd {
|
||||
case "SETNAME":
|
||||
m.cmdClientSetName(c, args[1:])
|
||||
case "GETNAME":
|
||||
m.cmdClientGetName(c, args[1:])
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf("ERR unknown subcommand '%s'. Try CLIENT HELP.", cmd))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// CLIENT SETNAME
|
||||
func (m *Miniredis) cmdClientSetName(c *server.Peer, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR wrong number of arguments for 'client setname' command")
|
||||
return
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
if strings.ContainsAny(name, " \n") {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR Client names cannot contain spaces, newlines or special characters.")
|
||||
return
|
||||
|
||||
}
|
||||
c.ClientName = name
|
||||
c.WriteOK()
|
||||
}
|
||||
|
||||
// CLIENT GETNAME
|
||||
func (m *Miniredis) cmdClientGetName(c *server.Peer, args []string) {
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR wrong number of arguments for 'client getname' command")
|
||||
return
|
||||
}
|
||||
|
||||
if c.ClientName == "" {
|
||||
c.WriteNull()
|
||||
} else {
|
||||
c.WriteBulk(c.ClientName)
|
||||
}
|
||||
}
|
||||
+67
@@ -0,0 +1,67 @@
|
||||
// Commands from https://redis.io/commands#cluster
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsCluster handles some cluster operations.
|
||||
func commandsCluster(m *Miniredis) {
|
||||
m.srv.Register("CLUSTER", m.cmdCluster)
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdCluster(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
switch strings.ToUpper(args[0]) {
|
||||
case "SLOTS":
|
||||
m.cmdClusterSlots(c, cmd, args)
|
||||
case "KEYSLOT":
|
||||
m.cmdClusterKeySlot(c, cmd, args)
|
||||
case "NODES":
|
||||
m.cmdClusterNodes(c, cmd, args)
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf("ERR 'CLUSTER %s' not supported", strings.Join(args, " ")))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CLUSTER SLOTS
|
||||
func (m *Miniredis) cmdClusterSlots(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteLen(1)
|
||||
c.WriteLen(3)
|
||||
c.WriteInt(0)
|
||||
c.WriteInt(16383)
|
||||
c.WriteLen(3)
|
||||
c.WriteBulk(m.srv.Addr().IP.String())
|
||||
c.WriteInt(m.srv.Addr().Port)
|
||||
c.WriteBulk("09dbe9720cda62f7865eabc5fd8857c5d2678366")
|
||||
})
|
||||
}
|
||||
|
||||
// CLUSTER KEYSLOT
|
||||
func (m *Miniredis) cmdClusterKeySlot(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteInt(163)
|
||||
})
|
||||
}
|
||||
|
||||
// CLUSTER NODES
|
||||
func (m *Miniredis) cmdClusterNodes(c *server.Peer, cmd string, args []string) {
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteBulk("e7d1eecce10fd6bb5eb35b9f99a514335d9ba9ca 127.0.0.1:7000@7000 myself,master - 0 0 1 connected 0-16383")
|
||||
})
|
||||
}
|
||||
+14
File diff suppressed because one or more lines are too long
+285
@@ -0,0 +1,285 @@
|
||||
// Commands from https://redis.io/commands#connection
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
func commandsConnection(m *Miniredis) {
|
||||
m.srv.Register("AUTH", m.cmdAuth)
|
||||
m.srv.Register("ECHO", m.cmdEcho)
|
||||
m.srv.Register("HELLO", m.cmdHello)
|
||||
m.srv.Register("PING", m.cmdPing)
|
||||
m.srv.Register("QUIT", m.cmdQuit)
|
||||
m.srv.Register("SELECT", m.cmdSelect)
|
||||
m.srv.Register("SWAPDB", m.cmdSwapdb)
|
||||
}
|
||||
|
||||
// PING
|
||||
func (m *Miniredis) cmdPing(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) > 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
payload := ""
|
||||
if len(args) > 0 {
|
||||
payload = args[0]
|
||||
}
|
||||
|
||||
// PING is allowed in subscribed state
|
||||
if sub := getCtx(c).subscriber; sub != nil {
|
||||
c.Block(func(c *server.Writer) {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("pong")
|
||||
c.WriteBulk(payload)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if payload == "" {
|
||||
c.WriteInline("PONG")
|
||||
return
|
||||
}
|
||||
c.WriteBulk(payload)
|
||||
})
|
||||
}
|
||||
|
||||
// AUTH
|
||||
func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) > 2 {
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
|
||||
return
|
||||
}
|
||||
|
||||
var opts = struct {
|
||||
username string
|
||||
password string
|
||||
}{
|
||||
username: "default",
|
||||
password: args[0],
|
||||
}
|
||||
if len(args) == 2 {
|
||||
opts.username, opts.password = args[0], args[1]
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if len(m.passwords) == 0 && opts.username == "default" {
|
||||
c.WriteError("ERR AUTH <password> called without any password configured for the default user. Are you sure your configuration is correct?")
|
||||
return
|
||||
}
|
||||
setPW, ok := m.passwords[opts.username]
|
||||
if !ok {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
if setPW != opts.password {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
|
||||
ctx.authenticated = true
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// HELLO
|
||||
func (m *Miniredis) cmdHello(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
version int
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
if ok := optIntErr(c, args[0], &opts.version, "ERR Protocol version is not an integer or out of range"); !ok {
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
|
||||
switch opts.version {
|
||||
case 2, 3:
|
||||
default:
|
||||
c.WriteError("NOPROTO unsupported protocol version")
|
||||
return
|
||||
}
|
||||
|
||||
var checkAuth bool
|
||||
for len(args) > 0 {
|
||||
switch strings.ToUpper(args[0]) {
|
||||
case "AUTH":
|
||||
if len(args) < 3 {
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
opts.username, opts.password, args = args[1], args[2], args[3:]
|
||||
checkAuth = true
|
||||
case "SETNAME":
|
||||
if len(args) < 2 {
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
_, args = args[1], args[2:]
|
||||
default:
|
||||
c.WriteError(fmt.Sprintf("ERR Syntax error in HELLO option '%s'", args[0]))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.passwords) == 0 && opts.username == "default" {
|
||||
// redis ignores legacy "AUTH" if it's not enabled.
|
||||
checkAuth = false
|
||||
}
|
||||
if checkAuth {
|
||||
setPW, ok := m.passwords[opts.username]
|
||||
if !ok {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
if setPW != opts.password {
|
||||
c.WriteError("WRONGPASS invalid username-password pair")
|
||||
return
|
||||
}
|
||||
getCtx(c).authenticated = true
|
||||
}
|
||||
|
||||
c.Resp3 = opts.version == 3
|
||||
|
||||
c.WriteMapLen(7)
|
||||
c.WriteBulk("server")
|
||||
c.WriteBulk("miniredis")
|
||||
c.WriteBulk("version")
|
||||
c.WriteBulk("6.0.5")
|
||||
c.WriteBulk("proto")
|
||||
c.WriteInt(opts.version)
|
||||
c.WriteBulk("id")
|
||||
c.WriteInt(42)
|
||||
c.WriteBulk("mode")
|
||||
c.WriteBulk("standalone")
|
||||
c.WriteBulk("role")
|
||||
c.WriteBulk("master")
|
||||
c.WriteBulk("modules")
|
||||
c.WriteLen(0)
|
||||
}
|
||||
|
||||
// ECHO
|
||||
func (m *Miniredis) cmdEcho(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
msg := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteBulk(msg)
|
||||
})
|
||||
}
|
||||
|
||||
// SELECT
|
||||
func (m *Miniredis) cmdSelect(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.isValidCMD(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
id int
|
||||
}
|
||||
if ok := optInt(c, args[0], &opts.id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if opts.id < 0 {
|
||||
c.WriteError(msgDBIndexOutOfRange)
|
||||
setDirty(c)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.selectedDB = opts.id
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// SWAPDB
|
||||
func (m *Miniredis) cmdSwapdb(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
id1 int
|
||||
id2 int
|
||||
}
|
||||
|
||||
if ok := optIntErr(c, args[0], &opts.id1, "ERR invalid first DB index"); !ok {
|
||||
return
|
||||
}
|
||||
if ok := optIntErr(c, args[1], &opts.id2, "ERR invalid second DB index"); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if opts.id1 < 0 || opts.id2 < 0 {
|
||||
c.WriteError(msgDBIndexOutOfRange)
|
||||
setDirty(c)
|
||||
return
|
||||
}
|
||||
|
||||
m.swapDB(opts.id1, opts.id2)
|
||||
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// QUIT
|
||||
func (m *Miniredis) cmdQuit(c *server.Peer, cmd string, args []string) {
|
||||
// QUIT isn't transactionfied and accepts any arguments.
|
||||
c.WriteOK()
|
||||
c.Close()
|
||||
}
|
||||
+813
@@ -0,0 +1,813 @@
|
||||
// Commands from https://redis.io/commands#generic
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
const (
|
||||
// expiretimeReplyNoExpiration is return value for EXPIRETIME and PEXPIRETIME if the key exists but has no associated expiration time
|
||||
expiretimeReplyNoExpiration = -1
|
||||
// expiretimeReplyMissingKey is return value for EXPIRETIME and PEXPIRETIME if the key does not exist
|
||||
expiretimeReplyMissingKey = -2
|
||||
)
|
||||
|
||||
func inSeconds(t time.Time) int {
|
||||
return int(t.Unix())
|
||||
}
|
||||
|
||||
func inMilliSeconds(t time.Time) int {
|
||||
return int(t.UnixMilli())
|
||||
}
|
||||
|
||||
// commandsGeneric handles EXPIRE, TTL, PERSIST, &c.
|
||||
func commandsGeneric(m *Miniredis) {
|
||||
m.srv.Register("COPY", m.cmdCopy)
|
||||
m.srv.Register("DEL", m.cmdDel)
|
||||
// DUMP
|
||||
m.srv.Register("EXISTS", m.cmdExists)
|
||||
m.srv.Register("EXPIRE", makeCmdExpire(m, false, time.Second))
|
||||
m.srv.Register("EXPIREAT", makeCmdExpire(m, true, time.Second))
|
||||
m.srv.Register("EXPIRETIME", m.makeCmdExpireTime(inSeconds))
|
||||
m.srv.Register("PEXPIRETIME", m.makeCmdExpireTime(inMilliSeconds))
|
||||
m.srv.Register("KEYS", m.cmdKeys)
|
||||
// MIGRATE
|
||||
m.srv.Register("MOVE", m.cmdMove)
|
||||
// OBJECT
|
||||
m.srv.Register("PERSIST", m.cmdPersist)
|
||||
m.srv.Register("PEXPIRE", makeCmdExpire(m, false, time.Millisecond))
|
||||
m.srv.Register("PEXPIREAT", makeCmdExpire(m, true, time.Millisecond))
|
||||
m.srv.Register("PTTL", m.cmdPTTL)
|
||||
m.srv.Register("RANDOMKEY", m.cmdRandomkey)
|
||||
m.srv.Register("RENAME", m.cmdRename)
|
||||
m.srv.Register("RENAMENX", m.cmdRenamenx)
|
||||
// RESTORE
|
||||
m.srv.Register("TOUCH", m.cmdTouch)
|
||||
m.srv.Register("TTL", m.cmdTTL)
|
||||
m.srv.Register("TYPE", m.cmdType)
|
||||
m.srv.Register("SCAN", m.cmdScan)
|
||||
// SORT
|
||||
m.srv.Register("UNLINK", m.cmdDel)
|
||||
}
|
||||
|
||||
type expireOpts struct {
|
||||
key string
|
||||
value int
|
||||
nx bool
|
||||
xx bool
|
||||
gt bool
|
||||
lt bool
|
||||
}
|
||||
|
||||
func expireParse(cmd string, args []string) (*expireOpts, error) {
|
||||
var opts expireOpts
|
||||
|
||||
opts.key = args[0]
|
||||
if err := optIntSimple(args[1], &opts.value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = args[2:]
|
||||
for len(args) > 0 {
|
||||
switch strings.ToLower(args[0]) {
|
||||
case "nx":
|
||||
opts.nx = true
|
||||
case "xx":
|
||||
opts.xx = true
|
||||
case "gt":
|
||||
opts.gt = true
|
||||
case "lt":
|
||||
opts.lt = true
|
||||
default:
|
||||
return nil, fmt.Errorf("ERR Unsupported option %s", args[0])
|
||||
}
|
||||
args = args[1:]
|
||||
}
|
||||
if opts.gt && opts.lt {
|
||||
return nil, errors.New("ERR GT and LT options at the same time are not compatible")
|
||||
}
|
||||
if opts.nx && (opts.xx || opts.gt || opts.lt) {
|
||||
return nil, errors.New("ERR NX and XX, GT or LT options at the same time are not compatible")
|
||||
}
|
||||
return &opts, nil
|
||||
}
|
||||
|
||||
// generic expire command for EXPIRE, PEXPIRE, EXPIREAT, PEXPIREAT
|
||||
// d is the time unit. If unix is set it'll be seen as a unixtimestamp and
|
||||
// converted to a duration.
|
||||
func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, string, []string) {
|
||||
return func(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := expireParse(cmd, args)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
// Key must be present.
|
||||
if _, ok := db.keys[opts.key]; !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
oldTTL, ok := db.ttl[opts.key]
|
||||
|
||||
var newTTL time.Duration
|
||||
if unix {
|
||||
newTTL = m.at(opts.value, d)
|
||||
} else {
|
||||
newTTL = time.Duration(opts.value) * d
|
||||
}
|
||||
|
||||
// > NX -- Set expiry only when the key has no expiry
|
||||
if opts.nx && ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
// > XX -- Set expiry only when the key has an existing expiry
|
||||
if opts.xx && !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
// > GT -- Set expiry only when the new expiry is greater than current one
|
||||
// (no exp == infinity)
|
||||
if opts.gt && (!ok || newTTL <= oldTTL) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
// > LT -- Set expiry only when the new expiry is less than current one
|
||||
if opts.lt && ok && newTTL > oldTTL {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
db.ttl[opts.key] = newTTL
|
||||
db.incr(opts.key)
|
||||
db.checkTTL(opts.key)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// makeCmdExpireTime creates server command function that returns the absolute Unix timestamp (since January 1, 1970)
|
||||
// at which the given key will expire, in unit selected by time result strategy (e.g. seconds, milliseconds).
|
||||
// For more information see redis documentation for [expiretime] and [pexpiretime].
|
||||
//
|
||||
// [expiretime]: https://redis.io/commands/expiretime/
|
||||
// [pexpiretime]: https://redis.io/commands/pexpiretime/
|
||||
func (m *Miniredis) makeCmdExpireTime(timeResultStrategy func(time.Time) int) server.Cmd {
|
||||
return func(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
c.WriteInt(expiretimeReplyMissingKey)
|
||||
return
|
||||
}
|
||||
|
||||
ttl, ok := db.ttl[key]
|
||||
if !ok {
|
||||
c.WriteInt(expiretimeReplyNoExpiration)
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(timeResultStrategy(m.effectiveNow().Add(ttl)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TOUCH
|
||||
func (m *Miniredis) cmdTouch(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count := 0
|
||||
for _, key := range args {
|
||||
if db.exists(key) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// TTL
|
||||
func (m *Miniredis) cmdTTL(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// No such key
|
||||
c.WriteInt(-2)
|
||||
return
|
||||
}
|
||||
|
||||
v, ok := db.ttl[key]
|
||||
if !ok {
|
||||
// no expire value
|
||||
c.WriteInt(-1)
|
||||
return
|
||||
}
|
||||
c.WriteInt(int(v.Seconds()))
|
||||
})
|
||||
}
|
||||
|
||||
// PTTL
|
||||
func (m *Miniredis) cmdPTTL(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// no such key
|
||||
c.WriteInt(-2)
|
||||
return
|
||||
}
|
||||
|
||||
v, ok := db.ttl[key]
|
||||
if !ok {
|
||||
// no expire value
|
||||
c.WriteInt(-1)
|
||||
return
|
||||
}
|
||||
c.WriteInt(int(v.Nanoseconds() / 1000000))
|
||||
})
|
||||
}
|
||||
|
||||
// PERSIST
|
||||
func (m *Miniredis) cmdPersist(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if _, ok := db.keys[key]; !ok {
|
||||
// no such key
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := db.ttl[key]; !ok {
|
||||
// no expire value
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
delete(db.ttl, key)
|
||||
db.incr(key)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// DEL and UNLINK
|
||||
func (m *Miniredis) cmdDel(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count := 0
|
||||
for _, key := range args {
|
||||
if db.exists(key) {
|
||||
count++
|
||||
}
|
||||
db.del(key, true) // delete expire
|
||||
}
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// TYPE
|
||||
func (m *Miniredis) cmdType(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError("usage error")
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteInline("none")
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInline(t)
|
||||
})
|
||||
}
|
||||
|
||||
// EXISTS
|
||||
func (m *Miniredis) cmdExists(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
found := 0
|
||||
for _, k := range args {
|
||||
if db.exists(k) {
|
||||
found++
|
||||
}
|
||||
}
|
||||
c.WriteInt(found)
|
||||
})
|
||||
}
|
||||
|
||||
// MOVE
|
||||
func (m *Miniredis) cmdMove(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
targetDB int
|
||||
}
|
||||
|
||||
opts.key = args[0]
|
||||
opts.targetDB, _ = strconv.Atoi(args[1])
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if ctx.selectedDB == opts.targetDB {
|
||||
c.WriteError("ERR source and destination objects are the same")
|
||||
return
|
||||
}
|
||||
db := m.db(ctx.selectedDB)
|
||||
targetDB := m.db(opts.targetDB)
|
||||
|
||||
if !db.move(opts.key, targetDB) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// KEYS
|
||||
func (m *Miniredis) cmdKeys(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
keys, _ := matchKeys(db.allKeys(), key)
|
||||
c.WriteLen(len(keys))
|
||||
for _, s := range keys {
|
||||
c.WriteBulk(s)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RANDOMKEY
|
||||
func (m *Miniredis) cmdRandomkey(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if len(db.keys) == 0 {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
nr := m.randIntn(len(db.keys))
|
||||
for k := range db.keys {
|
||||
if nr == 0 {
|
||||
c.WriteBulk(k)
|
||||
return
|
||||
}
|
||||
nr--
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RENAME
|
||||
func (m *Miniredis) cmdRename(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
from: args[0],
|
||||
to: args[1],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.from) {
|
||||
c.WriteError(msgKeyNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
db.rename(opts.from, opts.to)
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// RENAMENX
|
||||
func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
from string
|
||||
to string
|
||||
}{
|
||||
from: args[0],
|
||||
to: args[1],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.from) {
|
||||
c.WriteError(msgKeyNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if db.exists(opts.to) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
db.rename(opts.from, opts.to)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
type scanOpts struct {
|
||||
cursor int
|
||||
count int
|
||||
withMatch bool
|
||||
match string
|
||||
withType bool
|
||||
_type string
|
||||
}
|
||||
|
||||
func scanParse(cmd string, args []string) (*scanOpts, error) {
|
||||
var opts scanOpts
|
||||
if err := optIntSimple(args[0], &opts.cursor); err != nil {
|
||||
return nil, errors.New(msgInvalidCursor)
|
||||
}
|
||||
args = args[1:]
|
||||
|
||||
// MATCH, COUNT and TYPE options
|
||||
for len(args) > 0 {
|
||||
if strings.ToLower(args[0]) == "count" {
|
||||
if len(args) < 2 {
|
||||
return nil, errors.New(msgSyntaxError)
|
||||
}
|
||||
count, err := strconv.Atoi(args[1])
|
||||
if err != nil || count < 0 {
|
||||
return nil, errors.New(msgInvalidInt)
|
||||
}
|
||||
if count == 0 {
|
||||
return nil, errors.New(msgSyntaxError)
|
||||
}
|
||||
opts.count = count
|
||||
args = args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "match" {
|
||||
if len(args) < 2 {
|
||||
return nil, errors.New(msgSyntaxError)
|
||||
}
|
||||
opts.withMatch = true
|
||||
opts.match, args = args[1], args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "type" {
|
||||
if len(args) < 2 {
|
||||
return nil, errors.New(msgSyntaxError)
|
||||
}
|
||||
opts.withType = true
|
||||
opts._type, args = strings.ToLower(args[1]), args[2:]
|
||||
continue
|
||||
}
|
||||
return nil, errors.New(msgSyntaxError)
|
||||
}
|
||||
return &opts, nil
|
||||
}
|
||||
|
||||
// SCAN
|
||||
func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := scanParse(cmd, args)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
// We return _all_ (matched) keys every time.
|
||||
var keys []string
|
||||
|
||||
if opts.withType {
|
||||
keys = make([]string, 0)
|
||||
for k, t := range db.keys {
|
||||
// type must be given exactly; no pattern matching is performed
|
||||
if t == opts._type {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
keys = db.allKeys()
|
||||
}
|
||||
|
||||
sort.Strings(keys) // To make things deterministic.
|
||||
|
||||
if opts.withMatch {
|
||||
keys, _ = matchKeys(keys, opts.match)
|
||||
}
|
||||
|
||||
low := opts.cursor
|
||||
high := low + opts.count
|
||||
// validate high is correct
|
||||
if high > len(keys) || high == 0 {
|
||||
high = len(keys)
|
||||
}
|
||||
if opts.cursor > high {
|
||||
// invalid cursor
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(0) // no elements
|
||||
return
|
||||
}
|
||||
cursorValue := low + opts.count
|
||||
if cursorValue >= len(keys) {
|
||||
cursorValue = 0 // no next cursor
|
||||
}
|
||||
keys = keys[low:high]
|
||||
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(fmt.Sprintf("%d", cursorValue))
|
||||
c.WriteLen(len(keys))
|
||||
for _, k := range keys {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type copyOpts struct {
|
||||
from string
|
||||
to string
|
||||
destinationDB int
|
||||
replace bool
|
||||
}
|
||||
|
||||
func copyParse(cmd string, args []string) (*copyOpts, error) {
|
||||
opts := copyOpts{
|
||||
destinationDB: -1,
|
||||
}
|
||||
|
||||
opts.from, opts.to, args = args[0], args[1], args[2:]
|
||||
for len(args) > 0 {
|
||||
switch strings.ToLower(args[0]) {
|
||||
case "db":
|
||||
if len(args) < 2 {
|
||||
return nil, errors.New(msgSyntaxError)
|
||||
}
|
||||
if err := optIntSimple(args[1], &opts.destinationDB); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if opts.destinationDB < 0 {
|
||||
return nil, errors.New(msgDBIndexOutOfRange)
|
||||
}
|
||||
args = args[2:]
|
||||
case "replace":
|
||||
opts.replace = true
|
||||
args = args[1:]
|
||||
default:
|
||||
return nil, errors.New(msgSyntaxError)
|
||||
}
|
||||
}
|
||||
return &opts, nil
|
||||
}
|
||||
|
||||
// COPY
|
||||
func (m *Miniredis) cmdCopy(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := copyParse(cmd, args)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
fromDB, toDB := ctx.selectedDB, opts.destinationDB
|
||||
if toDB == -1 {
|
||||
toDB = fromDB
|
||||
}
|
||||
|
||||
if fromDB == toDB && opts.from == opts.to {
|
||||
c.WriteError("ERR source and destination objects are the same")
|
||||
return
|
||||
}
|
||||
|
||||
if !m.db(fromDB).exists(opts.from) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if !opts.replace {
|
||||
if m.db(toDB).exists(opts.to) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
m.copy(m.db(fromDB), opts.from, m.db(toDB), opts.to)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
+609
@@ -0,0 +1,609 @@
|
||||
// Commands from https://redis.io/commands#geo
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsGeo handles GEOADD, GEORADIUS etc.
|
||||
func commandsGeo(m *Miniredis) {
|
||||
m.srv.Register("GEOADD", m.cmdGeoadd)
|
||||
m.srv.Register("GEODIST", m.cmdGeodist)
|
||||
m.srv.Register("GEOPOS", m.cmdGeopos)
|
||||
m.srv.Register("GEORADIUS", m.cmdGeoradius)
|
||||
m.srv.Register("GEORADIUS_RO", m.cmdGeoradius)
|
||||
m.srv.Register("GEORADIUSBYMEMBER", m.cmdGeoradiusbymember)
|
||||
m.srv.Register("GEORADIUSBYMEMBER_RO", m.cmdGeoradiusbymember)
|
||||
}
|
||||
|
||||
// GEOADD
|
||||
func (m *Miniredis) cmdGeoadd(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 || len(args[1:])%3 != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != keyTypeSortedSet {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
toSet := map[string]float64{}
|
||||
for len(args) > 2 {
|
||||
rawLong, rawLat, name := args[0], args[1], args[2]
|
||||
args = args[3:]
|
||||
longitude, err := strconv.ParseFloat(rawLong, 64)
|
||||
if err != nil {
|
||||
c.WriteError("ERR value is not a valid float")
|
||||
return
|
||||
}
|
||||
latitude, err := strconv.ParseFloat(rawLat, 64)
|
||||
if err != nil {
|
||||
c.WriteError("ERR value is not a valid float")
|
||||
return
|
||||
}
|
||||
|
||||
if latitude < -85.05112878 ||
|
||||
latitude > 85.05112878 ||
|
||||
longitude < -180 ||
|
||||
longitude > 180 {
|
||||
c.WriteError(fmt.Sprintf("ERR invalid longitude,latitude pair %.6f,%.6f", longitude, latitude))
|
||||
return
|
||||
}
|
||||
|
||||
toSet[name] = float64(toGeohash(longitude, latitude))
|
||||
}
|
||||
|
||||
set := 0
|
||||
for name, score := range toSet {
|
||||
if db.ssetAdd(key, score, name) {
|
||||
set++
|
||||
}
|
||||
}
|
||||
c.WriteInt(set)
|
||||
})
|
||||
}
|
||||
|
||||
// GEODIST
|
||||
func (m *Miniredis) cmdGeodist(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, from, to, args := args[0], args[1], args[2], args[3:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
if !db.exists(key) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if db.t(key) != keyTypeSortedSet {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
unit := "m"
|
||||
if len(args) > 0 {
|
||||
unit, args = args[0], args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
toMeter := parseUnit(unit)
|
||||
if toMeter == 0 {
|
||||
c.WriteError(msgUnsupportedUnit)
|
||||
return
|
||||
}
|
||||
|
||||
members := db.sortedsetKeys[key]
|
||||
fromD, okFrom := members.get(from)
|
||||
toD, okTo := members.get(to)
|
||||
if !okFrom || !okTo {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
fromLo, fromLat := fromGeohash(uint64(fromD))
|
||||
toLo, toLat := fromGeohash(uint64(toD))
|
||||
|
||||
dist := distance(fromLat, fromLo, toLat, toLo) / toMeter
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", dist))
|
||||
})
|
||||
}
|
||||
|
||||
// GEOPOS
|
||||
func (m *Miniredis) cmdGeopos(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
key, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != keyTypeSortedSet {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(args))
|
||||
for _, l := range args {
|
||||
if !db.ssetExists(key, l) {
|
||||
c.WriteLen(-1)
|
||||
continue
|
||||
}
|
||||
score := db.ssetScore(key, l)
|
||||
c.WriteLen(2)
|
||||
long, lat := fromGeohash(uint64(score))
|
||||
c.WriteBulk(fmt.Sprintf("%f", long))
|
||||
c.WriteBulk(fmt.Sprintf("%f", lat))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type geoDistance struct {
|
||||
Name string
|
||||
Score float64
|
||||
Distance float64
|
||||
Longitude float64
|
||||
Latitude float64
|
||||
}
|
||||
|
||||
// GEORADIUS and GEORADIUS_RO
|
||||
func (m *Miniredis) cmdGeoradius(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 5 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
longitude, err := strconv.ParseFloat(args[1], 64)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
latitude, err := strconv.ParseFloat(args[2], 64)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
radius, err := strconv.ParseFloat(args[3], 64)
|
||||
if err != nil || radius < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
toMeter := parseUnit(args[4])
|
||||
if toMeter == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
args = args[5:]
|
||||
|
||||
var opts struct {
|
||||
withDist bool
|
||||
withCoord bool
|
||||
direction direction // unsorted
|
||||
count int
|
||||
withStore bool
|
||||
storeKey string
|
||||
withStoredist bool
|
||||
storedistKey string
|
||||
}
|
||||
for len(args) > 0 {
|
||||
arg := args[0]
|
||||
args = args[1:]
|
||||
switch strings.ToUpper(arg) {
|
||||
case "WITHCOORD":
|
||||
opts.withCoord = true
|
||||
case "WITHDIST":
|
||||
opts.withDist = true
|
||||
case "ASC":
|
||||
opts.direction = asc
|
||||
case "DESC":
|
||||
opts.direction = desc
|
||||
case "COUNT":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
n, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if n <= 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR COUNT must be > 0")
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
opts.count = n
|
||||
case "STORE":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStore = true
|
||||
opts.storeKey = args[0]
|
||||
args = args[1:]
|
||||
case "STOREDIST":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStoredist = true
|
||||
opts.storedistKey = args[0]
|
||||
args = args[1:]
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if strings.ToUpper(cmd) == "GEORADIUS_RO" && (opts.withStore || opts.withStoredist) {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
|
||||
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
|
||||
return
|
||||
}
|
||||
|
||||
db := m.db(ctx.selectedDB)
|
||||
members := db.ssetElements(key)
|
||||
|
||||
matches := withinRadius(members, longitude, latitude, radius*toMeter)
|
||||
|
||||
// deal with ASC/DESC
|
||||
if opts.direction != unsorted {
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if opts.direction == desc {
|
||||
return matches[i].Distance > matches[j].Distance
|
||||
}
|
||||
return matches[i].Distance < matches[j].Distance
|
||||
})
|
||||
}
|
||||
|
||||
// deal with COUNT
|
||||
if opts.count > 0 && len(matches) > opts.count {
|
||||
matches = matches[:opts.count]
|
||||
}
|
||||
|
||||
// deal with "STORE x"
|
||||
if opts.withStore {
|
||||
db.del(opts.storeKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storeKey, member.Score, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
// deal with "STOREDIST x"
|
||||
if opts.withStoredist {
|
||||
db.del(opts.storedistKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storedistKey, member.Distance/toMeter, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(matches))
|
||||
for _, member := range matches {
|
||||
if !opts.withDist && !opts.withCoord {
|
||||
c.WriteBulk(member.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
len := 1
|
||||
if opts.withDist {
|
||||
len++
|
||||
}
|
||||
if opts.withCoord {
|
||||
len++
|
||||
}
|
||||
c.WriteLen(len)
|
||||
c.WriteBulk(member.Name)
|
||||
if opts.withDist {
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/toMeter))
|
||||
}
|
||||
if opts.withCoord {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GEORADIUSBYMEMBER and GEORADIUSBYMEMBER_RO
|
||||
func (m *Miniredis) cmdGeoradiusbymember(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 4 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
member string
|
||||
radius float64
|
||||
toMeter float64
|
||||
|
||||
withDist bool
|
||||
withCoord bool
|
||||
direction direction // unsorted
|
||||
count int
|
||||
withStore bool
|
||||
storeKey string
|
||||
withStoredist bool
|
||||
storedistKey string
|
||||
}{
|
||||
key: args[0],
|
||||
member: args[1],
|
||||
}
|
||||
|
||||
r, err := strconv.ParseFloat(args[2], 64)
|
||||
if err != nil || r < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
opts.radius = r
|
||||
|
||||
opts.toMeter = parseUnit(args[3])
|
||||
if opts.toMeter == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
args = args[4:]
|
||||
|
||||
for len(args) > 0 {
|
||||
arg := args[0]
|
||||
args = args[1:]
|
||||
switch strings.ToUpper(arg) {
|
||||
case "WITHCOORD":
|
||||
opts.withCoord = true
|
||||
case "WITHDIST":
|
||||
opts.withDist = true
|
||||
case "ASC":
|
||||
opts.direction = asc
|
||||
case "DESC":
|
||||
opts.direction = desc
|
||||
case "COUNT":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
n, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if n <= 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR COUNT must be > 0")
|
||||
return
|
||||
}
|
||||
args = args[1:]
|
||||
opts.count = n
|
||||
case "STORE":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStore = true
|
||||
opts.storeKey = args[0]
|
||||
args = args[1:]
|
||||
case "STOREDIST":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
opts.withStoredist = true
|
||||
opts.storedistKey = args[0]
|
||||
args = args[1:]
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if strings.ToUpper(cmd) == "GEORADIUSBYMEMBER_RO" && (opts.withStore || opts.withStoredist) {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR syntax error")
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
if (opts.withStore || opts.withStoredist) && (opts.withDist || opts.withCoord) {
|
||||
c.WriteError("ERR STORE option in GEORADIUS is not compatible with WITHDIST, WITHHASH and WITHCOORDS options")
|
||||
return
|
||||
}
|
||||
|
||||
db := m.db(ctx.selectedDB)
|
||||
if !db.exists(opts.key) {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(opts.key) != keyTypeSortedSet {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// get position of member
|
||||
if !db.ssetExists(opts.key, opts.member) {
|
||||
c.WriteError("ERR could not decode requested zset member")
|
||||
return
|
||||
}
|
||||
score := db.ssetScore(opts.key, opts.member)
|
||||
longitude, latitude := fromGeohash(uint64(score))
|
||||
|
||||
members := db.ssetElements(opts.key)
|
||||
matches := withinRadius(members, longitude, latitude, opts.radius*opts.toMeter)
|
||||
|
||||
// deal with ASC/DESC
|
||||
if opts.direction != unsorted {
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if opts.direction == desc {
|
||||
return matches[i].Distance > matches[j].Distance
|
||||
}
|
||||
return matches[i].Distance < matches[j].Distance
|
||||
})
|
||||
}
|
||||
|
||||
// deal with COUNT
|
||||
if opts.count > 0 && len(matches) > opts.count {
|
||||
matches = matches[:opts.count]
|
||||
}
|
||||
|
||||
// deal with "STORE x"
|
||||
if opts.withStore {
|
||||
db.del(opts.storeKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storeKey, member.Score, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
// deal with "STOREDIST x"
|
||||
if opts.withStoredist {
|
||||
db.del(opts.storedistKey, true)
|
||||
for _, member := range matches {
|
||||
db.ssetAdd(opts.storedistKey, member.Distance/opts.toMeter, member.Name)
|
||||
}
|
||||
c.WriteInt(len(matches))
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(matches))
|
||||
for _, member := range matches {
|
||||
if !opts.withDist && !opts.withCoord {
|
||||
c.WriteBulk(member.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
len := 1
|
||||
if opts.withDist {
|
||||
len++
|
||||
}
|
||||
if opts.withCoord {
|
||||
len++
|
||||
}
|
||||
c.WriteLen(len)
|
||||
c.WriteBulk(member.Name)
|
||||
if opts.withDist {
|
||||
c.WriteBulk(fmt.Sprintf("%.4f", member.Distance/opts.toMeter))
|
||||
}
|
||||
if opts.withCoord {
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Longitude))
|
||||
c.WriteBulk(fmt.Sprintf("%f", member.Latitude))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func withinRadius(members []ssElem, longitude, latitude, radius float64) []geoDistance {
|
||||
matches := []geoDistance{}
|
||||
for _, el := range members {
|
||||
elLo, elLat := fromGeohash(uint64(el.score))
|
||||
distanceInMeter := distance(latitude, longitude, elLat, elLo)
|
||||
|
||||
if distanceInMeter <= radius {
|
||||
matches = append(matches, geoDistance{
|
||||
Name: el.member,
|
||||
Score: el.score,
|
||||
Distance: distanceInMeter,
|
||||
Longitude: elLo,
|
||||
Latitude: elLat,
|
||||
})
|
||||
}
|
||||
}
|
||||
return matches
|
||||
}
|
||||
|
||||
func parseUnit(u string) float64 {
|
||||
switch strings.ToLower(u) {
|
||||
case "m":
|
||||
return 1
|
||||
case "km":
|
||||
return 1000
|
||||
case "mi":
|
||||
return 1609.34
|
||||
case "ft":
|
||||
return 0.3048
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
+777
@@ -0,0 +1,777 @@
|
||||
// Commands from https://redis.io/commands#hash
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsHash handles all hash value operations.
|
||||
func commandsHash(m *Miniredis) {
|
||||
m.srv.Register("HDEL", m.cmdHdel)
|
||||
m.srv.Register("HEXISTS", m.cmdHexists)
|
||||
m.srv.Register("HGET", m.cmdHget)
|
||||
m.srv.Register("HGETALL", m.cmdHgetall)
|
||||
m.srv.Register("HINCRBY", m.cmdHincrby)
|
||||
m.srv.Register("HINCRBYFLOAT", m.cmdHincrbyfloat)
|
||||
m.srv.Register("HKEYS", m.cmdHkeys)
|
||||
m.srv.Register("HLEN", m.cmdHlen)
|
||||
m.srv.Register("HMGET", m.cmdHmget)
|
||||
m.srv.Register("HMSET", m.cmdHmset)
|
||||
m.srv.Register("HSET", m.cmdHset)
|
||||
m.srv.Register("HSETNX", m.cmdHsetnx)
|
||||
m.srv.Register("HSTRLEN", m.cmdHstrlen)
|
||||
m.srv.Register("HVALS", m.cmdHvals)
|
||||
m.srv.Register("HSCAN", m.cmdHscan)
|
||||
m.srv.Register("HRANDFIELD", m.cmdHrandfield)
|
||||
}
|
||||
|
||||
// HSET
|
||||
func (m *Miniredis) cmdHset(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, pairs := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if len(pairs)%2 == 1 {
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
if t, ok := db.keys[key]; ok && t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
new := db.hashSet(key, pairs...)
|
||||
c.WriteInt(new)
|
||||
})
|
||||
}
|
||||
|
||||
// HSETNX
|
||||
func (m *Miniredis) cmdHsetnx(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
value string
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
value: args[2],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[opts.key]; ok && t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := db.hashKeys[opts.key]; !ok {
|
||||
db.hashKeys[opts.key] = map[string]string{}
|
||||
db.keys[opts.key] = keyTypeHash
|
||||
}
|
||||
_, ok := db.hashKeys[opts.key][opts.field]
|
||||
if ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
db.hashKeys[opts.key][opts.field] = opts.value
|
||||
db.incr(opts.key)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// HMSET
|
||||
func (m *Miniredis) cmdHmset(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, args := args[0], args[1:]
|
||||
if len(args)%2 != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[key]; ok && t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
for len(args) > 0 {
|
||||
field, value := args[0], args[1]
|
||||
args = args[2:]
|
||||
db.hashSet(key, field, value)
|
||||
}
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// HGET
|
||||
func (m *Miniredis) cmdHget(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, field := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
if t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
value, ok := db.hashKeys[key][field]
|
||||
if !ok {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
c.WriteBulk(value)
|
||||
})
|
||||
}
|
||||
|
||||
// HDEL
|
||||
func (m *Miniredis) cmdHdel(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
fields []string
|
||||
}{
|
||||
key: args[0],
|
||||
fields: args[1:],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[opts.key]
|
||||
if !ok {
|
||||
// No key is zero deleted
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
deleted := 0
|
||||
for _, f := range opts.fields {
|
||||
_, ok := db.hashKeys[opts.key][f]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
delete(db.hashKeys[opts.key], f)
|
||||
deleted++
|
||||
}
|
||||
c.WriteInt(deleted)
|
||||
|
||||
// Nothing left. Remove the whole key.
|
||||
if len(db.hashKeys[opts.key]) == 0 {
|
||||
db.del(opts.key, true)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HEXISTS
|
||||
func (m *Miniredis) cmdHexists(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[opts.key]
|
||||
if !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := db.hashKeys[opts.key][opts.field]; !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// HGETALL
|
||||
func (m *Miniredis) cmdHgetall(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteMapLen(0)
|
||||
return
|
||||
}
|
||||
if t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteMapLen(len(db.hashKeys[key]))
|
||||
for _, k := range db.hashFields(key) {
|
||||
c.WriteBulk(k)
|
||||
c.WriteBulk(db.hashGet(key, k))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HKEYS
|
||||
func (m *Miniredis) cmdHkeys(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteLen(0)
|
||||
return
|
||||
}
|
||||
if db.t(key) != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
fields := db.hashFields(key)
|
||||
c.WriteLen(len(fields))
|
||||
for _, f := range fields {
|
||||
c.WriteBulk(f)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HSTRLEN
|
||||
func (m *Miniredis) cmdHstrlen(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
hash, key := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[hash]
|
||||
if !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
keys := db.hashKeys[hash]
|
||||
c.WriteInt(len(keys[key]))
|
||||
})
|
||||
}
|
||||
|
||||
// HVALS
|
||||
func (m *Miniredis) cmdHvals(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteLen(0)
|
||||
return
|
||||
}
|
||||
if t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
vals := db.hashValues(key)
|
||||
c.WriteLen(len(vals))
|
||||
for _, v := range vals {
|
||||
c.WriteBulk(v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HLEN
|
||||
func (m *Miniredis) cmdHlen(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
if t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(len(db.hashKeys[key]))
|
||||
})
|
||||
}
|
||||
|
||||
// HMGET
|
||||
func (m *Miniredis) cmdHmget(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[key]; ok && t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
f, ok := db.hashKeys[key]
|
||||
if !ok {
|
||||
f = map[string]string{}
|
||||
}
|
||||
|
||||
c.WriteLen(len(args) - 1)
|
||||
for _, k := range args[1:] {
|
||||
v, ok := f[k]
|
||||
if !ok {
|
||||
c.WriteNull()
|
||||
continue
|
||||
}
|
||||
c.WriteBulk(v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HINCRBY
|
||||
func (m *Miniredis) cmdHincrby(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
delta int
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
}
|
||||
if ok := optInt(c, args[2], &opts.delta); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[opts.key]; ok && t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
v, err := db.hashIncr(opts.key, opts.field, opts.delta)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
c.WriteInt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// HINCRBYFLOAT
|
||||
func (m *Miniredis) cmdHincrbyfloat(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
field string
|
||||
delta *big.Float
|
||||
}{
|
||||
key: args[0],
|
||||
field: args[1],
|
||||
}
|
||||
delta, _, err := big.ParseFloat(args[2], 10, 128, 0)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidFloat)
|
||||
return
|
||||
}
|
||||
opts.delta = delta
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if t, ok := db.keys[opts.key]; ok && t != keyTypeHash {
|
||||
c.WriteError(msgWrongType)
|
||||
return
|
||||
}
|
||||
|
||||
v, err := db.hashIncrfloat(opts.key, opts.field, opts.delta)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
c.WriteBulk(formatBig(v))
|
||||
})
|
||||
}
|
||||
|
||||
// HSCAN
|
||||
func (m *Miniredis) cmdHscan(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
cursor int
|
||||
withMatch bool
|
||||
match string
|
||||
}{
|
||||
key: args[0],
|
||||
}
|
||||
if ok := optIntErr(c, args[1], &opts.cursor, msgInvalidCursor); !ok {
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
|
||||
// MATCH and COUNT options
|
||||
for len(args) > 0 {
|
||||
if strings.ToLower(args[0]) == "count" {
|
||||
// we do nothing with count
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
_, err := strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "match" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.withMatch = true
|
||||
opts.match, args = args[1], args[2:]
|
||||
continue
|
||||
}
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
// return _all_ (matched) keys every time
|
||||
|
||||
if opts.cursor != 0 {
|
||||
// Invalid cursor.
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(0) // no elements
|
||||
return
|
||||
}
|
||||
if db.exists(opts.key) && db.t(opts.key) != keyTypeHash {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
members := db.hashFields(opts.key)
|
||||
if opts.withMatch {
|
||||
members, _ = matchKeys(members, opts.match)
|
||||
}
|
||||
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
// HSCAN gives key, values.
|
||||
c.WriteLen(len(members) * 2)
|
||||
for _, k := range members {
|
||||
c.WriteBulk(k)
|
||||
c.WriteBulk(db.hashGet(opts.key, k))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// HRANDFIELD
|
||||
func (m *Miniredis) cmdHrandfield(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) > 3 || len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
count int
|
||||
countSet bool
|
||||
withValues bool
|
||||
}{
|
||||
key: args[0],
|
||||
}
|
||||
|
||||
if len(args) > 1 {
|
||||
if ok := optIntErr(c, args[1], &opts.count, msgInvalidInt); !ok {
|
||||
return
|
||||
}
|
||||
opts.countSet = true
|
||||
}
|
||||
|
||||
if len(args) == 3 {
|
||||
if strings.ToLower(args[2]) == "withvalues" {
|
||||
opts.withValues = true
|
||||
} else {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
withTx(m, c, func(peer *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
members := db.hashFields(opts.key)
|
||||
m.shuffle(members)
|
||||
|
||||
if !opts.countSet {
|
||||
// > When called with just the key argument, return a random field from the
|
||||
// hash value stored at key.
|
||||
if len(members) == 0 {
|
||||
peer.WriteNull()
|
||||
return
|
||||
}
|
||||
peer.WriteBulk(members[0])
|
||||
return
|
||||
}
|
||||
|
||||
if len(members) > abs(opts.count) {
|
||||
members = members[:abs(opts.count)]
|
||||
}
|
||||
switch {
|
||||
case opts.count >= 0:
|
||||
// if count is positive there can't be duplicates, and the length is restricted
|
||||
case opts.count < 0:
|
||||
// if count is negative there can be duplicates, but length will match
|
||||
if len(members) > 0 {
|
||||
for len(members) < -opts.count {
|
||||
members = append(members, members[m.randIntn(len(members))])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if opts.withValues {
|
||||
peer.WriteMapLen(len(members))
|
||||
for _, m := range members {
|
||||
peer.WriteBulk(m)
|
||||
peer.WriteBulk(db.hashGet(opts.key, m))
|
||||
}
|
||||
return
|
||||
}
|
||||
peer.WriteLen(len(members))
|
||||
for _, m := range members {
|
||||
peer.WriteBulk(m)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func abs(n int) int {
|
||||
if n < 0 {
|
||||
return -n
|
||||
}
|
||||
return n
|
||||
}
|
||||
+95
@@ -0,0 +1,95 @@
|
||||
package miniredis
|
||||
|
||||
import "github.com/alicebob/miniredis/v2/server"
|
||||
|
||||
// commandsHll handles all hll related operations.
|
||||
func commandsHll(m *Miniredis) {
|
||||
m.srv.Register("PFADD", m.cmdPfadd)
|
||||
m.srv.Register("PFCOUNT", m.cmdPfcount)
|
||||
m.srv.Register("PFMERGE", m.cmdPfmerge)
|
||||
}
|
||||
|
||||
// PFADD
|
||||
func (m *Miniredis) cmdPfadd(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, items := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != keyTypeHll {
|
||||
c.WriteError(ErrNotValidHllValue.Error())
|
||||
return
|
||||
}
|
||||
|
||||
altered := db.hllAdd(key, items...)
|
||||
c.WriteInt(altered)
|
||||
})
|
||||
}
|
||||
|
||||
// PFCOUNT
|
||||
func (m *Miniredis) cmdPfcount(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count, err := db.hllCount(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// PFMERGE
|
||||
func (m *Miniredis) cmdPfmerge(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if err := db.hllMerge(keys); err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// Command 'INFO' from https://redis.io/commands/info/
|
||||
func (m *Miniredis) cmdInfo(c *server.Peer, cmd string, args []string) {
|
||||
if !m.isValidCMD(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) > 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
const (
|
||||
clientsSectionName = "clients"
|
||||
clientsSectionContent = "# Clients\nconnected_clients:%d\r\n"
|
||||
)
|
||||
|
||||
var result string
|
||||
|
||||
for _, key := range args {
|
||||
if key != clientsSectionName {
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf("section (%s) is not supported", key))
|
||||
return
|
||||
}
|
||||
}
|
||||
result = fmt.Sprintf(clientsSectionContent, m.Server().ClientsLen())
|
||||
|
||||
c.WriteBulk(result)
|
||||
})
|
||||
}
|
||||
+1060
File diff suppressed because it is too large
Load Diff
+58
@@ -0,0 +1,58 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsObject handles all object operations.
|
||||
func commandsObject(m *Miniredis) {
|
||||
m.srv.Register("OBJECT", m.cmdObject)
|
||||
}
|
||||
|
||||
// OBJECT
|
||||
func (m *Miniredis) cmdObject(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
switch sub := strings.ToLower(args[0]); sub {
|
||||
case "idletime":
|
||||
m.cmdObjectIdletime(c, args[1:])
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf(msgFObjectUsage, sub))
|
||||
}
|
||||
}
|
||||
|
||||
// OBJECT IDLETIME
|
||||
func (m *Miniredis) cmdObjectIdletime(c *server.Peer, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber("object|idletime"))
|
||||
return
|
||||
}
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
t, ok := db.lru[key]
|
||||
if !ok {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(int(db.master.effectiveNow().Sub(t).Seconds()))
|
||||
})
|
||||
}
|
||||
+262
@@ -0,0 +1,262 @@
|
||||
// Commands from https://redis.io/commands#pubsub
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsPubsub handles all PUB/SUB operations.
|
||||
func commandsPubsub(m *Miniredis) {
|
||||
m.srv.Register("SUBSCRIBE", m.cmdSubscribe)
|
||||
m.srv.Register("UNSUBSCRIBE", m.cmdUnsubscribe)
|
||||
m.srv.Register("PSUBSCRIBE", m.cmdPsubscribe)
|
||||
m.srv.Register("PUNSUBSCRIBE", m.cmdPunsubscribe)
|
||||
m.srv.Register("PUBLISH", m.cmdPublish)
|
||||
m.srv.Register("PUBSUB", m.cmdPubSub)
|
||||
}
|
||||
|
||||
// SUBSCRIBE
|
||||
func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
sub := m.subscribedState(c)
|
||||
for _, channel := range args {
|
||||
n := sub.Subscribe(channel)
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("subscribe")
|
||||
w.WriteBulk(channel)
|
||||
w.WriteInt(n)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// UNSUBSCRIBE
|
||||
func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
|
||||
return
|
||||
}
|
||||
|
||||
channels := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
sub := m.subscribedState(c)
|
||||
|
||||
if len(channels) == 0 {
|
||||
channels = sub.Channels()
|
||||
}
|
||||
|
||||
// there is no de-duplication
|
||||
for _, channel := range channels {
|
||||
n := sub.Unsubscribe(channel)
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("unsubscribe")
|
||||
w.WriteBulk(channel)
|
||||
w.WriteInt(n)
|
||||
})
|
||||
}
|
||||
if len(channels) == 0 {
|
||||
// special case: there is always a reply
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("unsubscribe")
|
||||
w.WriteNull()
|
||||
w.WriteInt(0)
|
||||
})
|
||||
}
|
||||
|
||||
if sub.Count() == 0 {
|
||||
endSubscriber(m, c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PSUBSCRIBE
|
||||
func (m *Miniredis) cmdPsubscribe(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
sub := m.subscribedState(c)
|
||||
for _, pat := range args {
|
||||
n := sub.Psubscribe(pat)
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("psubscribe")
|
||||
w.WriteBulk(pat)
|
||||
w.WriteInt(n)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PUNSUBSCRIBE
|
||||
func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) {
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
|
||||
return
|
||||
}
|
||||
|
||||
patterns := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
sub := m.subscribedState(c)
|
||||
|
||||
if len(patterns) == 0 {
|
||||
patterns = sub.Patterns()
|
||||
}
|
||||
|
||||
// there is no de-duplication
|
||||
for _, pat := range patterns {
|
||||
n := sub.Punsubscribe(pat)
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("punsubscribe")
|
||||
w.WriteBulk(pat)
|
||||
w.WriteInt(n)
|
||||
})
|
||||
}
|
||||
if len(patterns) == 0 {
|
||||
// special case: there is always a reply
|
||||
c.Block(func(w *server.Writer) {
|
||||
w.WritePushLen(3)
|
||||
w.WriteBulk("punsubscribe")
|
||||
w.WriteNull()
|
||||
w.WriteInt(0)
|
||||
})
|
||||
}
|
||||
|
||||
if sub.Count() == 0 {
|
||||
endSubscriber(m, c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PUBLISH
|
||||
func (m *Miniredis) cmdPublish(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
channel, mesg := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
c.WriteInt(m.publish(channel, mesg))
|
||||
})
|
||||
}
|
||||
|
||||
// PUBSUB
|
||||
func (m *Miniredis) cmdPubSub(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
subcommand := strings.ToUpper(args[0])
|
||||
subargs := args[1:]
|
||||
var argsOk bool
|
||||
|
||||
switch subcommand {
|
||||
case "CHANNELS":
|
||||
argsOk = len(subargs) < 2
|
||||
case "NUMSUB":
|
||||
argsOk = true
|
||||
case "NUMPAT":
|
||||
argsOk = len(subargs) == 0
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf(msgFPubsubUsageSimple, subcommand))
|
||||
return
|
||||
}
|
||||
|
||||
if !argsOk {
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf(msgFPubsubUsage, subcommand))
|
||||
return
|
||||
}
|
||||
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
switch subcommand {
|
||||
case "CHANNELS":
|
||||
pat := ""
|
||||
if len(subargs) == 1 {
|
||||
pat = subargs[0]
|
||||
}
|
||||
|
||||
allsubs := m.allSubscribers()
|
||||
channels := activeChannels(allsubs, pat)
|
||||
|
||||
c.WriteLen(len(channels))
|
||||
for _, channel := range channels {
|
||||
c.WriteBulk(channel)
|
||||
}
|
||||
|
||||
case "NUMSUB":
|
||||
subs := m.allSubscribers()
|
||||
c.WriteLen(len(subargs) * 2)
|
||||
for _, channel := range subargs {
|
||||
c.WriteBulk(channel)
|
||||
c.WriteInt(countSubs(subs, channel))
|
||||
}
|
||||
|
||||
case "NUMPAT":
|
||||
c.WriteInt(countPsubs(m.allSubscribers()))
|
||||
}
|
||||
})
|
||||
}
|
||||
+343
@@ -0,0 +1,343 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
"github.com/yuin/gopher-lua/parse"
|
||||
|
||||
luajson "github.com/alicebob/miniredis/v2/gopher-json"
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
func commandsScripting(m *Miniredis) {
|
||||
m.srv.Register("EVAL", m.cmdEval)
|
||||
m.srv.Register("EVALSHA", m.cmdEvalsha)
|
||||
m.srv.Register("SCRIPT", m.cmdScript)
|
||||
}
|
||||
|
||||
var (
|
||||
parsedScripts = sync.Map{}
|
||||
)
|
||||
|
||||
// Execute lua. Needs to run m.Lock()ed, from within withTx().
|
||||
// Returns true if the lua was OK (and hence should be cached).
|
||||
func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []string) bool {
|
||||
l := lua.NewState(lua.Options{SkipOpenLibs: true})
|
||||
defer l.Close()
|
||||
|
||||
// Taken from the go-lua manual
|
||||
for _, pair := range []struct {
|
||||
n string
|
||||
f lua.LGFunction
|
||||
}{
|
||||
{lua.LoadLibName, lua.OpenPackage},
|
||||
{lua.BaseLibName, lua.OpenBase},
|
||||
{lua.CoroutineLibName, lua.OpenCoroutine},
|
||||
{lua.TabLibName, lua.OpenTable},
|
||||
{lua.StringLibName, lua.OpenString},
|
||||
{lua.MathLibName, lua.OpenMath},
|
||||
{lua.DebugLibName, lua.OpenDebug},
|
||||
} {
|
||||
if err := l.CallByParam(lua.P{
|
||||
Fn: l.NewFunction(pair.f),
|
||||
NRet: 0,
|
||||
Protect: true,
|
||||
}, lua.LString(pair.n)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
luajson.Preload(l)
|
||||
requireGlobal(l, "cjson", "json")
|
||||
|
||||
// set global variable KEYS
|
||||
keysTable := l.NewTable()
|
||||
keysS, args := args[0], args[1:]
|
||||
keysLen, err := strconv.Atoi(keysS)
|
||||
if err != nil {
|
||||
c.WriteError(msgInvalidInt)
|
||||
return false
|
||||
}
|
||||
if keysLen < 0 {
|
||||
c.WriteError(msgNegativeKeysNumber)
|
||||
return false
|
||||
}
|
||||
if keysLen > len(args) {
|
||||
c.WriteError(msgInvalidKeysNumber)
|
||||
return false
|
||||
}
|
||||
keys, args := args[:keysLen], args[keysLen:]
|
||||
for i, k := range keys {
|
||||
l.RawSet(keysTable, lua.LNumber(i+1), lua.LString(k))
|
||||
}
|
||||
l.SetGlobal("KEYS", keysTable)
|
||||
|
||||
argvTable := l.NewTable()
|
||||
for i, a := range args {
|
||||
l.RawSet(argvTable, lua.LNumber(i+1), lua.LString(a))
|
||||
}
|
||||
l.SetGlobal("ARGV", argvTable)
|
||||
|
||||
redisFuncs, redisConstants := mkLua(m.srv, c, sha)
|
||||
// Register command handlers
|
||||
l.Push(l.NewFunction(func(l *lua.LState) int {
|
||||
mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable)
|
||||
for k, v := range redisConstants {
|
||||
mod.RawSetString(k, v)
|
||||
}
|
||||
l.Push(mod)
|
||||
return 1
|
||||
}))
|
||||
|
||||
_ = doScript(l, protectGlobals)
|
||||
|
||||
l.Push(lua.LString("redis"))
|
||||
l.Call(1, 0)
|
||||
|
||||
// lua can call redis.setresp(...), but it's tmp state.
|
||||
oldresp := c.Resp3
|
||||
if err := doScript(l, script); err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
luaToRedis(l, c, l.Get(1))
|
||||
c.Resp3 = oldresp
|
||||
c.SwitchResp3 = nil
|
||||
return true
|
||||
}
|
||||
|
||||
// doScript pre-compiles the given script into a Lua prototype,
|
||||
// then executes the pre-compiled function against the given lua state.
|
||||
//
|
||||
// This is thread-safe.
|
||||
func doScript(l *lua.LState, script string) error {
|
||||
proto, err := compile(script)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errLuaParseError(err))
|
||||
}
|
||||
|
||||
lfunc := l.NewFunctionFromProto(proto)
|
||||
l.Push(lfunc)
|
||||
if err := l.PCall(0, lua.MultRet, nil); err != nil {
|
||||
// ensure we wrap with the correct format.
|
||||
return fmt.Errorf(errLuaParseError(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func compile(script string) (*lua.FunctionProto, error) {
|
||||
if val, ok := parsedScripts.Load(script); ok {
|
||||
return val.(*lua.FunctionProto), nil
|
||||
}
|
||||
chunk, err := parse.Parse(strings.NewReader(script), "<string>")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
proto, err := lua.Compile(chunk, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsedScripts.Store(script, proto)
|
||||
return proto, nil
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
|
||||
return
|
||||
}
|
||||
|
||||
script, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
sha := sha1Hex(script)
|
||||
ok := m.runLuaScript(c, sha, script, args)
|
||||
if ok {
|
||||
m.scripts[sha] = script
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
|
||||
return
|
||||
}
|
||||
|
||||
sha, args := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
script, ok := m.scripts[sha]
|
||||
if !ok {
|
||||
c.WriteError(msgNoScriptFound)
|
||||
return
|
||||
}
|
||||
|
||||
m.runLuaScript(c, sha, script, args)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
subcmd string
|
||||
script string
|
||||
}
|
||||
|
||||
opts.subcmd, args = args[0], args[1:]
|
||||
|
||||
switch strings.ToLower(opts.subcmd) {
|
||||
case "load":
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf(msgFScriptUsage, "LOAD"))
|
||||
return
|
||||
}
|
||||
opts.script = args[0]
|
||||
case "exists":
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber("script|exists"))
|
||||
return
|
||||
}
|
||||
case "flush":
|
||||
if len(args) == 1 {
|
||||
switch strings.ToUpper(args[0]) {
|
||||
case "SYNC", "ASYNC":
|
||||
args = args[1:]
|
||||
default:
|
||||
}
|
||||
}
|
||||
if len(args) != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgScriptFlush)
|
||||
return
|
||||
}
|
||||
default:
|
||||
setDirty(c)
|
||||
c.WriteError(fmt.Sprintf(msgFScriptUsageSimple, strings.ToUpper(opts.subcmd)))
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
switch strings.ToLower(opts.subcmd) {
|
||||
case "load":
|
||||
if _, err := parse.Parse(strings.NewReader(opts.script), "user_script"); err != nil {
|
||||
c.WriteError(errLuaParseError(err))
|
||||
return
|
||||
}
|
||||
sha := sha1Hex(opts.script)
|
||||
m.scripts[sha] = opts.script
|
||||
c.WriteBulk(sha)
|
||||
case "exists":
|
||||
c.WriteLen(len(args))
|
||||
for _, arg := range args {
|
||||
if _, ok := m.scripts[arg]; ok {
|
||||
c.WriteInt(1)
|
||||
} else {
|
||||
c.WriteInt(0)
|
||||
}
|
||||
}
|
||||
case "flush":
|
||||
m.scripts = map[string]string{}
|
||||
c.WriteOK()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func sha1Hex(s string) string {
|
||||
h := sha1.New()
|
||||
io.WriteString(h, s)
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// requireGlobal imports module modName into the global namespace with the
|
||||
// identifier id. panics if an error results from the function execution
|
||||
func requireGlobal(l *lua.LState, id, modName string) {
|
||||
if err := l.CallByParam(lua.P{
|
||||
Fn: l.GetGlobal("require"),
|
||||
NRet: 1,
|
||||
Protect: true,
|
||||
}, lua.LString(modName)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mod := l.Get(-1)
|
||||
l.Pop(1)
|
||||
|
||||
l.SetGlobal(id, mod)
|
||||
}
|
||||
|
||||
// the following script protects globals
|
||||
// it is based on: http://metalua.luaforge.net/src/lib/strict.lua.html
|
||||
var protectGlobals = `
|
||||
local dbg=debug
|
||||
local mt = {}
|
||||
setmetatable(_G, mt)
|
||||
mt.__newindex = function (t, n, v)
|
||||
if dbg.getinfo(2) then
|
||||
local w = dbg.getinfo(2, "S").what
|
||||
if w ~= "C" then
|
||||
error("Script attempted to create global variable '"..tostring(n).."'", 2)
|
||||
end
|
||||
end
|
||||
rawset(t, n, v)
|
||||
end
|
||||
mt.__index = function (t, n)
|
||||
if dbg.getinfo(2) and dbg.getinfo(2, "S").what ~= "C" then
|
||||
error("Script attempted to access nonexistent global variable '"..tostring(n).."'", 2)
|
||||
end
|
||||
return rawget(t, n)
|
||||
end
|
||||
debug = nil
|
||||
|
||||
`
|
||||
+177
@@ -0,0 +1,177 @@
|
||||
// Commands from https://redis.io/commands#server
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
"github.com/alicebob/miniredis/v2/size"
|
||||
)
|
||||
|
||||
func commandsServer(m *Miniredis) {
|
||||
m.srv.Register("COMMAND", m.cmdCommand)
|
||||
m.srv.Register("DBSIZE", m.cmdDbsize)
|
||||
m.srv.Register("FLUSHALL", m.cmdFlushall)
|
||||
m.srv.Register("FLUSHDB", m.cmdFlushdb)
|
||||
m.srv.Register("INFO", m.cmdInfo)
|
||||
m.srv.Register("TIME", m.cmdTime)
|
||||
m.srv.Register("MEMORY", m.cmdMemory)
|
||||
}
|
||||
|
||||
// MEMORY
|
||||
func (m *Miniredis) cmdMemory(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
cmd, args := strings.ToLower(args[0]), args[1:]
|
||||
switch cmd {
|
||||
case "usage":
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber("memory|usage"))
|
||||
return
|
||||
}
|
||||
if len(args) > 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
value interface{}
|
||||
ok bool
|
||||
)
|
||||
switch db.keys[args[0]] {
|
||||
case keyTypeString:
|
||||
value, ok = db.stringKeys[args[0]]
|
||||
case keyTypeSet:
|
||||
value, ok = db.setKeys[args[0]]
|
||||
case keyTypeHash:
|
||||
value, ok = db.hashKeys[args[0]]
|
||||
case keyTypeList:
|
||||
value, ok = db.listKeys[args[0]]
|
||||
case keyTypeHll:
|
||||
value, ok = db.hllKeys[args[0]]
|
||||
case keyTypeSortedSet:
|
||||
value, ok = db.sortedsetKeys[args[0]]
|
||||
case keyTypeStream:
|
||||
value, ok = db.streamKeys[args[0]]
|
||||
}
|
||||
if !ok {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
c.WriteInt(size.Of(value))
|
||||
default:
|
||||
c.WriteError(fmt.Sprintf(msgMemorySubcommand, strings.ToUpper(cmd)))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// DBSIZE
|
||||
func (m *Miniredis) cmdDbsize(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
c.WriteInt(len(db.keys))
|
||||
})
|
||||
}
|
||||
|
||||
// FLUSHALL
|
||||
func (m *Miniredis) cmdFlushall(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) > 0 && strings.ToLower(args[0]) == "async" {
|
||||
args = args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
m.flushAll()
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// FLUSHDB
|
||||
func (m *Miniredis) cmdFlushdb(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) > 0 && strings.ToLower(args[0]) == "async" {
|
||||
args = args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
m.db(ctx.selectedDB).flush()
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
|
||||
// TIME
|
||||
func (m *Miniredis) cmdTime(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
now := m.effectiveNow()
|
||||
nanos := now.UnixNano()
|
||||
seconds := nanos / 1_000_000_000
|
||||
microseconds := (nanos / 1_000) % 1_000_000
|
||||
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(strconv.FormatInt(seconds, 10))
|
||||
c.WriteBulk(strconv.FormatInt(microseconds, 10))
|
||||
})
|
||||
}
|
||||
+836
@@ -0,0 +1,836 @@
|
||||
// Commands from https://redis.io/commands#set
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsSet handles all set value operations.
|
||||
func commandsSet(m *Miniredis) {
|
||||
m.srv.Register("SADD", m.cmdSadd)
|
||||
m.srv.Register("SCARD", m.cmdScard)
|
||||
m.srv.Register("SDIFF", m.cmdSdiff)
|
||||
m.srv.Register("SDIFFSTORE", m.cmdSdiffstore)
|
||||
m.srv.Register("SINTERCARD", m.cmdSintercard)
|
||||
m.srv.Register("SINTER", m.cmdSinter)
|
||||
m.srv.Register("SINTERSTORE", m.cmdSinterstore)
|
||||
m.srv.Register("SISMEMBER", m.cmdSismember)
|
||||
m.srv.Register("SMEMBERS", m.cmdSmembers)
|
||||
m.srv.Register("SMISMEMBER", m.cmdSmismember)
|
||||
m.srv.Register("SMOVE", m.cmdSmove)
|
||||
m.srv.Register("SPOP", m.cmdSpop)
|
||||
m.srv.Register("SRANDMEMBER", m.cmdSrandmember)
|
||||
m.srv.Register("SREM", m.cmdSrem)
|
||||
m.srv.Register("SUNION", m.cmdSunion)
|
||||
m.srv.Register("SUNIONSTORE", m.cmdSunionstore)
|
||||
m.srv.Register("SSCAN", m.cmdSscan)
|
||||
}
|
||||
|
||||
// SADD
|
||||
func (m *Miniredis) cmdSadd(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, elems := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if db.exists(key) && db.t(key) != keyTypeSet {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
added := db.setAdd(key, elems...)
|
||||
c.WriteInt(added)
|
||||
})
|
||||
}
|
||||
|
||||
// SCARD
|
||||
func (m *Miniredis) cmdScard(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
members := db.setMembers(key)
|
||||
c.WriteInt(len(members))
|
||||
})
|
||||
}
|
||||
|
||||
// SDIFF
|
||||
func (m *Miniredis) cmdSdiff(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setDiff(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteSetLen(len(set))
|
||||
for k := range set {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SDIFFSTORE
|
||||
func (m *Miniredis) cmdSdiffstore(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
dest, keys := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setDiff(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db.del(dest, true)
|
||||
db.setSet(dest, set)
|
||||
c.WriteInt(len(set))
|
||||
})
|
||||
}
|
||||
|
||||
// SINTER
|
||||
func (m *Miniredis) cmdSinter(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setInter(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(set))
|
||||
for k := range set {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SINTERSTORE
|
||||
func (m *Miniredis) cmdSinterstore(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
dest, keys := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setInter(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db.del(dest, true)
|
||||
db.setSet(dest, set)
|
||||
c.WriteInt(len(set))
|
||||
})
|
||||
}
|
||||
|
||||
// SINTERCARD
|
||||
func (m *Miniredis) cmdSintercard(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
keys []string
|
||||
limit int
|
||||
}{}
|
||||
|
||||
numKeys, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR numkeys should be greater than 0")
|
||||
return
|
||||
}
|
||||
if numKeys < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR numkeys should be greater than 0")
|
||||
return
|
||||
}
|
||||
|
||||
args = args[1:]
|
||||
if len(args) < numKeys {
|
||||
setDirty(c)
|
||||
c.WriteError("ERR Number of keys can't be greater than number of args")
|
||||
return
|
||||
}
|
||||
opts.keys = args[:numKeys]
|
||||
|
||||
args = args[numKeys:]
|
||||
if len(args) == 2 && strings.ToLower(args[0]) == "limit" {
|
||||
l, err := strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if l < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgLimitIsNegative)
|
||||
return
|
||||
}
|
||||
opts.limit = l
|
||||
} else if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
count, err := db.setIntercard(opts.keys, opts.limit)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
c.WriteInt(count)
|
||||
})
|
||||
}
|
||||
|
||||
// SISMEMBER
|
||||
func (m *Miniredis) cmdSismember(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, value := args[0], args[1]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if db.setIsMember(key, value) {
|
||||
c.WriteInt(1)
|
||||
return
|
||||
}
|
||||
c.WriteInt(0)
|
||||
})
|
||||
}
|
||||
|
||||
// SMEMBERS
|
||||
func (m *Miniredis) cmdSmembers(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteSetLen(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
members := db.setMembers(key)
|
||||
|
||||
c.WriteSetLen(len(members))
|
||||
for _, elem := range members {
|
||||
c.WriteBulk(elem)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SMISMEMBER
|
||||
func (m *Miniredis) cmdSmismember(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, values := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteLen(len(values))
|
||||
for range values {
|
||||
c.WriteInt(0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(values))
|
||||
for _, value := range values {
|
||||
if db.setIsMember(key, value) {
|
||||
c.WriteInt(1)
|
||||
} else {
|
||||
c.WriteInt(0)
|
||||
}
|
||||
}
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
// SMOVE
|
||||
func (m *Miniredis) cmdSmove(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 3 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
src, dst, member := args[0], args[1], args[2]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(src) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(src) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if db.exists(dst) && db.t(dst) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !db.setIsMember(src, member) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
db.setRem(src, member)
|
||||
db.setAdd(dst, member)
|
||||
c.WriteInt(1)
|
||||
})
|
||||
}
|
||||
|
||||
// SPOP
|
||||
func (m *Miniredis) cmdSpop(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := struct {
|
||||
key string
|
||||
withCount bool
|
||||
count int
|
||||
}{
|
||||
count: 1,
|
||||
}
|
||||
opts.key, args = args[0], args[1:]
|
||||
|
||||
if len(args) > 0 {
|
||||
v, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if v < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgOutOfRange)
|
||||
return
|
||||
}
|
||||
opts.count = v
|
||||
opts.withCount = true
|
||||
args = args[1:]
|
||||
}
|
||||
if len(args) > 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(opts.key) {
|
||||
if !opts.withCount {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
c.WriteLen(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(opts.key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var deleted []string
|
||||
members := db.setMembers(opts.key)
|
||||
for i := 0; i < opts.count; i++ {
|
||||
if len(members) == 0 {
|
||||
break
|
||||
}
|
||||
i := m.randIntn(len(members))
|
||||
member := members[i]
|
||||
members = delElem(members, i)
|
||||
db.setRem(opts.key, member)
|
||||
deleted = append(deleted, member)
|
||||
}
|
||||
// without `count` return a single value
|
||||
if !opts.withCount {
|
||||
if len(deleted) == 0 {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
c.WriteBulk(deleted[0])
|
||||
return
|
||||
}
|
||||
// with `count` return a list
|
||||
c.WriteLen(len(deleted))
|
||||
for _, v := range deleted {
|
||||
c.WriteBulk(v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SRANDMEMBER
|
||||
func (m *Miniredis) cmdSrandmember(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if len(args) > 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key := args[0]
|
||||
count := 0
|
||||
withCount := false
|
||||
if len(args) == 2 {
|
||||
var err error
|
||||
count, err = strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
withCount = true
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
if withCount {
|
||||
c.WriteLen(0)
|
||||
return
|
||||
}
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
members := db.setMembers(key)
|
||||
if count < 0 {
|
||||
// Non-unique elements is allowed with negative count.
|
||||
c.WriteLen(-count)
|
||||
for count != 0 {
|
||||
member := members[m.randIntn(len(members))]
|
||||
c.WriteBulk(member)
|
||||
count++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Must be unique elements.
|
||||
m.shuffle(members)
|
||||
if count > len(members) {
|
||||
count = len(members)
|
||||
}
|
||||
if !withCount {
|
||||
c.WriteBulk(members[0])
|
||||
return
|
||||
}
|
||||
c.WriteLen(count)
|
||||
for i := range make([]struct{}, count) {
|
||||
c.WriteBulk(members[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SREM
|
||||
func (m *Miniredis) cmdSrem(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
key, fields := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
if !db.exists(key) {
|
||||
c.WriteInt(0)
|
||||
return
|
||||
}
|
||||
|
||||
if db.t(key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteInt(db.setRem(key, fields...))
|
||||
})
|
||||
}
|
||||
|
||||
// SUNION
|
||||
func (m *Miniredis) cmdSunion(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 1 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
keys := args
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setUnion(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.WriteLen(len(set))
|
||||
for k := range set {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SUNIONSTORE
|
||||
func (m *Miniredis) cmdSunionstore(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
dest, keys := args[0], args[1:]
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
set, err := db.setUnion(keys)
|
||||
if err != nil {
|
||||
c.WriteError(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
db.del(dest, true)
|
||||
db.setSet(dest, set)
|
||||
c.WriteInt(len(set))
|
||||
})
|
||||
}
|
||||
|
||||
// SSCAN
|
||||
func (m *Miniredis) cmdSscan(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
var opts struct {
|
||||
key string
|
||||
value int
|
||||
cursor int
|
||||
count int
|
||||
withMatch bool
|
||||
match string
|
||||
}
|
||||
|
||||
opts.key = args[0]
|
||||
if ok := optIntErr(c, args[1], &opts.cursor, msgInvalidCursor); !ok {
|
||||
return
|
||||
}
|
||||
args = args[2:]
|
||||
|
||||
// MATCH and COUNT options
|
||||
for len(args) > 0 {
|
||||
if strings.ToLower(args[0]) == "count" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
count, err := strconv.Atoi(args[1])
|
||||
if err != nil || count < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidInt)
|
||||
return
|
||||
}
|
||||
if count == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.count = count
|
||||
args = args[2:]
|
||||
continue
|
||||
}
|
||||
if strings.ToLower(args[0]) == "match" {
|
||||
if len(args) < 2 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
opts.withMatch = true
|
||||
opts.match = args[1]
|
||||
args = args[2:]
|
||||
continue
|
||||
}
|
||||
setDirty(c)
|
||||
c.WriteError(msgSyntaxError)
|
||||
return
|
||||
}
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
db := m.db(ctx.selectedDB)
|
||||
// return _all_ (matched) keys every time
|
||||
if db.exists(opts.key) && db.t(opts.key) != "set" {
|
||||
c.WriteError(ErrWrongType.Error())
|
||||
return
|
||||
}
|
||||
members := db.setMembers(opts.key)
|
||||
if opts.withMatch {
|
||||
members, _ = matchKeys(members, opts.match)
|
||||
}
|
||||
low := opts.cursor
|
||||
high := low + opts.count
|
||||
// validate high is correct
|
||||
if high > len(members) || high == 0 {
|
||||
high = len(members)
|
||||
}
|
||||
if opts.cursor > high {
|
||||
// invalid cursor
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk("0") // no next cursor
|
||||
c.WriteLen(0) // no elements
|
||||
return
|
||||
}
|
||||
cursorValue := low + opts.count
|
||||
if cursorValue > len(members) {
|
||||
cursorValue = 0 // no next cursor
|
||||
}
|
||||
members = members[low:high]
|
||||
c.WriteLen(2)
|
||||
c.WriteBulk(fmt.Sprintf("%d", cursorValue))
|
||||
c.WriteLen(len(members))
|
||||
for _, k := range members {
|
||||
c.WriteBulk(k)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func delElem(ls []string, i int) []string {
|
||||
// this swap+truncate is faster but changes behaviour:
|
||||
// ls[i] = ls[len(ls)-1]
|
||||
// ls = ls[:len(ls)-1]
|
||||
// so we do the dumb thing:
|
||||
ls = append(ls[:i], ls[i+1:]...)
|
||||
return ls
|
||||
}
|
||||
+2025
File diff suppressed because it is too large
Load Diff
+1812
File diff suppressed because it is too large
Load Diff
+1364
File diff suppressed because it is too large
Load Diff
+179
@@ -0,0 +1,179 @@
|
||||
// Commands from https://redis.io/commands#transactions
|
||||
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// commandsTransaction handles MULTI &c.
|
||||
func commandsTransaction(m *Miniredis) {
|
||||
m.srv.Register("DISCARD", m.cmdDiscard)
|
||||
m.srv.Register("EXEC", m.cmdExec)
|
||||
m.srv.Register("MULTI", m.cmdMulti)
|
||||
m.srv.Register("UNWATCH", m.cmdUnwatch)
|
||||
m.srv.Register("WATCH", m.cmdWatch)
|
||||
}
|
||||
|
||||
// MULTI
|
||||
func (m *Miniredis) cmdMulti(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
|
||||
return
|
||||
}
|
||||
if inTx(ctx) {
|
||||
c.WriteError("ERR MULTI calls can not be nested")
|
||||
return
|
||||
}
|
||||
|
||||
startTx(ctx)
|
||||
|
||||
c.WriteOK()
|
||||
}
|
||||
|
||||
// EXEC
|
||||
func (m *Miniredis) cmdExec(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
|
||||
return
|
||||
}
|
||||
if !inTx(ctx) {
|
||||
c.WriteError("ERR EXEC without MULTI")
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.dirtyTransaction {
|
||||
c.WriteError("EXECABORT Transaction discarded because of previous errors.")
|
||||
// a failed EXEC finishes the tx
|
||||
stopTx(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Check WATCHed keys.
|
||||
for t, version := range ctx.watch {
|
||||
if m.db(t.db).keyVersion[t.key] > version {
|
||||
// Abort! Abort!
|
||||
stopTx(ctx)
|
||||
c.WriteLen(-1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.WriteLen(len(ctx.transaction))
|
||||
for _, cb := range ctx.transaction {
|
||||
cb(c, ctx)
|
||||
}
|
||||
// wake up anyone who waits on anything.
|
||||
m.signal.Broadcast()
|
||||
|
||||
stopTx(ctx)
|
||||
}
|
||||
|
||||
// DISCARD
|
||||
func (m *Miniredis) cmdDiscard(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getCtx(c)
|
||||
if !inTx(ctx) {
|
||||
c.WriteError("ERR DISCARD without MULTI")
|
||||
return
|
||||
}
|
||||
|
||||
stopTx(ctx)
|
||||
c.WriteOK()
|
||||
}
|
||||
|
||||
// WATCH
|
||||
func (m *Miniredis) cmdWatch(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) == 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := getCtx(c)
|
||||
if ctx.nested {
|
||||
c.WriteError(msgNotFromScripts(ctx.nestedSHA))
|
||||
return
|
||||
}
|
||||
if inTx(ctx) {
|
||||
c.WriteError("ERR WATCH in MULTI")
|
||||
return
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
db := m.db(ctx.selectedDB)
|
||||
|
||||
for _, key := range args {
|
||||
watch(db, ctx, key)
|
||||
}
|
||||
c.WriteOK()
|
||||
}
|
||||
|
||||
// UNWATCH
|
||||
func (m *Miniredis) cmdUnwatch(c *server.Peer, cmd string, args []string) {
|
||||
if len(args) != 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(errWrongNumber(cmd))
|
||||
return
|
||||
}
|
||||
if !m.handleAuth(c) {
|
||||
return
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return
|
||||
}
|
||||
|
||||
// Doesn't matter if UNWATCH is in a TX or not. Looks like a Redis bug to me.
|
||||
unwatch(getCtx(c))
|
||||
|
||||
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
|
||||
// Do nothing if it's called in a transaction.
|
||||
c.WriteOK()
|
||||
})
|
||||
}
|
||||
+790
@@ -0,0 +1,790 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidEntryID = errors.New("stream ID is invalid")
|
||||
)
|
||||
|
||||
// exists also updates the lru
|
||||
func (db *RedisDB) exists(k string) bool {
|
||||
_, ok := db.keys[k]
|
||||
if ok {
|
||||
db.lru[k] = db.master.effectiveNow()
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// t gives the type of a key, or ""
|
||||
func (db *RedisDB) t(k string) string {
|
||||
return db.keys[k]
|
||||
}
|
||||
|
||||
// incr increases the version and the lru timestamp
|
||||
func (db *RedisDB) incr(k string) {
|
||||
db.lru[k] = db.master.effectiveNow()
|
||||
db.keyVersion[k]++
|
||||
}
|
||||
|
||||
// allKeys returns all keys. Sorted.
|
||||
func (db *RedisDB) allKeys() []string {
|
||||
res := make([]string, 0, len(db.keys))
|
||||
for k := range db.keys {
|
||||
res = append(res, k)
|
||||
}
|
||||
sort.Strings(res) // To make things deterministic.
|
||||
return res
|
||||
}
|
||||
|
||||
// flush removes all keys and values.
|
||||
func (db *RedisDB) flush() {
|
||||
db.keys = map[string]string{}
|
||||
db.lru = map[string]time.Time{}
|
||||
db.stringKeys = map[string]string{}
|
||||
db.hashKeys = map[string]hashKey{}
|
||||
db.listKeys = map[string]listKey{}
|
||||
db.setKeys = map[string]setKey{}
|
||||
db.hllKeys = map[string]*hll{}
|
||||
db.sortedsetKeys = map[string]sortedSet{}
|
||||
db.ttl = map[string]time.Duration{}
|
||||
db.streamKeys = map[string]*streamKey{}
|
||||
}
|
||||
|
||||
// move something to another db. Will return ok. Or not.
|
||||
func (db *RedisDB) move(key string, to *RedisDB) bool {
|
||||
if _, ok := to.keys[key]; ok {
|
||||
return false
|
||||
}
|
||||
|
||||
t, ok := db.keys[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
to.keys[key] = db.keys[key]
|
||||
switch t {
|
||||
case keyTypeString:
|
||||
to.stringKeys[key] = db.stringKeys[key]
|
||||
case keyTypeHash:
|
||||
to.hashKeys[key] = db.hashKeys[key]
|
||||
case keyTypeList:
|
||||
to.listKeys[key] = db.listKeys[key]
|
||||
case keyTypeSet:
|
||||
to.setKeys[key] = db.setKeys[key]
|
||||
case keyTypeSortedSet:
|
||||
to.sortedsetKeys[key] = db.sortedsetKeys[key]
|
||||
case keyTypeStream:
|
||||
to.streamKeys[key] = db.streamKeys[key]
|
||||
case keyTypeHll:
|
||||
to.hllKeys[key] = db.hllKeys[key]
|
||||
default:
|
||||
panic("unhandled key type")
|
||||
}
|
||||
if v, ok := db.ttl[key]; ok {
|
||||
to.ttl[key] = v
|
||||
}
|
||||
to.incr(key)
|
||||
db.del(key, true)
|
||||
return true
|
||||
}
|
||||
|
||||
func (db *RedisDB) rename(from, to string) {
|
||||
db.del(to, true)
|
||||
switch db.t(from) {
|
||||
case keyTypeString:
|
||||
db.stringKeys[to] = db.stringKeys[from]
|
||||
case keyTypeHash:
|
||||
db.hashKeys[to] = db.hashKeys[from]
|
||||
case keyTypeList:
|
||||
db.listKeys[to] = db.listKeys[from]
|
||||
case keyTypeSet:
|
||||
db.setKeys[to] = db.setKeys[from]
|
||||
case keyTypeSortedSet:
|
||||
db.sortedsetKeys[to] = db.sortedsetKeys[from]
|
||||
case keyTypeStream:
|
||||
db.streamKeys[to] = db.streamKeys[from]
|
||||
case keyTypeHll:
|
||||
db.hllKeys[to] = db.hllKeys[from]
|
||||
default:
|
||||
panic("missing case")
|
||||
}
|
||||
db.keys[to] = db.keys[from]
|
||||
if v, ok := db.ttl[from]; ok {
|
||||
db.ttl[to] = v
|
||||
}
|
||||
db.incr(to)
|
||||
|
||||
db.del(from, true)
|
||||
}
|
||||
|
||||
func (db *RedisDB) del(k string, delTTL bool) {
|
||||
if !db.exists(k) {
|
||||
return
|
||||
}
|
||||
t := db.t(k)
|
||||
delete(db.keys, k)
|
||||
delete(db.lru, k)
|
||||
db.keyVersion[k]++
|
||||
if delTTL {
|
||||
delete(db.ttl, k)
|
||||
}
|
||||
switch t {
|
||||
case keyTypeString:
|
||||
delete(db.stringKeys, k)
|
||||
case keyTypeHash:
|
||||
delete(db.hashKeys, k)
|
||||
case keyTypeList:
|
||||
delete(db.listKeys, k)
|
||||
case keyTypeSet:
|
||||
delete(db.setKeys, k)
|
||||
case keyTypeSortedSet:
|
||||
delete(db.sortedsetKeys, k)
|
||||
case keyTypeStream:
|
||||
delete(db.streamKeys, k)
|
||||
case keyTypeHll:
|
||||
delete(db.hllKeys, k)
|
||||
default:
|
||||
panic("Unknown key type: " + t)
|
||||
}
|
||||
}
|
||||
|
||||
// stringGet returns the string key or "" on error/nonexists.
|
||||
func (db *RedisDB) stringGet(k string) string {
|
||||
if t, ok := db.keys[k]; !ok || t != keyTypeString {
|
||||
return ""
|
||||
}
|
||||
return db.stringKeys[k]
|
||||
}
|
||||
|
||||
// stringSet force set()s a key. Does not touch expire.
|
||||
func (db *RedisDB) stringSet(k, v string) {
|
||||
db.del(k, false)
|
||||
db.keys[k] = keyTypeString
|
||||
db.stringKeys[k] = v
|
||||
db.incr(k)
|
||||
}
|
||||
|
||||
// change int key value
|
||||
func (db *RedisDB) stringIncr(k string, delta int) (int, error) {
|
||||
v := 0
|
||||
if sv, ok := db.stringKeys[k]; ok {
|
||||
var err error
|
||||
v, err = strconv.Atoi(sv)
|
||||
if err != nil {
|
||||
return 0, ErrIntValueError
|
||||
}
|
||||
}
|
||||
|
||||
if delta > 0 {
|
||||
if math.MaxInt-delta < v {
|
||||
return 0, ErrIntValueOverflowError
|
||||
}
|
||||
} else {
|
||||
if math.MinInt-delta > v {
|
||||
return 0, ErrIntValueOverflowError
|
||||
}
|
||||
}
|
||||
|
||||
v += delta
|
||||
db.stringSet(k, strconv.Itoa(v))
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// change float key value
|
||||
func (db *RedisDB) stringIncrfloat(k string, delta *big.Float) (*big.Float, error) {
|
||||
v := big.NewFloat(0.0)
|
||||
v.SetPrec(128)
|
||||
if sv, ok := db.stringKeys[k]; ok {
|
||||
var err error
|
||||
v, _, err = big.ParseFloat(sv, 10, 128, 0)
|
||||
if err != nil {
|
||||
return nil, ErrFloatValueError
|
||||
}
|
||||
}
|
||||
v.Add(v, delta)
|
||||
db.stringSet(k, formatBig(v))
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// listLpush is 'left push', aka unshift. Returns the new length.
|
||||
func (db *RedisDB) listLpush(k, v string) int {
|
||||
l, ok := db.listKeys[k]
|
||||
if !ok {
|
||||
db.keys[k] = keyTypeList
|
||||
}
|
||||
l = append([]string{v}, l...)
|
||||
db.listKeys[k] = l
|
||||
db.incr(k)
|
||||
return len(l)
|
||||
}
|
||||
|
||||
// 'left pop', aka shift.
|
||||
func (db *RedisDB) listLpop(k string) string {
|
||||
l := db.listKeys[k]
|
||||
el := l[0]
|
||||
l = l[1:]
|
||||
if len(l) == 0 {
|
||||
db.del(k, true)
|
||||
} else {
|
||||
db.listKeys[k] = l
|
||||
}
|
||||
db.incr(k)
|
||||
return el
|
||||
}
|
||||
|
||||
func (db *RedisDB) listPush(k string, v ...string) int {
|
||||
l, ok := db.listKeys[k]
|
||||
if !ok {
|
||||
db.keys[k] = keyTypeList
|
||||
}
|
||||
l = append(l, v...)
|
||||
db.listKeys[k] = l
|
||||
db.incr(k)
|
||||
return len(l)
|
||||
}
|
||||
|
||||
func (db *RedisDB) listPop(k string) string {
|
||||
l := db.listKeys[k]
|
||||
el := l[len(l)-1]
|
||||
l = l[:len(l)-1]
|
||||
if len(l) == 0 {
|
||||
db.del(k, true)
|
||||
} else {
|
||||
db.listKeys[k] = l
|
||||
db.incr(k)
|
||||
}
|
||||
return el
|
||||
}
|
||||
|
||||
// setset replaces a whole set.
|
||||
func (db *RedisDB) setSet(k string, set setKey) {
|
||||
db.keys[k] = keyTypeSet
|
||||
db.setKeys[k] = set
|
||||
db.incr(k)
|
||||
}
|
||||
|
||||
// setadd adds members to a set. Returns nr of new keys.
|
||||
func (db *RedisDB) setAdd(k string, elems ...string) int {
|
||||
s, ok := db.setKeys[k]
|
||||
if !ok {
|
||||
s = setKey{}
|
||||
db.keys[k] = keyTypeSet
|
||||
}
|
||||
added := 0
|
||||
for _, e := range elems {
|
||||
if _, ok := s[e]; !ok {
|
||||
added++
|
||||
}
|
||||
s[e] = struct{}{}
|
||||
}
|
||||
db.setKeys[k] = s
|
||||
db.incr(k)
|
||||
return added
|
||||
}
|
||||
|
||||
// setrem removes members from a set. Returns nr of deleted keys.
|
||||
func (db *RedisDB) setRem(k string, fields ...string) int {
|
||||
s, ok := db.setKeys[k]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
removed := 0
|
||||
for _, f := range fields {
|
||||
if _, ok := s[f]; ok {
|
||||
removed++
|
||||
delete(s, f)
|
||||
}
|
||||
}
|
||||
if len(s) == 0 {
|
||||
db.del(k, true)
|
||||
} else {
|
||||
db.setKeys[k] = s
|
||||
}
|
||||
db.incr(k)
|
||||
return removed
|
||||
}
|
||||
|
||||
// All members of a set.
|
||||
func (db *RedisDB) setMembers(k string) []string {
|
||||
set := db.setKeys[k]
|
||||
members := make([]string, 0, len(set))
|
||||
for k := range set {
|
||||
members = append(members, k)
|
||||
}
|
||||
sort.Strings(members)
|
||||
return members
|
||||
}
|
||||
|
||||
// Is a SET value present?
|
||||
func (db *RedisDB) setIsMember(k, v string) bool {
|
||||
set, ok := db.setKeys[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, ok = set[v]
|
||||
return ok
|
||||
}
|
||||
|
||||
// hashFields returns all (sorted) keys ('fields') for a hash key.
|
||||
func (db *RedisDB) hashFields(k string) []string {
|
||||
v := db.hashKeys[k]
|
||||
var r []string
|
||||
for k := range v {
|
||||
r = append(r, k)
|
||||
}
|
||||
sort.Strings(r)
|
||||
return r
|
||||
}
|
||||
|
||||
// hashValues returns all (sorted) values a hash key.
|
||||
func (db *RedisDB) hashValues(k string) []string {
|
||||
h := db.hashKeys[k]
|
||||
var r []string
|
||||
for _, v := range h {
|
||||
r = append(r, v)
|
||||
}
|
||||
sort.Strings(r)
|
||||
return r
|
||||
}
|
||||
|
||||
// hashGet a value
|
||||
func (db *RedisDB) hashGet(key, field string) string {
|
||||
return db.hashKeys[key][field]
|
||||
}
|
||||
|
||||
// hashSet returns the number of new keys
|
||||
func (db *RedisDB) hashSet(k string, fv ...string) int {
|
||||
if t, ok := db.keys[k]; ok && t != keyTypeHash {
|
||||
db.del(k, true)
|
||||
}
|
||||
db.keys[k] = keyTypeHash
|
||||
if _, ok := db.hashKeys[k]; !ok {
|
||||
db.hashKeys[k] = map[string]string{}
|
||||
}
|
||||
new := 0
|
||||
for idx := 0; idx < len(fv)-1; idx = idx + 2 {
|
||||
f, v := fv[idx], fv[idx+1]
|
||||
_, ok := db.hashKeys[k][f]
|
||||
db.hashKeys[k][f] = v
|
||||
db.incr(k)
|
||||
if !ok {
|
||||
new++
|
||||
}
|
||||
}
|
||||
return new
|
||||
}
|
||||
|
||||
// hashIncr changes int key value
|
||||
func (db *RedisDB) hashIncr(key, field string, delta int) (int, error) {
|
||||
v := 0
|
||||
if h, ok := db.hashKeys[key]; ok {
|
||||
if f, ok := h[field]; ok {
|
||||
var err error
|
||||
v, err = strconv.Atoi(f)
|
||||
if err != nil {
|
||||
return 0, ErrIntValueError
|
||||
}
|
||||
}
|
||||
}
|
||||
v += delta
|
||||
db.hashSet(key, field, strconv.Itoa(v))
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// hashIncrfloat changes float key value
|
||||
func (db *RedisDB) hashIncrfloat(key, field string, delta *big.Float) (*big.Float, error) {
|
||||
v := big.NewFloat(0.0)
|
||||
v.SetPrec(128)
|
||||
if h, ok := db.hashKeys[key]; ok {
|
||||
if f, ok := h[field]; ok {
|
||||
var err error
|
||||
v, _, err = big.ParseFloat(f, 10, 128, 0)
|
||||
if err != nil {
|
||||
return nil, ErrFloatValueError
|
||||
}
|
||||
}
|
||||
}
|
||||
v.Add(v, delta)
|
||||
db.hashSet(key, field, formatBig(v))
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// sortedSet set returns a sortedSet as map
|
||||
func (db *RedisDB) sortedSet(key string) map[string]float64 {
|
||||
ss := db.sortedsetKeys[key]
|
||||
return map[string]float64(ss)
|
||||
}
|
||||
|
||||
// ssetSet sets a complete sorted set.
|
||||
func (db *RedisDB) ssetSet(key string, sset sortedSet) {
|
||||
db.keys[key] = keyTypeSortedSet
|
||||
db.incr(key)
|
||||
db.sortedsetKeys[key] = sset
|
||||
}
|
||||
|
||||
// ssetAdd adds member to a sorted set. Returns whether this was a new member.
|
||||
func (db *RedisDB) ssetAdd(key string, score float64, member string) bool {
|
||||
ss, ok := db.sortedsetKeys[key]
|
||||
if !ok {
|
||||
ss = newSortedSet()
|
||||
db.keys[key] = keyTypeSortedSet
|
||||
}
|
||||
_, ok = ss[member]
|
||||
ss[member] = score
|
||||
db.sortedsetKeys[key] = ss
|
||||
db.incr(key)
|
||||
return !ok
|
||||
}
|
||||
|
||||
// All members from a sorted set, ordered by score.
|
||||
func (db *RedisDB) ssetMembers(key string) []string {
|
||||
ss, ok := db.sortedsetKeys[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
elems := ss.byScore(asc)
|
||||
members := make([]string, 0, len(elems))
|
||||
for _, e := range elems {
|
||||
members = append(members, e.member)
|
||||
}
|
||||
return members
|
||||
}
|
||||
|
||||
// All members+scores from a sorted set, ordered by score.
|
||||
func (db *RedisDB) ssetElements(key string) ssElems {
|
||||
ss, ok := db.sortedsetKeys[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return ss.byScore(asc)
|
||||
}
|
||||
|
||||
func (db *RedisDB) ssetRandomMember(key string) string {
|
||||
elems := db.ssetElements(key)
|
||||
if len(elems) == 0 {
|
||||
return ""
|
||||
}
|
||||
return elems[db.master.randIntn(len(elems))].member
|
||||
}
|
||||
|
||||
// ssetCard is the sorted set cardinality.
|
||||
func (db *RedisDB) ssetCard(key string) int {
|
||||
ss := db.sortedsetKeys[key]
|
||||
return ss.card()
|
||||
}
|
||||
|
||||
// ssetRank is the sorted set rank.
|
||||
func (db *RedisDB) ssetRank(key, member string, d direction) (int, bool) {
|
||||
ss := db.sortedsetKeys[key]
|
||||
return ss.rankByScore(member, d)
|
||||
}
|
||||
|
||||
// ssetScore is sorted set score.
|
||||
func (db *RedisDB) ssetScore(key, member string) float64 {
|
||||
ss := db.sortedsetKeys[key]
|
||||
return ss[member]
|
||||
}
|
||||
|
||||
// ssetMScore returns multiple scores of a list of members in a sorted set.
|
||||
func (db *RedisDB) ssetMScore(key string, members []string) []float64 {
|
||||
scores := make([]float64, 0, len(members))
|
||||
ss := db.sortedsetKeys[key]
|
||||
for _, member := range members {
|
||||
scores = append(scores, ss[member])
|
||||
}
|
||||
return scores
|
||||
}
|
||||
|
||||
// ssetRem is sorted set key delete.
|
||||
func (db *RedisDB) ssetRem(key, member string) bool {
|
||||
ss := db.sortedsetKeys[key]
|
||||
_, ok := ss[member]
|
||||
delete(ss, member)
|
||||
if len(ss) == 0 {
|
||||
// Delete key on removal of last member
|
||||
db.del(key, true)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// ssetExists tells if a member exists in a sorted set.
|
||||
func (db *RedisDB) ssetExists(key, member string) bool {
|
||||
ss := db.sortedsetKeys[key]
|
||||
_, ok := ss[member]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ssetIncrby changes float sorted set score.
|
||||
func (db *RedisDB) ssetIncrby(k, m string, delta float64) float64 {
|
||||
ss, ok := db.sortedsetKeys[k]
|
||||
if !ok {
|
||||
ss = newSortedSet()
|
||||
db.keys[k] = keyTypeSortedSet
|
||||
db.sortedsetKeys[k] = ss
|
||||
}
|
||||
|
||||
v, _ := ss.get(m)
|
||||
v += delta
|
||||
ss.set(v, m)
|
||||
db.incr(k)
|
||||
return v
|
||||
}
|
||||
|
||||
// setDiff implements the logic behind SDIFF*
|
||||
func (db *RedisDB) setDiff(keys []string) (setKey, error) {
|
||||
key := keys[0]
|
||||
keys = keys[1:]
|
||||
if db.exists(key) && db.t(key) != keyTypeSet {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
s := setKey{}
|
||||
for k := range db.setKeys[key] {
|
||||
s[k] = struct{}{}
|
||||
}
|
||||
for _, sk := range keys {
|
||||
if !db.exists(sk) {
|
||||
continue
|
||||
}
|
||||
if db.t(sk) != keyTypeSet {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
for e := range db.setKeys[sk] {
|
||||
delete(s, e)
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// setInter implements the logic behind SINTER*
|
||||
// len keys needs to be > 0
|
||||
func (db *RedisDB) setInter(keys []string) (setKey, error) {
|
||||
// all keys must either not exist, or be of type "set".
|
||||
for _, key := range keys {
|
||||
if db.exists(key) && db.t(key) != keyTypeSet {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
}
|
||||
|
||||
key := keys[0]
|
||||
keys = keys[1:]
|
||||
if !db.exists(key) {
|
||||
return nil, nil
|
||||
}
|
||||
if db.t(key) != keyTypeSet {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
s := setKey{}
|
||||
for k := range db.setKeys[key] {
|
||||
s[k] = struct{}{}
|
||||
}
|
||||
for _, sk := range keys {
|
||||
if !db.exists(sk) {
|
||||
return setKey{}, nil
|
||||
}
|
||||
if db.t(sk) != keyTypeSet {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
other := db.setKeys[sk]
|
||||
for e := range s {
|
||||
if _, ok := other[e]; ok {
|
||||
continue
|
||||
}
|
||||
delete(s, e)
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// setIntercard implements the logic behind SINTER*
|
||||
// len keys needs to be > 0
|
||||
func (db *RedisDB) setIntercard(keys []string, limit int) (int, error) {
|
||||
// all keys must either not exist, or be of type "set".
|
||||
allExist := true
|
||||
for _, key := range keys {
|
||||
exists := db.exists(key)
|
||||
allExist = allExist && exists
|
||||
if exists && db.t(key) != "set" {
|
||||
return 0, ErrWrongType
|
||||
}
|
||||
}
|
||||
|
||||
if !allExist {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
smallestKey := keys[0]
|
||||
smallestIdx := 0
|
||||
for i, key := range keys {
|
||||
if len(db.setKeys[key]) < len(db.setKeys[smallestKey]) {
|
||||
smallestKey = key
|
||||
smallestIdx = i
|
||||
}
|
||||
}
|
||||
keys[smallestIdx] = keys[len(keys)-1]
|
||||
keys = keys[:len(keys)-1]
|
||||
|
||||
count := 0
|
||||
for item := range db.setKeys[smallestKey] {
|
||||
inIntersection := true
|
||||
for _, key := range keys {
|
||||
if _, ok := db.setKeys[key][item]; !ok {
|
||||
inIntersection = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if inIntersection {
|
||||
count++
|
||||
if count == limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// setUnion implements the logic behind SUNION*
|
||||
func (db *RedisDB) setUnion(keys []string) (setKey, error) {
|
||||
key := keys[0]
|
||||
keys = keys[1:]
|
||||
if db.exists(key) && db.t(key) != "set" {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
s := setKey{}
|
||||
for k := range db.setKeys[key] {
|
||||
s[k] = struct{}{}
|
||||
}
|
||||
for _, sk := range keys {
|
||||
if !db.exists(sk) {
|
||||
continue
|
||||
}
|
||||
if db.t(sk) != "set" {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
for e := range db.setKeys[sk] {
|
||||
s[e] = struct{}{}
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (db *RedisDB) newStream(key string) (*streamKey, error) {
|
||||
if s, err := db.stream(key); err != nil {
|
||||
return nil, err
|
||||
} else if s != nil {
|
||||
return nil, fmt.Errorf("ErrAlreadyExists")
|
||||
}
|
||||
|
||||
db.keys[key] = keyTypeStream
|
||||
s := newStreamKey()
|
||||
db.streamKeys[key] = s
|
||||
db.incr(key)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// return existing stream, or nil.
|
||||
func (db *RedisDB) stream(key string) (*streamKey, error) {
|
||||
if db.exists(key) && db.t(key) != keyTypeStream {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
|
||||
return db.streamKeys[key], nil
|
||||
}
|
||||
|
||||
// return existing stream group, or nil.
|
||||
func (db *RedisDB) streamGroup(key, group string) (*streamGroup, error) {
|
||||
s, err := db.stream(key)
|
||||
if err != nil || s == nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.groups[group], nil
|
||||
}
|
||||
|
||||
// fastForward proceeds the current timestamp with duration, works as a time machine
|
||||
func (db *RedisDB) fastForward(duration time.Duration) {
|
||||
for _, key := range db.allKeys() {
|
||||
if value, ok := db.ttl[key]; ok {
|
||||
db.ttl[key] = value - duration
|
||||
db.checkTTL(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (db *RedisDB) checkTTL(key string) {
|
||||
if v, ok := db.ttl[key]; ok && v <= 0 {
|
||||
db.del(key, true)
|
||||
}
|
||||
}
|
||||
|
||||
// hllAdd adds members to a hll. Returns 1 if at least 1 if internal HyperLogLog was altered, otherwise 0
|
||||
func (db *RedisDB) hllAdd(k string, elems ...string) int {
|
||||
s, ok := db.hllKeys[k]
|
||||
if !ok {
|
||||
s = newHll()
|
||||
db.keys[k] = keyTypeHll
|
||||
}
|
||||
hllAltered := 0
|
||||
for _, e := range elems {
|
||||
if s.Add([]byte(e)) {
|
||||
hllAltered = 1
|
||||
}
|
||||
}
|
||||
db.hllKeys[k] = s
|
||||
db.incr(k)
|
||||
return hllAltered
|
||||
}
|
||||
|
||||
// hllCount estimates the amount of members added to hll by hllAdd. If called with several arguments, hllCount returns a sum of estimations
|
||||
func (db *RedisDB) hllCount(keys []string) (int, error) {
|
||||
countOverall := 0
|
||||
for _, key := range keys {
|
||||
if db.exists(key) && db.t(key) != keyTypeHll {
|
||||
return 0, ErrNotValidHllValue
|
||||
}
|
||||
if !db.exists(key) {
|
||||
continue
|
||||
}
|
||||
countOverall += db.hllKeys[key].Count()
|
||||
}
|
||||
|
||||
return countOverall, nil
|
||||
}
|
||||
|
||||
// hllMerge merges all the hlls provided as keys to the first key. Creates a new hll in the first key if it contains nothing
|
||||
func (db *RedisDB) hllMerge(keys []string) error {
|
||||
for _, key := range keys {
|
||||
if db.exists(key) && db.t(key) != keyTypeHll {
|
||||
return ErrNotValidHllValue
|
||||
}
|
||||
}
|
||||
|
||||
destKey := keys[0]
|
||||
restKeys := keys[1:]
|
||||
|
||||
var destHll *hll
|
||||
if db.exists(destKey) {
|
||||
destHll = db.hllKeys[destKey]
|
||||
} else {
|
||||
destHll = newHll()
|
||||
}
|
||||
|
||||
for _, key := range restKeys {
|
||||
if !db.exists(key) {
|
||||
continue
|
||||
}
|
||||
destHll.Merge(db.hllKeys[key])
|
||||
}
|
||||
|
||||
db.hllKeys[destKey] = destHll
|
||||
db.keys[destKey] = keyTypeHll
|
||||
db.incr(destKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
+824
@@ -0,0 +1,824 @@
|
||||
package miniredis
|
||||
|
||||
// Commands to modify and query our databases directly.
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/big"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrKeyNotFound is returned when a key doesn't exist.
|
||||
ErrKeyNotFound = errors.New(msgKeyNotFound)
|
||||
|
||||
// ErrWrongType when a key is not the right type.
|
||||
ErrWrongType = errors.New(msgWrongType)
|
||||
|
||||
// ErrNotValidHllValue when a key is not a valid HyperLogLog string value.
|
||||
ErrNotValidHllValue = errors.New(msgNotValidHllValue)
|
||||
|
||||
// ErrIntValueError can returned by INCRBY
|
||||
ErrIntValueError = errors.New(msgInvalidInt)
|
||||
|
||||
// ErrIntValueOverflowError can be returned by INCR, DECR, INCRBY, DECRBY
|
||||
ErrIntValueOverflowError = errors.New(msgIntOverflow)
|
||||
|
||||
// ErrFloatValueError can returned by INCRBYFLOAT
|
||||
ErrFloatValueError = errors.New(msgInvalidFloat)
|
||||
)
|
||||
|
||||
// Select sets the DB id for all direct commands.
|
||||
func (m *Miniredis) Select(i int) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
m.selectedDB = i
|
||||
}
|
||||
|
||||
// Keys returns all keys from the selected database, sorted.
|
||||
func (m *Miniredis) Keys() []string {
|
||||
return m.DB(m.selectedDB).Keys()
|
||||
}
|
||||
|
||||
// Keys returns all keys, sorted.
|
||||
func (db *RedisDB) Keys() []string {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
return db.allKeys()
|
||||
}
|
||||
|
||||
// FlushAll removes all keys from all databases.
|
||||
func (m *Miniredis) FlushAll() {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
defer m.signal.Broadcast()
|
||||
|
||||
m.flushAll()
|
||||
}
|
||||
|
||||
func (m *Miniredis) flushAll() {
|
||||
for _, db := range m.dbs {
|
||||
db.flush()
|
||||
}
|
||||
}
|
||||
|
||||
// FlushDB removes all keys from the selected database.
|
||||
func (m *Miniredis) FlushDB() {
|
||||
m.DB(m.selectedDB).FlushDB()
|
||||
}
|
||||
|
||||
// FlushDB removes all keys.
|
||||
func (db *RedisDB) FlushDB() {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
db.flush()
|
||||
}
|
||||
|
||||
// Get returns string keys added with SET.
|
||||
func (m *Miniredis) Get(k string) (string, error) {
|
||||
return m.DB(m.selectedDB).Get(k)
|
||||
}
|
||||
|
||||
// Get returns a string key.
|
||||
func (db *RedisDB) Get(k string) (string, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
if !db.exists(k) {
|
||||
return "", ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeString {
|
||||
return "", ErrWrongType
|
||||
}
|
||||
return db.stringGet(k), nil
|
||||
}
|
||||
|
||||
// Set sets a string key. Removes expire.
|
||||
func (m *Miniredis) Set(k, v string) error {
|
||||
return m.DB(m.selectedDB).Set(k, v)
|
||||
}
|
||||
|
||||
// Set sets a string key. Removes expire.
|
||||
// Unlike redis the key can't be an existing non-string key.
|
||||
func (db *RedisDB) Set(k, v string) error {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if db.exists(k) && db.t(k) != keyTypeString {
|
||||
return ErrWrongType
|
||||
}
|
||||
db.del(k, true) // Remove expire
|
||||
db.stringSet(k, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Incr changes a int string value by delta.
|
||||
func (m *Miniredis) Incr(k string, delta int) (int, error) {
|
||||
return m.DB(m.selectedDB).Incr(k, delta)
|
||||
}
|
||||
|
||||
// Incr changes a int string value by delta.
|
||||
func (db *RedisDB) Incr(k string, delta int) (int, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if db.exists(k) && db.t(k) != keyTypeString {
|
||||
return 0, ErrWrongType
|
||||
}
|
||||
|
||||
return db.stringIncr(k, delta)
|
||||
}
|
||||
|
||||
// IncrByFloat increments the float value of a key by the given delta.
|
||||
// is an alias for Miniredis.Incrfloat
|
||||
func (m *Miniredis) IncrByFloat(k string, delta float64) (float64, error) {
|
||||
return m.Incrfloat(k, delta)
|
||||
}
|
||||
|
||||
// Incrfloat changes a float string value by delta.
|
||||
func (m *Miniredis) Incrfloat(k string, delta float64) (float64, error) {
|
||||
return m.DB(m.selectedDB).Incrfloat(k, delta)
|
||||
}
|
||||
|
||||
// Incrfloat changes a float string value by delta.
|
||||
func (db *RedisDB) Incrfloat(k string, delta float64) (float64, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if db.exists(k) && db.t(k) != keyTypeString {
|
||||
return 0, ErrWrongType
|
||||
}
|
||||
|
||||
v, err := db.stringIncrfloat(k, big.NewFloat(delta))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
vf, _ := v.Float64()
|
||||
return vf, nil
|
||||
}
|
||||
|
||||
// List returns the list k, or an error if it's not there or something else.
|
||||
// This is the same as the Redis command `LRANGE 0 -1`, but you can do your own
|
||||
// range-ing.
|
||||
func (m *Miniredis) List(k string) ([]string, error) {
|
||||
return m.DB(m.selectedDB).List(k)
|
||||
}
|
||||
|
||||
// List returns the list k, or an error if it's not there or something else.
|
||||
// This is the same as the Redis command `LRANGE 0 -1`, but you can do your own
|
||||
// range-ing.
|
||||
func (db *RedisDB) List(k string) ([]string, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
if !db.exists(k) {
|
||||
return nil, ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeList {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
return db.listKeys[k], nil
|
||||
}
|
||||
|
||||
// Lpush prepends one value to a list. Returns the new length.
|
||||
func (m *Miniredis) Lpush(k, v string) (int, error) {
|
||||
return m.DB(m.selectedDB).Lpush(k, v)
|
||||
}
|
||||
|
||||
// Lpush prepends one value to a list. Returns the new length.
|
||||
func (db *RedisDB) Lpush(k, v string) (int, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if db.exists(k) && db.t(k) != keyTypeList {
|
||||
return 0, ErrWrongType
|
||||
}
|
||||
return db.listLpush(k, v), nil
|
||||
}
|
||||
|
||||
// Lpop removes and returns the last element in a list.
|
||||
func (m *Miniredis) Lpop(k string) (string, error) {
|
||||
return m.DB(m.selectedDB).Lpop(k)
|
||||
}
|
||||
|
||||
// Lpop removes and returns the last element in a list.
|
||||
func (db *RedisDB) Lpop(k string) (string, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if !db.exists(k) {
|
||||
return "", ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeList {
|
||||
return "", ErrWrongType
|
||||
}
|
||||
return db.listLpop(k), nil
|
||||
}
|
||||
|
||||
// RPush appends one or multiple values to a list. Returns the new length.
|
||||
// An alias for Push
|
||||
func (m *Miniredis) RPush(k string, v ...string) (int, error) {
|
||||
return m.Push(k, v...)
|
||||
}
|
||||
|
||||
// Push add element at the end. Returns the new length.
|
||||
func (m *Miniredis) Push(k string, v ...string) (int, error) {
|
||||
return m.DB(m.selectedDB).Push(k, v...)
|
||||
}
|
||||
|
||||
// Push add element at the end. Is called RPUSH in redis. Returns the new length.
|
||||
func (db *RedisDB) Push(k string, v ...string) (int, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if db.exists(k) && db.t(k) != keyTypeList {
|
||||
return 0, ErrWrongType
|
||||
}
|
||||
return db.listPush(k, v...), nil
|
||||
}
|
||||
|
||||
// RPop is an alias for Pop
|
||||
func (m *Miniredis) RPop(k string) (string, error) {
|
||||
return m.Pop(k)
|
||||
}
|
||||
|
||||
// Pop removes and returns the last element. Is called RPOP in Redis.
|
||||
func (m *Miniredis) Pop(k string) (string, error) {
|
||||
return m.DB(m.selectedDB).Pop(k)
|
||||
}
|
||||
|
||||
// Pop removes and returns the last element. Is called RPOP in Redis.
|
||||
func (db *RedisDB) Pop(k string) (string, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if !db.exists(k) {
|
||||
return "", ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeList {
|
||||
return "", ErrWrongType
|
||||
}
|
||||
|
||||
return db.listPop(k), nil
|
||||
}
|
||||
|
||||
// SAdd adds keys to a set. Returns the number of new keys.
|
||||
// Alias for SetAdd
|
||||
func (m *Miniredis) SAdd(k string, elems ...string) (int, error) {
|
||||
return m.SetAdd(k, elems...)
|
||||
}
|
||||
|
||||
// SetAdd adds keys to a set. Returns the number of new keys.
|
||||
func (m *Miniredis) SetAdd(k string, elems ...string) (int, error) {
|
||||
return m.DB(m.selectedDB).SetAdd(k, elems...)
|
||||
}
|
||||
|
||||
// SetAdd adds keys to a set. Returns the number of new keys.
|
||||
func (db *RedisDB) SetAdd(k string, elems ...string) (int, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if db.exists(k) && db.t(k) != keyTypeSet {
|
||||
return 0, ErrWrongType
|
||||
}
|
||||
return db.setAdd(k, elems...), nil
|
||||
}
|
||||
|
||||
// SMembers returns all keys in a set, sorted.
|
||||
// Alias for Members.
|
||||
func (m *Miniredis) SMembers(k string) ([]string, error) {
|
||||
return m.Members(k)
|
||||
}
|
||||
|
||||
// Members returns all keys in a set, sorted.
|
||||
func (m *Miniredis) Members(k string) ([]string, error) {
|
||||
return m.DB(m.selectedDB).Members(k)
|
||||
}
|
||||
|
||||
// Members gives all set keys. Sorted.
|
||||
func (db *RedisDB) Members(k string) ([]string, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
if !db.exists(k) {
|
||||
return nil, ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeSet {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
return db.setMembers(k), nil
|
||||
}
|
||||
|
||||
// SIsMember tells if value is in the set.
|
||||
// Alias for IsMember
|
||||
func (m *Miniredis) SIsMember(k, v string) (bool, error) {
|
||||
return m.IsMember(k, v)
|
||||
}
|
||||
|
||||
// IsMember tells if value is in the set.
|
||||
func (m *Miniredis) IsMember(k, v string) (bool, error) {
|
||||
return m.DB(m.selectedDB).IsMember(k, v)
|
||||
}
|
||||
|
||||
// IsMember tells if value is in the set.
|
||||
func (db *RedisDB) IsMember(k, v string) (bool, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
if !db.exists(k) {
|
||||
return false, ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeSet {
|
||||
return false, ErrWrongType
|
||||
}
|
||||
return db.setIsMember(k, v), nil
|
||||
}
|
||||
|
||||
// HKeys returns all (sorted) keys ('fields') for a hash key.
|
||||
func (m *Miniredis) HKeys(k string) ([]string, error) {
|
||||
return m.DB(m.selectedDB).HKeys(k)
|
||||
}
|
||||
|
||||
// HKeys returns all (sorted) keys ('fields') for a hash key.
|
||||
func (db *RedisDB) HKeys(key string) ([]string, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
if !db.exists(key) {
|
||||
return nil, ErrKeyNotFound
|
||||
}
|
||||
if db.t(key) != keyTypeHash {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
return db.hashFields(key), nil
|
||||
}
|
||||
|
||||
// Del deletes a key and any expiration value. Returns whether there was a key.
|
||||
func (m *Miniredis) Del(k string) bool {
|
||||
return m.DB(m.selectedDB).Del(k)
|
||||
}
|
||||
|
||||
// Del deletes a key and any expiration value. Returns whether there was a key.
|
||||
func (db *RedisDB) Del(k string) bool {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if !db.exists(k) {
|
||||
return false
|
||||
}
|
||||
db.del(k, true)
|
||||
return true
|
||||
}
|
||||
|
||||
// Unlink deletes a key and any expiration value. Returns where there was a key.
|
||||
// It's exactly the same as Del() and is not async. It is here for the consistency.
|
||||
func (m *Miniredis) Unlink(k string) bool {
|
||||
return m.Del(k)
|
||||
}
|
||||
|
||||
// Unlink deletes a key and any expiration value. Returns where there was a key.
|
||||
// It's exactly the same as Del() and is not async. It is here for the consistency.
|
||||
func (db *RedisDB) Unlink(k string) bool {
|
||||
return db.Del(k)
|
||||
}
|
||||
|
||||
// TTL is the left over time to live. As set via EXPIRE, PEXPIRE, EXPIREAT,
|
||||
// PEXPIREAT.
|
||||
// Note: this direct function returns 0 if there is no TTL set, unlike redis,
|
||||
// which returns -1.
|
||||
func (m *Miniredis) TTL(k string) time.Duration {
|
||||
return m.DB(m.selectedDB).TTL(k)
|
||||
}
|
||||
|
||||
// TTL is the left over time to live. As set via EXPIRE, PEXPIRE, EXPIREAT,
|
||||
// PEXPIREAT.
|
||||
// 0 if not set.
|
||||
func (db *RedisDB) TTL(k string) time.Duration {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
return db.ttl[k]
|
||||
}
|
||||
|
||||
// SetTTL sets the TTL of a key.
|
||||
func (m *Miniredis) SetTTL(k string, ttl time.Duration) {
|
||||
m.DB(m.selectedDB).SetTTL(k, ttl)
|
||||
}
|
||||
|
||||
// SetTTL sets the time to live of a key.
|
||||
func (db *RedisDB) SetTTL(k string, ttl time.Duration) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
db.ttl[k] = ttl
|
||||
db.incr(k)
|
||||
}
|
||||
|
||||
// Type gives the type of a key, or ""
|
||||
func (m *Miniredis) Type(k string) string {
|
||||
return m.DB(m.selectedDB).Type(k)
|
||||
}
|
||||
|
||||
// Type gives the type of a key, or ""
|
||||
func (db *RedisDB) Type(k string) string {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
return db.t(k)
|
||||
}
|
||||
|
||||
// Exists tells whether a key exists.
|
||||
func (m *Miniredis) Exists(k string) bool {
|
||||
return m.DB(m.selectedDB).Exists(k)
|
||||
}
|
||||
|
||||
// Exists tells whether a key exists.
|
||||
func (db *RedisDB) Exists(k string) bool {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
return db.exists(k)
|
||||
}
|
||||
|
||||
// HGet returns hash keys added with HSET.
|
||||
// This will return an empty string if the key is not set. Redis would return
|
||||
// a nil.
|
||||
// Returns empty string when the key is of a different type.
|
||||
func (m *Miniredis) HGet(k, f string) string {
|
||||
return m.DB(m.selectedDB).HGet(k, f)
|
||||
}
|
||||
|
||||
// HGet returns hash keys added with HSET.
|
||||
// Returns empty string when the key is of a different type.
|
||||
func (db *RedisDB) HGet(k, f string) string {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
h, ok := db.hashKeys[k]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return h[f]
|
||||
}
|
||||
|
||||
// HSet sets hash keys.
|
||||
// If there is another key by the same name it will be gone.
|
||||
func (m *Miniredis) HSet(k string, fv ...string) {
|
||||
m.DB(m.selectedDB).HSet(k, fv...)
|
||||
}
|
||||
|
||||
// HSet sets hash keys.
|
||||
// If there is another key by the same name it will be gone.
|
||||
func (db *RedisDB) HSet(k string, fv ...string) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
db.hashSet(k, fv...)
|
||||
}
|
||||
|
||||
// HDel deletes a hash key.
|
||||
func (m *Miniredis) HDel(k, f string) {
|
||||
m.DB(m.selectedDB).HDel(k, f)
|
||||
}
|
||||
|
||||
// HDel deletes a hash key.
|
||||
func (db *RedisDB) HDel(k, f string) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
db.hdel(k, f)
|
||||
}
|
||||
|
||||
func (db *RedisDB) hdel(k, f string) {
|
||||
if _, ok := db.hashKeys[k]; !ok {
|
||||
return
|
||||
}
|
||||
delete(db.hashKeys[k], f)
|
||||
db.incr(k)
|
||||
}
|
||||
|
||||
// HIncrBy increases the integer value of a hash field by delta (int).
|
||||
func (m *Miniredis) HIncrBy(k, f string, delta int) (int, error) {
|
||||
return m.HIncr(k, f, delta)
|
||||
}
|
||||
|
||||
// HIncr increases a key/field by delta (int).
|
||||
func (m *Miniredis) HIncr(k, f string, delta int) (int, error) {
|
||||
return m.DB(m.selectedDB).HIncr(k, f, delta)
|
||||
}
|
||||
|
||||
// HIncr increases a key/field by delta (int).
|
||||
func (db *RedisDB) HIncr(k, f string, delta int) (int, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
return db.hashIncr(k, f, delta)
|
||||
}
|
||||
|
||||
// HIncrByFloat increases a key/field by delta (float).
|
||||
func (m *Miniredis) HIncrByFloat(k, f string, delta float64) (float64, error) {
|
||||
return m.HIncrfloat(k, f, delta)
|
||||
}
|
||||
|
||||
// HIncrfloat increases a key/field by delta (float).
|
||||
func (m *Miniredis) HIncrfloat(k, f string, delta float64) (float64, error) {
|
||||
return m.DB(m.selectedDB).HIncrfloat(k, f, delta)
|
||||
}
|
||||
|
||||
// HIncrfloat increases a key/field by delta (float).
|
||||
func (db *RedisDB) HIncrfloat(k, f string, delta float64) (float64, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
v, err := db.hashIncrfloat(k, f, big.NewFloat(delta))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
vf, _ := v.Float64()
|
||||
return vf, nil
|
||||
}
|
||||
|
||||
// SRem removes fields from a set. Returns number of deleted fields.
|
||||
func (m *Miniredis) SRem(k string, fields ...string) (int, error) {
|
||||
return m.DB(m.selectedDB).SRem(k, fields...)
|
||||
}
|
||||
|
||||
// SRem removes fields from a set. Returns number of deleted fields.
|
||||
func (db *RedisDB) SRem(k string, fields ...string) (int, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if !db.exists(k) {
|
||||
return 0, ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeSet {
|
||||
return 0, ErrWrongType
|
||||
}
|
||||
return db.setRem(k, fields...), nil
|
||||
}
|
||||
|
||||
// ZAdd adds a score,member to a sorted set.
|
||||
func (m *Miniredis) ZAdd(k string, score float64, member string) (bool, error) {
|
||||
return m.DB(m.selectedDB).ZAdd(k, score, member)
|
||||
}
|
||||
|
||||
// ZAdd adds a score,member to a sorted set.
|
||||
func (db *RedisDB) ZAdd(k string, score float64, member string) (bool, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if db.exists(k) && db.t(k) != keyTypeSortedSet {
|
||||
return false, ErrWrongType
|
||||
}
|
||||
return db.ssetAdd(k, score, member), nil
|
||||
}
|
||||
|
||||
// ZMembers returns all members of a sorted set by score
|
||||
func (m *Miniredis) ZMembers(k string) ([]string, error) {
|
||||
return m.DB(m.selectedDB).ZMembers(k)
|
||||
}
|
||||
|
||||
// ZMembers returns all members of a sorted set by score
|
||||
func (db *RedisDB) ZMembers(k string) ([]string, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
if !db.exists(k) {
|
||||
return nil, ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeSortedSet {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
return db.ssetMembers(k), nil
|
||||
}
|
||||
|
||||
// SortedSet returns a raw string->float64 map.
|
||||
func (m *Miniredis) SortedSet(k string) (map[string]float64, error) {
|
||||
return m.DB(m.selectedDB).SortedSet(k)
|
||||
}
|
||||
|
||||
// SortedSet returns a raw string->float64 map.
|
||||
func (db *RedisDB) SortedSet(k string) (map[string]float64, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
if !db.exists(k) {
|
||||
return nil, ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeSortedSet {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
return db.sortedSet(k), nil
|
||||
}
|
||||
|
||||
// ZRem deletes a member. Returns whether the was a key.
|
||||
func (m *Miniredis) ZRem(k, member string) (bool, error) {
|
||||
return m.DB(m.selectedDB).ZRem(k, member)
|
||||
}
|
||||
|
||||
// ZRem deletes a member. Returns whether the was a key.
|
||||
func (db *RedisDB) ZRem(k, member string) (bool, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
if !db.exists(k) {
|
||||
return false, ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeSortedSet {
|
||||
return false, ErrWrongType
|
||||
}
|
||||
return db.ssetRem(k, member), nil
|
||||
}
|
||||
|
||||
// ZScore gives the score of a sorted set member.
|
||||
func (m *Miniredis) ZScore(k, member string) (float64, error) {
|
||||
return m.DB(m.selectedDB).ZScore(k, member)
|
||||
}
|
||||
|
||||
// ZScore gives the score of a sorted set member.
|
||||
func (db *RedisDB) ZScore(k, member string) (float64, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
if !db.exists(k) {
|
||||
return 0, ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeSortedSet {
|
||||
return 0, ErrWrongType
|
||||
}
|
||||
return db.ssetScore(k, member), nil
|
||||
}
|
||||
|
||||
// ZScore gives scores of a list of members in a sorted set.
|
||||
func (m *Miniredis) ZMScore(k string, members ...string) ([]float64, error) {
|
||||
return m.DB(m.selectedDB).ZMScore(k, members)
|
||||
}
|
||||
|
||||
func (db *RedisDB) ZMScore(k string, members []string) ([]float64, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
if !db.exists(k) {
|
||||
return nil, ErrKeyNotFound
|
||||
}
|
||||
if db.t(k) != keyTypeSortedSet {
|
||||
return nil, ErrWrongType
|
||||
}
|
||||
return db.ssetMScore(k, members), nil
|
||||
}
|
||||
|
||||
// XAdd adds an entry to a stream. `id` can be left empty or be '*'.
|
||||
// If a value is given normal XADD rules apply. Values should be an even
|
||||
// length.
|
||||
func (m *Miniredis) XAdd(k string, id string, values []string) (string, error) {
|
||||
return m.DB(m.selectedDB).XAdd(k, id, values)
|
||||
}
|
||||
|
||||
// XAdd adds an entry to a stream. `id` can be left empty or be '*'.
|
||||
// If a value is given normal XADD rules apply. Values should be an even
|
||||
// length.
|
||||
func (db *RedisDB) XAdd(k string, id string, values []string) (string, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
defer db.master.signal.Broadcast()
|
||||
|
||||
s, err := db.stream(k)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if s == nil {
|
||||
s, _ = db.newStream(k)
|
||||
}
|
||||
|
||||
return s.add(id, values, db.master.effectiveNow())
|
||||
}
|
||||
|
||||
// Stream returns a slice of stream entries. Oldest first.
|
||||
func (m *Miniredis) Stream(k string) ([]StreamEntry, error) {
|
||||
return m.DB(m.selectedDB).Stream(k)
|
||||
}
|
||||
|
||||
// Stream returns a slice of stream entries. Oldest first.
|
||||
func (db *RedisDB) Stream(key string) ([]StreamEntry, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
s, err := db.stream(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.entries, nil
|
||||
}
|
||||
|
||||
// Publish a message to subscribers. Returns the number of receivers.
|
||||
func (m *Miniredis) Publish(channel, message string) int {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
return m.publish(channel, message)
|
||||
}
|
||||
|
||||
// PubSubChannels is "PUBSUB CHANNELS <pattern>". An empty pattern is fine
|
||||
// (meaning all channels).
|
||||
// Returned channels will be ordered alphabetically.
|
||||
func (m *Miniredis) PubSubChannels(pattern string) []string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
return activeChannels(m.allSubscribers(), pattern)
|
||||
}
|
||||
|
||||
// PubSubNumSub is "PUBSUB NUMSUB [channels]". It returns all channels with their
|
||||
// subscriber count.
|
||||
func (m *Miniredis) PubSubNumSub(channels ...string) map[string]int {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
subs := m.allSubscribers()
|
||||
res := map[string]int{}
|
||||
for _, channel := range channels {
|
||||
res[channel] = countSubs(subs, channel)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// PubSubNumPat is "PUBSUB NUMPAT"
|
||||
func (m *Miniredis) PubSubNumPat() int {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
return countPsubs(m.allSubscribers())
|
||||
}
|
||||
|
||||
// PfAdd adds keys to a hll. Returns the flag which equals to 1 if the inner hll value has been changed.
|
||||
func (m *Miniredis) PfAdd(k string, elems ...string) (int, error) {
|
||||
return m.DB(m.selectedDB).HllAdd(k, elems...)
|
||||
}
|
||||
|
||||
// HllAdd adds keys to a hll. Returns the flag which equals to true if the inner hll value has been changed.
|
||||
func (db *RedisDB) HllAdd(k string, elems ...string) (int, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
if db.exists(k) && db.t(k) != keyTypeHll {
|
||||
return 0, ErrWrongType
|
||||
}
|
||||
return db.hllAdd(k, elems...), nil
|
||||
}
|
||||
|
||||
// PfCount returns an estimation of the amount of elements previously added to a hll.
|
||||
func (m *Miniredis) PfCount(keys ...string) (int, error) {
|
||||
return m.DB(m.selectedDB).HllCount(keys...)
|
||||
}
|
||||
|
||||
// HllCount returns an estimation of the amount of elements previously added to a hll.
|
||||
func (db *RedisDB) HllCount(keys ...string) (int, error) {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
return db.hllCount(keys)
|
||||
}
|
||||
|
||||
// PfMerge merges all the input hlls into a hll under destKey key.
|
||||
func (m *Miniredis) PfMerge(destKey string, sourceKeys ...string) error {
|
||||
return m.DB(m.selectedDB).HllMerge(destKey, sourceKeys...)
|
||||
}
|
||||
|
||||
// HllMerge merges all the input hlls into a hll under destKey key.
|
||||
func (db *RedisDB) HllMerge(destKey string, sourceKeys ...string) error {
|
||||
db.master.Lock()
|
||||
defer db.master.Unlock()
|
||||
|
||||
return db.hllMerge(append([]string{destKey}, sourceKeys...))
|
||||
}
|
||||
|
||||
// Copy a value.
|
||||
// Needs the IDs of both the source and dest DBs (which can differ).
|
||||
// Returns ErrKeyNotFound if src does not exist.
|
||||
// Overwrites dest if it already exists (unlike the redis command, which needs a flag to allow that).
|
||||
func (m *Miniredis) Copy(srcDB int, src string, destDB int, dest string) error {
|
||||
return m.copy(m.DB(srcDB), src, m.DB(destDB), dest)
|
||||
}
|
||||
+26
@@ -0,0 +1,26 @@
|
||||
This code is derived from the C code in redis-7.2.0/deps/fpconv/*, which has
|
||||
this license:
|
||||
|
||||
Boost Software License - Version 1.0 - August 17th, 2003
|
||||
|
||||
Permission is hereby granted, free of charge, to any person or organization
|
||||
obtaining a copy of the software and accompanying documentation covered by
|
||||
this license (the "Software") to use, reproduce, display, distribute,
|
||||
execute, and transmit the Software, and to prepare derivative works of the
|
||||
Software, and to permit third-parties to whom the Software is furnished to
|
||||
do so, all subject to the following:
|
||||
|
||||
The copyright notices in the Software and this entire statement, including
|
||||
the above license grant, this restriction and the following disclaimer,
|
||||
must be included in all copies of the Software, in whole or in part, and
|
||||
all derivative works of the Software, unless such copies or derivative
|
||||
works are solely in the form of machine-executable object code generated by
|
||||
a source language processor.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
|
||||
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
|
||||
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
|
||||
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
+6
@@ -0,0 +1,6 @@
|
||||
.PHONY: test fuzz
|
||||
test:
|
||||
go test
|
||||
|
||||
fuzz:
|
||||
go test -fuzz=Fuzz
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
This is a translation of the actual C code in Redis (7.2) which does the float
|
||||
-> string conversion.
|
||||
Strconv does a close enough job, but we can use the exact same logic, so why not.
|
||||
+286
@@ -0,0 +1,286 @@
|
||||
package fpconv
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
var (
|
||||
fracmask uint64 = 0x000FFFFFFFFFFFFF
|
||||
expmask uint64 = 0x7FF0000000000000
|
||||
hiddenbit uint64 = 0x0010000000000000
|
||||
signmask uint64 = 0x8000000000000000
|
||||
expbias int64 = 1023 + 52
|
||||
zeros = []rune("0000000000000000000000")
|
||||
|
||||
tens = []uint64{
|
||||
10000000000000000000,
|
||||
1000000000000000000,
|
||||
100000000000000000,
|
||||
10000000000000000,
|
||||
1000000000000000,
|
||||
100000000000000,
|
||||
10000000000000,
|
||||
1000000000000,
|
||||
100000000000,
|
||||
10000000000,
|
||||
1000000000,
|
||||
100000000,
|
||||
10000000,
|
||||
1000000,
|
||||
100000,
|
||||
10000,
|
||||
1000,
|
||||
100,
|
||||
10,
|
||||
1}
|
||||
)
|
||||
|
||||
func absv(n int) int {
|
||||
if n < 0 {
|
||||
return -n
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func minv(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func Dtoa(d float64) string {
|
||||
var (
|
||||
dest [25]rune // Note C has 24, which is broken
|
||||
digits [18]rune
|
||||
|
||||
str_len int = 0
|
||||
neg = false
|
||||
)
|
||||
|
||||
if get_dbits(d)&signmask != 0 {
|
||||
dest[0] = '-'
|
||||
str_len++
|
||||
neg = true
|
||||
}
|
||||
|
||||
if spec := filter_special(d, dest[str_len:]); spec != 0 {
|
||||
return string(dest[:str_len+spec])
|
||||
}
|
||||
|
||||
var (
|
||||
k int = 0
|
||||
ndigits int = grisu2(d, &digits, &k)
|
||||
)
|
||||
|
||||
str_len += emit_digits(&digits, ndigits, dest[str_len:], k, neg)
|
||||
return string(dest[:str_len])
|
||||
}
|
||||
|
||||
func filter_special(fp float64, dest []rune) int {
|
||||
if fp == 0.0 {
|
||||
dest[0] = '0'
|
||||
return 1
|
||||
}
|
||||
|
||||
if math.IsNaN(fp) {
|
||||
dest[0] = 'n'
|
||||
dest[1] = 'a'
|
||||
dest[2] = 'n'
|
||||
return 3
|
||||
}
|
||||
if math.IsInf(fp, 0) {
|
||||
dest[0] = 'i'
|
||||
dest[1] = 'n'
|
||||
dest[2] = 'f'
|
||||
return 3
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func grisu2(d float64, digits *[18]rune, K *int) int {
|
||||
w := build_fp(d)
|
||||
|
||||
lower, upper := get_normalized_boundaries(w)
|
||||
|
||||
w = normalize(w)
|
||||
|
||||
var k int64
|
||||
cp := find_cachedpow10(upper.exp, &k)
|
||||
|
||||
w = multiply(w, cp)
|
||||
upper = multiply(upper, cp)
|
||||
lower = multiply(lower, cp)
|
||||
|
||||
lower.frac++
|
||||
upper.frac--
|
||||
|
||||
*K = int(-k)
|
||||
|
||||
return generate_digits(w, upper, lower, digits[:], K)
|
||||
}
|
||||
|
||||
func emit_digits(digits *[18]rune, ndigits int, dest []rune, K int, neg bool) int {
|
||||
exp := int(absv(K + ndigits - 1))
|
||||
|
||||
/* write plain integer */
|
||||
if K >= 0 && (exp < (ndigits + 7)) {
|
||||
copy(dest, digits[:ndigits])
|
||||
copy(dest[ndigits:], zeros[:K])
|
||||
|
||||
return ndigits + K
|
||||
}
|
||||
|
||||
/* write decimal w/o scientific notation */
|
||||
if K < 0 && (K > -7 || exp < 4) {
|
||||
offset := int(ndigits - absv(K))
|
||||
/* fp < 1.0 -> write leading zero */
|
||||
if offset <= 0 {
|
||||
offset = -offset
|
||||
dest[0] = '0'
|
||||
dest[1] = '.'
|
||||
copy(dest[2:], zeros[:offset])
|
||||
copy(dest[offset+2:], digits[:ndigits])
|
||||
|
||||
return ndigits + 2 + offset
|
||||
|
||||
/* fp > 1.0 */
|
||||
} else {
|
||||
copy(dest, digits[:offset])
|
||||
dest[offset] = '.'
|
||||
copy(dest[offset+1:], digits[offset:offset+ndigits-offset])
|
||||
|
||||
return ndigits + 1
|
||||
}
|
||||
}
|
||||
/* write decimal w/ scientific notation */
|
||||
l := 18 // was: 18-neg
|
||||
if neg {
|
||||
l--
|
||||
}
|
||||
ndigits = minv(ndigits, l)
|
||||
|
||||
var idx int = 0
|
||||
dest[idx] = digits[0]
|
||||
idx++
|
||||
|
||||
if ndigits > 1 {
|
||||
dest[idx] = '.'
|
||||
idx++
|
||||
copy(dest[idx:], digits[+1:ndigits-1+1])
|
||||
idx += ndigits - 1
|
||||
}
|
||||
|
||||
dest[idx] = 'e'
|
||||
idx++
|
||||
|
||||
sign := '+'
|
||||
if K+ndigits-1 < 0 {
|
||||
sign = '-'
|
||||
}
|
||||
dest[idx] = sign
|
||||
idx++
|
||||
|
||||
var cent rune = 0
|
||||
|
||||
if exp > 99 {
|
||||
cent = rune(exp / 100)
|
||||
dest[idx] = cent + '0'
|
||||
idx++
|
||||
exp -= int(cent) * 100
|
||||
}
|
||||
if exp > 9 {
|
||||
dec := rune(exp / 10)
|
||||
dest[idx] = dec + '0'
|
||||
idx++
|
||||
exp -= int(dec) * 10
|
||||
} else if cent != 0 {
|
||||
dest[idx] = '0'
|
||||
idx++
|
||||
}
|
||||
|
||||
dest[idx] = rune(exp%10) + '0'
|
||||
idx++
|
||||
|
||||
return idx
|
||||
}
|
||||
|
||||
func generate_digits(fp, upper, lower Fp, digits []rune, K *int) int {
|
||||
var (
|
||||
wfrac = uint64(upper.frac - fp.frac)
|
||||
delta = uint64(upper.frac - lower.frac)
|
||||
)
|
||||
|
||||
one := Fp{
|
||||
frac: 1 << -upper.exp,
|
||||
exp: upper.exp,
|
||||
}
|
||||
|
||||
part1 := uint64(upper.frac >> -one.exp)
|
||||
part2 := uint64(upper.frac & (one.frac - 1))
|
||||
|
||||
var (
|
||||
idx = 0
|
||||
kappa = 10
|
||||
index = 10
|
||||
)
|
||||
/* 1000000000 */
|
||||
for ; kappa > 0; index++ {
|
||||
div := tens[index]
|
||||
digit := part1 / div
|
||||
|
||||
if digit != 0 || idx != 0 {
|
||||
digits[idx] = rune(digit) + '0'
|
||||
idx++
|
||||
}
|
||||
|
||||
part1 -= digit * div
|
||||
kappa--
|
||||
|
||||
tmp := (part1 << -one.exp) + part2
|
||||
if tmp <= delta {
|
||||
*K += kappa
|
||||
round_digit(digits, idx, delta, tmp, div<<-one.exp, wfrac)
|
||||
|
||||
return idx
|
||||
}
|
||||
}
|
||||
|
||||
/* 10 */
|
||||
index = 18
|
||||
for {
|
||||
var unit uint64 = tens[index]
|
||||
part2 *= 10
|
||||
delta *= 10
|
||||
kappa--
|
||||
|
||||
digit := part2 >> -one.exp
|
||||
if digit != 0 || idx != 0 {
|
||||
digits[idx] = rune(digit) + '0'
|
||||
idx++
|
||||
}
|
||||
|
||||
part2 &= uint64(one.frac) - 1
|
||||
if part2 < delta {
|
||||
*K += kappa
|
||||
round_digit(digits, idx, delta, part2, uint64(one.frac), wfrac*unit)
|
||||
|
||||
return idx
|
||||
}
|
||||
|
||||
index--
|
||||
}
|
||||
}
|
||||
|
||||
func round_digit(digits []rune,
|
||||
ndigits int,
|
||||
delta uint64,
|
||||
rem uint64,
|
||||
kappa uint64,
|
||||
frac uint64) {
|
||||
for rem < frac && delta-rem >= kappa &&
|
||||
(rem+kappa < frac || frac-rem > rem+kappa-frac) {
|
||||
digits[ndigits-1]--
|
||||
rem += kappa
|
||||
}
|
||||
}
|
||||
+96
@@ -0,0 +1,96 @@
|
||||
package fpconv
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
type (
|
||||
Fp struct {
|
||||
frac uint64
|
||||
exp int64
|
||||
}
|
||||
)
|
||||
|
||||
func build_fp(d float64) Fp {
|
||||
bits := get_dbits(d)
|
||||
|
||||
fp := Fp{
|
||||
frac: bits & fracmask,
|
||||
exp: int64((bits & expmask) >> 52),
|
||||
}
|
||||
|
||||
if fp.exp != 0 {
|
||||
fp.frac += hiddenbit
|
||||
fp.exp -= expbias
|
||||
} else {
|
||||
fp.exp = -expbias + 1
|
||||
}
|
||||
|
||||
return fp
|
||||
}
|
||||
|
||||
func normalize(fp Fp) Fp {
|
||||
for (fp.frac & hiddenbit) == 0 {
|
||||
fp.frac <<= 1
|
||||
fp.exp--
|
||||
}
|
||||
|
||||
var shift int64 = 64 - 52 - 1
|
||||
fp.frac <<= shift
|
||||
fp.exp -= shift
|
||||
return fp
|
||||
}
|
||||
|
||||
func multiply(a, b Fp) Fp {
|
||||
lomask := uint64(0x00000000FFFFFFFF)
|
||||
|
||||
var (
|
||||
ah_bl = uint64((a.frac >> 32) * (b.frac & lomask))
|
||||
al_bh = uint64((a.frac & lomask) * (b.frac >> 32))
|
||||
al_bl = uint64((a.frac & lomask) * (b.frac & lomask))
|
||||
ah_bh = uint64((a.frac >> 32) * (b.frac >> 32))
|
||||
)
|
||||
|
||||
tmp := uint64((ah_bl & lomask) + (al_bh & lomask) + (al_bl >> 32))
|
||||
/* round up */
|
||||
tmp += uint64(1) << 31
|
||||
|
||||
return Fp{
|
||||
ah_bh + (ah_bl >> 32) + (al_bh >> 32) + (tmp >> 32),
|
||||
a.exp + b.exp + 64,
|
||||
}
|
||||
}
|
||||
|
||||
func get_dbits(d float64) uint64 {
|
||||
return math.Float64bits(d)
|
||||
}
|
||||
|
||||
func get_normalized_boundaries(fp Fp) (Fp, Fp) {
|
||||
upper := Fp{
|
||||
frac: (fp.frac << 1) + 1,
|
||||
exp: fp.exp - 1,
|
||||
}
|
||||
for (upper.frac & (hiddenbit << 1)) == 0 {
|
||||
upper.frac <<= 1
|
||||
upper.exp--
|
||||
}
|
||||
|
||||
var u_shift int64 = 64 - 52 - 2
|
||||
|
||||
upper.frac <<= u_shift
|
||||
upper.exp = upper.exp - u_shift
|
||||
|
||||
l_shift := int64(1)
|
||||
if fp.frac == hiddenbit {
|
||||
l_shift = 2
|
||||
}
|
||||
|
||||
lower := Fp{
|
||||
frac: (fp.frac << l_shift) - 1,
|
||||
exp: fp.exp - l_shift,
|
||||
}
|
||||
|
||||
lower.frac <<= lower.exp - upper.exp
|
||||
lower.exp = upper.exp
|
||||
return lower, upper
|
||||
}
|
||||
+82
@@ -0,0 +1,82 @@
|
||||
package fpconv
|
||||
|
||||
var (
|
||||
npowers int64 = 87
|
||||
steppowers int64 = 8
|
||||
firstpower int64 = -348 /* 10 ^ -348 */
|
||||
|
||||
expmax = -32
|
||||
expmin = -60
|
||||
|
||||
powers_ten = []Fp{
|
||||
{18054884314459144840, -1220}, {13451937075301367670, -1193},
|
||||
{10022474136428063862, -1166}, {14934650266808366570, -1140},
|
||||
{11127181549972568877, -1113}, {16580792590934885855, -1087},
|
||||
{12353653155963782858, -1060}, {18408377700990114895, -1034},
|
||||
{13715310171984221708, -1007}, {10218702384817765436, -980},
|
||||
{15227053142812498563, -954}, {11345038669416679861, -927},
|
||||
{16905424996341287883, -901}, {12595523146049147757, -874},
|
||||
{9384396036005875287, -847}, {13983839803942852151, -821},
|
||||
{10418772551374772303, -794}, {15525180923007089351, -768},
|
||||
{11567161174868858868, -741}, {17236413322193710309, -715},
|
||||
{12842128665889583758, -688}, {9568131466127621947, -661},
|
||||
{14257626930069360058, -635}, {10622759856335341974, -608},
|
||||
{15829145694278690180, -582}, {11793632577567316726, -555},
|
||||
{17573882009934360870, -529}, {13093562431584567480, -502},
|
||||
{9755464219737475723, -475}, {14536774485912137811, -449},
|
||||
{10830740992659433045, -422}, {16139061738043178685, -396},
|
||||
{12024538023802026127, -369}, {17917957937422433684, -343},
|
||||
{13349918974505688015, -316}, {9946464728195732843, -289},
|
||||
{14821387422376473014, -263}, {11042794154864902060, -236},
|
||||
{16455045573212060422, -210}, {12259964326927110867, -183},
|
||||
{18268770466636286478, -157}, {13611294676837538539, -130},
|
||||
{10141204801825835212, -103}, {15111572745182864684, -77},
|
||||
{11258999068426240000, -50}, {16777216000000000000, -24},
|
||||
{12500000000000000000, 3}, {9313225746154785156, 30},
|
||||
{13877787807814456755, 56}, {10339757656912845936, 83},
|
||||
{15407439555097886824, 109}, {11479437019748901445, 136},
|
||||
{17105694144590052135, 162}, {12744735289059618216, 189},
|
||||
{9495567745759798747, 216}, {14149498560666738074, 242},
|
||||
{10542197943230523224, 269}, {15709099088952724970, 295},
|
||||
{11704190886730495818, 322}, {17440603504673385349, 348},
|
||||
{12994262207056124023, 375}, {9681479787123295682, 402},
|
||||
{14426529090290212157, 428}, {10748601772107342003, 455},
|
||||
{16016664761464807395, 481}, {11933345169920330789, 508},
|
||||
{17782069995880619868, 534}, {13248674568444952270, 561},
|
||||
{9871031767461413346, 588}, {14708983551653345445, 614},
|
||||
{10959046745042015199, 641}, {16330252207878254650, 667},
|
||||
{12166986024289022870, 694}, {18130221999122236476, 720},
|
||||
{13508068024458167312, 747}, {10064294952495520794, 774},
|
||||
{14996968138956309548, 800}, {11173611982879273257, 827},
|
||||
{16649979327439178909, 853}, {12405201291620119593, 880},
|
||||
{9242595204427927429, 907}, {13772540099066387757, 933},
|
||||
{10261342003245940623, 960}, {15290591125556738113, 986},
|
||||
{11392378155556871081, 1013}, {16975966327722178521, 1039},
|
||||
{12648080533535911531, 1066},
|
||||
}
|
||||
)
|
||||
|
||||
func find_cachedpow10(exp int64, k *int64) Fp {
|
||||
one_log_ten := 0.30102999566398114
|
||||
|
||||
approx := int64(float64(-(exp + npowers)) * one_log_ten)
|
||||
idx := int((approx - firstpower) / steppowers)
|
||||
|
||||
for {
|
||||
current := int(exp + powers_ten[idx].exp + 64)
|
||||
|
||||
if current < expmin {
|
||||
idx++
|
||||
continue
|
||||
}
|
||||
|
||||
if current > expmax {
|
||||
idx--
|
||||
continue
|
||||
}
|
||||
|
||||
*k = (firstpower + int64(idx)*steppowers)
|
||||
|
||||
return powers_ten[idx]
|
||||
}
|
||||
}
|
||||
+46
@@ -0,0 +1,46 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/geohash"
|
||||
)
|
||||
|
||||
func toGeohash(long, lat float64) uint64 {
|
||||
return geohash.EncodeIntWithPrecision(lat, long, 52)
|
||||
}
|
||||
|
||||
func fromGeohash(score uint64) (float64, float64) {
|
||||
lat, long := geohash.DecodeIntWithPrecision(score, 52)
|
||||
return long, lat
|
||||
}
|
||||
|
||||
// haversin(θ) function
|
||||
func hsin(theta float64) float64 {
|
||||
return math.Pow(math.Sin(theta/2), 2)
|
||||
}
|
||||
|
||||
// distance function returns the distance (in meters) between two points of
|
||||
// a given longitude and latitude relatively accurately (using a spherical
|
||||
// approximation of the Earth) through the Haversin Distance Formula for
|
||||
// great arc distance on a sphere with accuracy for small distances
|
||||
// point coordinates are supplied in degrees and converted into rad. in the func
|
||||
// distance returned is meters
|
||||
// http://en.wikipedia.org/wiki/Haversine_formula
|
||||
// Source: https://gist.github.com/cdipaolo/d3f8db3848278b49db68
|
||||
func distance(lat1, lon1, lat2, lon2 float64) float64 {
|
||||
// convert to radians
|
||||
// must cast radius as float to multiply later
|
||||
var la1, lo1, la2, lo2 float64
|
||||
la1 = lat1 * math.Pi / 180
|
||||
lo1 = lon1 * math.Pi / 180
|
||||
la2 = lat2 * math.Pi / 180
|
||||
lo2 = lon2 * math.Pi / 180
|
||||
|
||||
earth := 6372797.560856 // Earth radius in METERS, according to src/geohash_helper.c
|
||||
|
||||
// calculate
|
||||
h := hsin(la2-la1) + math.Cos(la1)*math.Cos(la2)*hsin(lo2-lo1)
|
||||
|
||||
return 2 * earth * math.Asin(math.Sqrt(h))
|
||||
}
|
||||
+22
@@ -0,0 +1,22 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 Michael McLoughlin
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
+2
@@ -0,0 +1,2 @@
|
||||
This is a (selected) copy of github.com/mmcloughlin/geohash with the latitude
|
||||
range changed from 90 to ~85, to align with the algorithm use by Redis.
|
||||
+44
@@ -0,0 +1,44 @@
|
||||
package geohash
|
||||
|
||||
// encoding encapsulates an encoding defined by a given base32 alphabet.
|
||||
type encoding struct {
|
||||
encode string
|
||||
decode [256]byte
|
||||
}
|
||||
|
||||
// newEncoding constructs a new encoding defined by the given alphabet,
|
||||
// which must be a 32-byte string.
|
||||
func newEncoding(encoder string) *encoding {
|
||||
e := new(encoding)
|
||||
e.encode = encoder
|
||||
for i := 0; i < len(e.decode); i++ {
|
||||
e.decode[i] = 0xff
|
||||
}
|
||||
for i := 0; i < len(encoder); i++ {
|
||||
e.decode[encoder[i]] = byte(i)
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Decode string into bits of a 64-bit word. The string s may be at most 12
|
||||
// characters.
|
||||
func (e *encoding) Decode(s string) uint64 {
|
||||
x := uint64(0)
|
||||
for i := 0; i < len(s); i++ {
|
||||
x = (x << 5) | uint64(e.decode[s[i]])
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// Encode bits of 64-bit word into a string.
|
||||
func (e *encoding) Encode(x uint64) string {
|
||||
b := [12]byte{}
|
||||
for i := 0; i < 12; i++ {
|
||||
b[11-i] = e.encode[x&0x1f]
|
||||
x >>= 5
|
||||
}
|
||||
return string(b[:])
|
||||
}
|
||||
|
||||
// Base32Encoding with the Geohash alphabet.
|
||||
var base32encoding = newEncoding("0123456789bcdefghjkmnpqrstuvwxyz")
|
||||
+269
@@ -0,0 +1,269 @@
|
||||
// Package geohash provides encoding and decoding of string and integer
|
||||
// geohashes.
|
||||
package geohash
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
const (
|
||||
ENC_LAT = 85.05112878
|
||||
ENC_LONG = 180.0
|
||||
)
|
||||
|
||||
// Direction represents directions in the latitute/longitude space.
|
||||
type Direction int
|
||||
|
||||
// Cardinal and intercardinal directions
|
||||
const (
|
||||
North Direction = iota
|
||||
NorthEast
|
||||
East
|
||||
SouthEast
|
||||
South
|
||||
SouthWest
|
||||
West
|
||||
NorthWest
|
||||
)
|
||||
|
||||
// Encode the point (lat, lng) as a string geohash with the standard 12
|
||||
// characters of precision.
|
||||
func Encode(lat, lng float64) string {
|
||||
return EncodeWithPrecision(lat, lng, 12)
|
||||
}
|
||||
|
||||
// EncodeWithPrecision encodes the point (lat, lng) as a string geohash with
|
||||
// the specified number of characters of precision (max 12).
|
||||
func EncodeWithPrecision(lat, lng float64, chars uint) string {
|
||||
bits := 5 * chars
|
||||
inthash := EncodeIntWithPrecision(lat, lng, bits)
|
||||
enc := base32encoding.Encode(inthash)
|
||||
return enc[12-chars:]
|
||||
}
|
||||
|
||||
// encodeInt provides a Go implementation of integer geohash. This is the
|
||||
// default implementation of EncodeInt, but optimized versions are provided
|
||||
// for certain architectures.
|
||||
func EncodeInt(lat, lng float64) uint64 {
|
||||
latInt := encodeRange(lat, ENC_LAT)
|
||||
lngInt := encodeRange(lng, ENC_LONG)
|
||||
return interleave(latInt, lngInt)
|
||||
}
|
||||
|
||||
// EncodeIntWithPrecision encodes the point (lat, lng) to an integer with the
|
||||
// specified number of bits.
|
||||
func EncodeIntWithPrecision(lat, lng float64, bits uint) uint64 {
|
||||
hash := EncodeInt(lat, lng)
|
||||
return hash >> (64 - bits)
|
||||
}
|
||||
|
||||
// Box represents a rectangle in latitude/longitude space.
|
||||
type Box struct {
|
||||
MinLat float64
|
||||
MaxLat float64
|
||||
MinLng float64
|
||||
MaxLng float64
|
||||
}
|
||||
|
||||
// Center returns the center of the box.
|
||||
func (b Box) Center() (lat, lng float64) {
|
||||
lat = (b.MinLat + b.MaxLat) / 2.0
|
||||
lng = (b.MinLng + b.MaxLng) / 2.0
|
||||
return
|
||||
}
|
||||
|
||||
// Contains decides whether (lat, lng) is contained in the box. The
|
||||
// containment test is inclusive of the edges and corners.
|
||||
func (b Box) Contains(lat, lng float64) bool {
|
||||
return (b.MinLat <= lat && lat <= b.MaxLat &&
|
||||
b.MinLng <= lng && lng <= b.MaxLng)
|
||||
}
|
||||
|
||||
// errorWithPrecision returns the error range in latitude and longitude for in
|
||||
// integer geohash with bits of precision.
|
||||
func errorWithPrecision(bits uint) (latErr, lngErr float64) {
|
||||
b := int(bits)
|
||||
latBits := b / 2
|
||||
lngBits := b - latBits
|
||||
latErr = math.Ldexp(180.0, -latBits)
|
||||
lngErr = math.Ldexp(360.0, -lngBits)
|
||||
return
|
||||
}
|
||||
|
||||
// BoundingBox returns the region encoded by the given string geohash.
|
||||
func BoundingBox(hash string) Box {
|
||||
bits := uint(5 * len(hash))
|
||||
inthash := base32encoding.Decode(hash)
|
||||
return BoundingBoxIntWithPrecision(inthash, bits)
|
||||
}
|
||||
|
||||
// BoundingBoxIntWithPrecision returns the region encoded by the integer
|
||||
// geohash with the specified precision.
|
||||
func BoundingBoxIntWithPrecision(hash uint64, bits uint) Box {
|
||||
fullHash := hash << (64 - bits)
|
||||
latInt, lngInt := deinterleave(fullHash)
|
||||
lat := decodeRange(latInt, ENC_LAT)
|
||||
lng := decodeRange(lngInt, ENC_LONG)
|
||||
latErr, lngErr := errorWithPrecision(bits)
|
||||
return Box{
|
||||
MinLat: lat,
|
||||
MaxLat: lat + latErr,
|
||||
MinLng: lng,
|
||||
MaxLng: lng + lngErr,
|
||||
}
|
||||
}
|
||||
|
||||
// BoundingBoxInt returns the region encoded by the given 64-bit integer
|
||||
// geohash.
|
||||
func BoundingBoxInt(hash uint64) Box {
|
||||
return BoundingBoxIntWithPrecision(hash, 64)
|
||||
}
|
||||
|
||||
// DecodeCenter decodes the string geohash to the central point of the bounding box.
|
||||
func DecodeCenter(hash string) (lat, lng float64) {
|
||||
box := BoundingBox(hash)
|
||||
return box.Center()
|
||||
}
|
||||
|
||||
// DecodeIntWithPrecision decodes the provided integer geohash with bits of
|
||||
// precision to a (lat, lng) point.
|
||||
func DecodeIntWithPrecision(hash uint64, bits uint) (lat, lng float64) {
|
||||
box := BoundingBoxIntWithPrecision(hash, bits)
|
||||
return box.Center()
|
||||
}
|
||||
|
||||
// DecodeInt decodes the provided 64-bit integer geohash to a (lat, lng) point.
|
||||
func DecodeInt(hash uint64) (lat, lng float64) {
|
||||
return DecodeIntWithPrecision(hash, 64)
|
||||
}
|
||||
|
||||
// Neighbors returns a slice of geohash strings that correspond to the provided
|
||||
// geohash's neighbors.
|
||||
func Neighbors(hash string) []string {
|
||||
box := BoundingBox(hash)
|
||||
lat, lng := box.Center()
|
||||
latDelta := box.MaxLat - box.MinLat
|
||||
lngDelta := box.MaxLng - box.MinLng
|
||||
precision := uint(len(hash))
|
||||
return []string{
|
||||
// N
|
||||
EncodeWithPrecision(lat+latDelta, lng, precision),
|
||||
// NE,
|
||||
EncodeWithPrecision(lat+latDelta, lng+lngDelta, precision),
|
||||
// E,
|
||||
EncodeWithPrecision(lat, lng+lngDelta, precision),
|
||||
// SE,
|
||||
EncodeWithPrecision(lat-latDelta, lng+lngDelta, precision),
|
||||
// S,
|
||||
EncodeWithPrecision(lat-latDelta, lng, precision),
|
||||
// SW,
|
||||
EncodeWithPrecision(lat-latDelta, lng-lngDelta, precision),
|
||||
// W,
|
||||
EncodeWithPrecision(lat, lng-lngDelta, precision),
|
||||
// NW
|
||||
EncodeWithPrecision(lat+latDelta, lng-lngDelta, precision),
|
||||
}
|
||||
}
|
||||
|
||||
// NeighborsInt returns a slice of uint64s that correspond to the provided hash's
|
||||
// neighbors at 64-bit precision.
|
||||
func NeighborsInt(hash uint64) []uint64 {
|
||||
return NeighborsIntWithPrecision(hash, 64)
|
||||
}
|
||||
|
||||
// NeighborsIntWithPrecision returns a slice of uint64s that correspond to the
|
||||
// provided hash's neighbors at the given precision.
|
||||
func NeighborsIntWithPrecision(hash uint64, bits uint) []uint64 {
|
||||
box := BoundingBoxIntWithPrecision(hash, bits)
|
||||
lat, lng := box.Center()
|
||||
latDelta := box.MaxLat - box.MinLat
|
||||
lngDelta := box.MaxLng - box.MinLng
|
||||
return []uint64{
|
||||
// N
|
||||
EncodeIntWithPrecision(lat+latDelta, lng, bits),
|
||||
// NE,
|
||||
EncodeIntWithPrecision(lat+latDelta, lng+lngDelta, bits),
|
||||
// E,
|
||||
EncodeIntWithPrecision(lat, lng+lngDelta, bits),
|
||||
// SE,
|
||||
EncodeIntWithPrecision(lat-latDelta, lng+lngDelta, bits),
|
||||
// S,
|
||||
EncodeIntWithPrecision(lat-latDelta, lng, bits),
|
||||
// SW,
|
||||
EncodeIntWithPrecision(lat-latDelta, lng-lngDelta, bits),
|
||||
// W,
|
||||
EncodeIntWithPrecision(lat, lng-lngDelta, bits),
|
||||
// NW
|
||||
EncodeIntWithPrecision(lat+latDelta, lng-lngDelta, bits),
|
||||
}
|
||||
}
|
||||
|
||||
// Neighbor returns a geohash string that corresponds to the provided
|
||||
// geohash's neighbor in the provided direction
|
||||
func Neighbor(hash string, direction Direction) string {
|
||||
return Neighbors(hash)[direction]
|
||||
}
|
||||
|
||||
// NeighborInt returns a uint64 that corresponds to the provided hash's
|
||||
// neighbor in the provided direction at 64-bit precision.
|
||||
func NeighborInt(hash uint64, direction Direction) uint64 {
|
||||
return NeighborsIntWithPrecision(hash, 64)[direction]
|
||||
}
|
||||
|
||||
// NeighborIntWithPrecision returns a uint64s that corresponds to the
|
||||
// provided hash's neighbor in the provided direction at the given precision.
|
||||
func NeighborIntWithPrecision(hash uint64, bits uint, direction Direction) uint64 {
|
||||
return NeighborsIntWithPrecision(hash, bits)[direction]
|
||||
}
|
||||
|
||||
// precalculated for performance
|
||||
var exp232 = math.Exp2(32)
|
||||
|
||||
// Encode the position of x within the range -r to +r as a 32-bit integer.
|
||||
func encodeRange(x, r float64) uint32 {
|
||||
p := (x + r) / (2 * r)
|
||||
return uint32(p * exp232)
|
||||
}
|
||||
|
||||
// Decode the 32-bit range encoding X back to a value in the range -r to +r.
|
||||
func decodeRange(X uint32, r float64) float64 {
|
||||
p := float64(X) / exp232
|
||||
x := 2*r*p - r
|
||||
return x
|
||||
}
|
||||
|
||||
// Spread out the 32 bits of x into 64 bits, where the bits of x occupy even
|
||||
// bit positions.
|
||||
func spread(x uint32) uint64 {
|
||||
X := uint64(x)
|
||||
X = (X | (X << 16)) & 0x0000ffff0000ffff
|
||||
X = (X | (X << 8)) & 0x00ff00ff00ff00ff
|
||||
X = (X | (X << 4)) & 0x0f0f0f0f0f0f0f0f
|
||||
X = (X | (X << 2)) & 0x3333333333333333
|
||||
X = (X | (X << 1)) & 0x5555555555555555
|
||||
return X
|
||||
}
|
||||
|
||||
// Interleave the bits of x and y. In the result, x and y occupy even and odd
|
||||
// bitlevels, respectively.
|
||||
func interleave(x, y uint32) uint64 {
|
||||
return spread(x) | (spread(y) << 1)
|
||||
}
|
||||
|
||||
// Squash the even bitlevels of X into a 32-bit word. Odd bitlevels of X are
|
||||
// ignored, and may take any value.
|
||||
func squash(X uint64) uint32 {
|
||||
X &= 0x5555555555555555
|
||||
X = (X | (X >> 1)) & 0x3333333333333333
|
||||
X = (X | (X >> 2)) & 0x0f0f0f0f0f0f0f0f
|
||||
X = (X | (X >> 4)) & 0x00ff00ff00ff00ff
|
||||
X = (X | (X >> 8)) & 0x0000ffff0000ffff
|
||||
X = (X | (X >> 16)) & 0x00000000ffffffff
|
||||
return uint32(X)
|
||||
}
|
||||
|
||||
// Deinterleave the bits of X into 32-bit words containing the even and odd
|
||||
// bitlevels of X, respectively.
|
||||
func deinterleave(X uint64) (uint32, uint32) {
|
||||
return squash(X), squash(X >> 1)
|
||||
}
|
||||
+24
@@ -0,0 +1,24 @@
|
||||
This is free and unencumbered software released into the public domain.
|
||||
|
||||
Anyone is free to copy, modify, publish, use, compile, sell, or
|
||||
distribute this software, either in source code form or as a compiled
|
||||
binary, for any purpose, commercial or non-commercial, and by any
|
||||
means.
|
||||
|
||||
In jurisdictions that recognize copyright laws, the author or authors
|
||||
of this software dedicate any and all copyright interest in the
|
||||
software to the public domain. We make this dedication for the benefit
|
||||
of the public at large and to the detriment of our heirs and
|
||||
successors. We intend this dedication to be an overt act of
|
||||
relinquishment in perpetuity of all present and future rights to this
|
||||
software under copyright law.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
For more information, please refer to <http://unlicense.org/>
|
||||
+1
@@ -0,0 +1 @@
|
||||
Copied from https://github.com/layeh/gopher-json and https://github.com/alicebob/gopher-json
|
||||
+189
@@ -0,0 +1,189 @@
|
||||
package json
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// Preload adds json to the given Lua state's package.preload table. After it
|
||||
// has been preloaded, it can be loaded using require:
|
||||
//
|
||||
// local json = require("json")
|
||||
func Preload(L *lua.LState) {
|
||||
L.PreloadModule("json", Loader)
|
||||
}
|
||||
|
||||
// Loader is the module loader function.
|
||||
func Loader(L *lua.LState) int {
|
||||
t := L.NewTable()
|
||||
L.SetFuncs(t, api)
|
||||
L.Push(t)
|
||||
return 1
|
||||
}
|
||||
|
||||
var api = map[string]lua.LGFunction{
|
||||
"decode": apiDecode,
|
||||
"encode": apiEncode,
|
||||
}
|
||||
|
||||
func apiDecode(L *lua.LState) int {
|
||||
if L.GetTop() != 1 {
|
||||
L.Error(lua.LString("bad argument #1 to decode"), 1)
|
||||
return 0
|
||||
}
|
||||
str := L.CheckString(1)
|
||||
|
||||
value, err := Decode(L, []byte(str))
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(value)
|
||||
return 1
|
||||
}
|
||||
|
||||
func apiEncode(L *lua.LState) int {
|
||||
if L.GetTop() != 1 {
|
||||
L.Error(lua.LString("bad argument #1 to encode"), 1)
|
||||
return 0
|
||||
}
|
||||
value := L.CheckAny(1)
|
||||
|
||||
data, err := Encode(value)
|
||||
if err != nil {
|
||||
L.Push(lua.LNil)
|
||||
L.Push(lua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(lua.LString(string(data)))
|
||||
return 1
|
||||
}
|
||||
|
||||
var (
|
||||
errNested = errors.New("cannot encode recursively nested tables to JSON")
|
||||
errSparseArray = errors.New("cannot encode sparse array")
|
||||
errInvalidKeys = errors.New("cannot encode mixed or invalid key types")
|
||||
)
|
||||
|
||||
type invalidTypeError lua.LValueType
|
||||
|
||||
func (i invalidTypeError) Error() string {
|
||||
return `cannot encode ` + lua.LValueType(i).String() + ` to JSON`
|
||||
}
|
||||
|
||||
// Encode returns the JSON encoding of value.
|
||||
func Encode(value lua.LValue) ([]byte, error) {
|
||||
return json.Marshal(jsonValue{
|
||||
LValue: value,
|
||||
visited: make(map[*lua.LTable]bool),
|
||||
})
|
||||
}
|
||||
|
||||
type jsonValue struct {
|
||||
lua.LValue
|
||||
visited map[*lua.LTable]bool
|
||||
}
|
||||
|
||||
func (j jsonValue) MarshalJSON() (data []byte, err error) {
|
||||
switch converted := j.LValue.(type) {
|
||||
case lua.LBool:
|
||||
data, err = json.Marshal(bool(converted))
|
||||
case lua.LNumber:
|
||||
data, err = json.Marshal(float64(converted))
|
||||
case *lua.LNilType:
|
||||
data = []byte(`null`)
|
||||
case lua.LString:
|
||||
data, err = json.Marshal(string(converted))
|
||||
case *lua.LTable:
|
||||
if j.visited[converted] {
|
||||
return nil, errNested
|
||||
}
|
||||
j.visited[converted] = true
|
||||
|
||||
key, value := converted.Next(lua.LNil)
|
||||
|
||||
switch key.Type() {
|
||||
case lua.LTNil: // empty table
|
||||
data = []byte(`[]`)
|
||||
case lua.LTNumber:
|
||||
arr := make([]jsonValue, 0, converted.Len())
|
||||
expectedKey := lua.LNumber(1)
|
||||
for key != lua.LNil {
|
||||
if key.Type() != lua.LTNumber {
|
||||
err = errInvalidKeys
|
||||
return
|
||||
}
|
||||
if expectedKey != key {
|
||||
err = errSparseArray
|
||||
return
|
||||
}
|
||||
arr = append(arr, jsonValue{value, j.visited})
|
||||
expectedKey++
|
||||
key, value = converted.Next(key)
|
||||
}
|
||||
data, err = json.Marshal(arr)
|
||||
case lua.LTString:
|
||||
obj := make(map[string]jsonValue)
|
||||
for key != lua.LNil {
|
||||
if key.Type() != lua.LTString {
|
||||
err = errInvalidKeys
|
||||
return
|
||||
}
|
||||
obj[key.String()] = jsonValue{value, j.visited}
|
||||
key, value = converted.Next(key)
|
||||
}
|
||||
data, err = json.Marshal(obj)
|
||||
default:
|
||||
err = errInvalidKeys
|
||||
}
|
||||
default:
|
||||
err = invalidTypeError(j.LValue.Type())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Decode converts the JSON encoded data to Lua values.
|
||||
func Decode(L *lua.LState, data []byte) (lua.LValue, error) {
|
||||
var value interface{}
|
||||
err := json.Unmarshal(data, &value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DecodeValue(L, value), nil
|
||||
}
|
||||
|
||||
// DecodeValue converts the value to a Lua value.
|
||||
//
|
||||
// This function only converts values that the encoding/json package decodes to.
|
||||
// All other values will return lua.LNil.
|
||||
func DecodeValue(L *lua.LState, value interface{}) lua.LValue {
|
||||
switch converted := value.(type) {
|
||||
case bool:
|
||||
return lua.LBool(converted)
|
||||
case float64:
|
||||
return lua.LNumber(converted)
|
||||
case string:
|
||||
return lua.LString(converted)
|
||||
case json.Number:
|
||||
return lua.LString(converted)
|
||||
case []interface{}:
|
||||
arr := L.CreateTable(len(converted), 0)
|
||||
for _, item := range converted {
|
||||
arr.Append(DecodeValue(L, item))
|
||||
}
|
||||
return arr
|
||||
case map[string]interface{}:
|
||||
tbl := L.CreateTable(0, len(converted))
|
||||
for key, item := range converted {
|
||||
tbl.RawSetH(lua.LString(key), DecodeValue(L, item))
|
||||
}
|
||||
return tbl
|
||||
case nil:
|
||||
return lua.LNil
|
||||
}
|
||||
|
||||
return lua.LNil
|
||||
}
|
||||
+42
@@ -0,0 +1,42 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"github.com/alicebob/miniredis/v2/hyperloglog"
|
||||
)
|
||||
|
||||
type hll struct {
|
||||
inner *hyperloglog.Sketch
|
||||
}
|
||||
|
||||
func newHll() *hll {
|
||||
return &hll{
|
||||
inner: hyperloglog.New14(),
|
||||
}
|
||||
}
|
||||
|
||||
// Add returns true if cardinality has been changed, or false otherwise.
|
||||
func (h *hll) Add(item []byte) bool {
|
||||
return h.inner.Insert(item)
|
||||
}
|
||||
|
||||
// Count returns the estimation of a set cardinality.
|
||||
func (h *hll) Count() int {
|
||||
return int(h.inner.Estimate())
|
||||
}
|
||||
|
||||
// Merge merges the other hll into original one (not making a copy but doing this in place).
|
||||
func (h *hll) Merge(other *hll) {
|
||||
_ = h.inner.Merge(other.inner)
|
||||
}
|
||||
|
||||
// Bytes returns raw-bytes representation of hll data structure.
|
||||
func (h *hll) Bytes() []byte {
|
||||
dataBytes, _ := h.inner.MarshalBinary()
|
||||
return dataBytes
|
||||
}
|
||||
|
||||
func (h *hll) copy() *hll {
|
||||
return &hll{
|
||||
inner: h.inner.Clone(),
|
||||
}
|
||||
}
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2017 Axiom Inc. <seif@axiom.sh>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
+1
@@ -0,0 +1 @@
|
||||
This is a copy of github.com/axiomhq/hyperloglog.
|
||||
+180
@@ -0,0 +1,180 @@
|
||||
package hyperloglog
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// Original author of this file is github.com/clarkduvall/hyperloglog
|
||||
type iterable interface {
|
||||
decode(i int, last uint32) (uint32, int)
|
||||
Len() int
|
||||
Iter() *iterator
|
||||
}
|
||||
|
||||
type iterator struct {
|
||||
i int
|
||||
last uint32
|
||||
v iterable
|
||||
}
|
||||
|
||||
func (iter *iterator) Next() uint32 {
|
||||
n, i := iter.v.decode(iter.i, iter.last)
|
||||
iter.last = n
|
||||
iter.i = i
|
||||
return n
|
||||
}
|
||||
|
||||
func (iter *iterator) Peek() uint32 {
|
||||
n, _ := iter.v.decode(iter.i, iter.last)
|
||||
return n
|
||||
}
|
||||
|
||||
func (iter iterator) HasNext() bool {
|
||||
return iter.i < iter.v.Len()
|
||||
}
|
||||
|
||||
type compressedList struct {
|
||||
count uint32
|
||||
last uint32
|
||||
b variableLengthList
|
||||
}
|
||||
|
||||
func (v *compressedList) Clone() *compressedList {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
newV := &compressedList{
|
||||
count: v.count,
|
||||
last: v.last,
|
||||
}
|
||||
|
||||
newV.b = make(variableLengthList, len(v.b))
|
||||
copy(newV.b, v.b)
|
||||
return newV
|
||||
}
|
||||
|
||||
func (v *compressedList) MarshalBinary() (data []byte, err error) {
|
||||
// Marshal the variableLengthList
|
||||
bdata, err := v.b.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// At least 4 bytes for the two fixed sized values plus the size of bdata.
|
||||
data = make([]byte, 0, 4+4+len(bdata))
|
||||
|
||||
// Marshal the count and last values.
|
||||
data = append(data, []byte{
|
||||
// Number of items in the list.
|
||||
byte(v.count >> 24),
|
||||
byte(v.count >> 16),
|
||||
byte(v.count >> 8),
|
||||
byte(v.count),
|
||||
// The last item in the list.
|
||||
byte(v.last >> 24),
|
||||
byte(v.last >> 16),
|
||||
byte(v.last >> 8),
|
||||
byte(v.last),
|
||||
}...)
|
||||
|
||||
// Append the list
|
||||
return append(data, bdata...), nil
|
||||
}
|
||||
|
||||
func (v *compressedList) UnmarshalBinary(data []byte) error {
|
||||
if len(data) < 12 {
|
||||
return ErrorTooShort
|
||||
}
|
||||
|
||||
// Set the count.
|
||||
v.count, data = binary.BigEndian.Uint32(data[:4]), data[4:]
|
||||
|
||||
// Set the last value.
|
||||
v.last, data = binary.BigEndian.Uint32(data[:4]), data[4:]
|
||||
|
||||
// Set the list.
|
||||
sz, data := binary.BigEndian.Uint32(data[:4]), data[4:]
|
||||
v.b = make([]uint8, sz)
|
||||
if uint32(len(data)) < sz {
|
||||
return ErrorTooShort
|
||||
}
|
||||
for i := uint32(0); i < sz; i++ {
|
||||
v.b[i] = data[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCompressedList() *compressedList {
|
||||
v := &compressedList{}
|
||||
v.b = make(variableLengthList, 0)
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *compressedList) Len() int {
|
||||
return len(v.b)
|
||||
}
|
||||
|
||||
func (v *compressedList) decode(i int, last uint32) (uint32, int) {
|
||||
n, i := v.b.decode(i, last)
|
||||
return n + last, i
|
||||
}
|
||||
|
||||
func (v *compressedList) Append(x uint32) {
|
||||
v.count++
|
||||
v.b = v.b.Append(x - v.last)
|
||||
v.last = x
|
||||
}
|
||||
|
||||
func (v *compressedList) Iter() *iterator {
|
||||
return &iterator{0, 0, v}
|
||||
}
|
||||
|
||||
type variableLengthList []uint8
|
||||
|
||||
func (v variableLengthList) MarshalBinary() (data []byte, err error) {
|
||||
// 4 bytes for the size of the list, and a byte for each element in the
|
||||
// list.
|
||||
data = make([]byte, 0, 4+v.Len())
|
||||
|
||||
// Length of the list. We only need 32 bits because the size of the set
|
||||
// couldn't exceed that on 32 bit architectures.
|
||||
sz := v.Len()
|
||||
data = append(data, []byte{
|
||||
byte(sz >> 24),
|
||||
byte(sz >> 16),
|
||||
byte(sz >> 8),
|
||||
byte(sz),
|
||||
}...)
|
||||
|
||||
// Marshal each element in the list.
|
||||
for i := 0; i < sz; i++ {
|
||||
data = append(data, v[i])
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (v variableLengthList) Len() int {
|
||||
return len(v)
|
||||
}
|
||||
|
||||
func (v *variableLengthList) Iter() *iterator {
|
||||
return &iterator{0, 0, v}
|
||||
}
|
||||
|
||||
func (v variableLengthList) decode(i int, last uint32) (uint32, int) {
|
||||
var x uint32
|
||||
j := i
|
||||
for ; v[j]&0x80 != 0; j++ {
|
||||
x |= uint32(v[j]&0x7f) << (uint(j-i) * 7)
|
||||
}
|
||||
x |= uint32(v[j]) << (uint(j-i) * 7)
|
||||
return x, j + 1
|
||||
}
|
||||
|
||||
func (v variableLengthList) Append(x uint32) variableLengthList {
|
||||
for x&0xffffff80 != 0 {
|
||||
v = append(v, uint8((x&0x7f)|0x80))
|
||||
x >>= 7
|
||||
}
|
||||
return append(v, uint8(x&0x7f))
|
||||
}
|
||||
+424
@@ -0,0 +1,424 @@
|
||||
package hyperloglog
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
)
|
||||
|
||||
const (
|
||||
capacity = uint8(16)
|
||||
pp = uint8(25)
|
||||
mp = uint32(1) << pp
|
||||
version = 1
|
||||
)
|
||||
|
||||
// Sketch is a HyperLogLog data-structure for the count-distinct problem,
|
||||
// approximating the number of distinct elements in a multiset.
|
||||
type Sketch struct {
|
||||
p uint8
|
||||
b uint8
|
||||
m uint32
|
||||
alpha float64
|
||||
tmpSet set
|
||||
sparseList *compressedList
|
||||
regs *registers
|
||||
}
|
||||
|
||||
// New returns a HyperLogLog Sketch with 2^14 registers (precision 14)
|
||||
func New() *Sketch {
|
||||
return New14()
|
||||
}
|
||||
|
||||
// New14 returns a HyperLogLog Sketch with 2^14 registers (precision 14)
|
||||
func New14() *Sketch {
|
||||
sk, _ := newSketch(14, true)
|
||||
return sk
|
||||
}
|
||||
|
||||
// New16 returns a HyperLogLog Sketch with 2^16 registers (precision 16)
|
||||
func New16() *Sketch {
|
||||
sk, _ := newSketch(16, true)
|
||||
return sk
|
||||
}
|
||||
|
||||
// NewNoSparse returns a HyperLogLog Sketch with 2^14 registers (precision 14)
|
||||
// that will not use a sparse representation
|
||||
func NewNoSparse() *Sketch {
|
||||
sk, _ := newSketch(14, false)
|
||||
return sk
|
||||
}
|
||||
|
||||
// New16NoSparse returns a HyperLogLog Sketch with 2^16 registers (precision 16)
|
||||
// that will not use a sparse representation
|
||||
func New16NoSparse() *Sketch {
|
||||
sk, _ := newSketch(16, false)
|
||||
return sk
|
||||
}
|
||||
|
||||
// newSketch returns a HyperLogLog Sketch with 2^precision registers
|
||||
func newSketch(precision uint8, sparse bool) (*Sketch, error) {
|
||||
if precision < 4 || precision > 18 {
|
||||
return nil, fmt.Errorf("p has to be >= 4 and <= 18")
|
||||
}
|
||||
m := uint32(math.Pow(2, float64(precision)))
|
||||
s := &Sketch{
|
||||
m: m,
|
||||
p: precision,
|
||||
alpha: alpha(float64(m)),
|
||||
}
|
||||
if sparse {
|
||||
s.tmpSet = set{}
|
||||
s.sparseList = newCompressedList()
|
||||
} else {
|
||||
s.regs = newRegisters(m)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (sk *Sketch) sparse() bool {
|
||||
return sk.sparseList != nil
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of sk.
|
||||
func (sk *Sketch) Clone() *Sketch {
|
||||
return &Sketch{
|
||||
b: sk.b,
|
||||
p: sk.p,
|
||||
m: sk.m,
|
||||
alpha: sk.alpha,
|
||||
tmpSet: sk.tmpSet.Clone(),
|
||||
sparseList: sk.sparseList.Clone(),
|
||||
regs: sk.regs.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
// Converts to normal if the sparse list is too large.
|
||||
func (sk *Sketch) maybeToNormal() {
|
||||
if uint32(len(sk.tmpSet))*100 > sk.m {
|
||||
sk.mergeSparse()
|
||||
if uint32(sk.sparseList.Len()) > sk.m {
|
||||
sk.toNormal()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Merge takes another Sketch and combines it with Sketch h.
|
||||
// If Sketch h is using the sparse Sketch, it will be converted
|
||||
// to the normal Sketch.
|
||||
func (sk *Sketch) Merge(other *Sketch) error {
|
||||
if other == nil {
|
||||
// Nothing to do
|
||||
return nil
|
||||
}
|
||||
cpOther := other.Clone()
|
||||
|
||||
if sk.p != cpOther.p {
|
||||
return errors.New("precisions must be equal")
|
||||
}
|
||||
|
||||
if sk.sparse() && other.sparse() {
|
||||
for k := range other.tmpSet {
|
||||
sk.tmpSet.add(k)
|
||||
}
|
||||
for iter := other.sparseList.Iter(); iter.HasNext(); {
|
||||
sk.tmpSet.add(iter.Next())
|
||||
}
|
||||
sk.maybeToNormal()
|
||||
return nil
|
||||
}
|
||||
|
||||
if sk.sparse() {
|
||||
sk.toNormal()
|
||||
}
|
||||
|
||||
if cpOther.sparse() {
|
||||
for k := range cpOther.tmpSet {
|
||||
i, r := decodeHash(k, cpOther.p, pp)
|
||||
sk.insert(i, r)
|
||||
}
|
||||
|
||||
for iter := cpOther.sparseList.Iter(); iter.HasNext(); {
|
||||
i, r := decodeHash(iter.Next(), cpOther.p, pp)
|
||||
sk.insert(i, r)
|
||||
}
|
||||
} else {
|
||||
if sk.b < cpOther.b {
|
||||
sk.regs.rebase(cpOther.b - sk.b)
|
||||
sk.b = cpOther.b
|
||||
} else {
|
||||
cpOther.regs.rebase(sk.b - cpOther.b)
|
||||
cpOther.b = sk.b
|
||||
}
|
||||
|
||||
for i, v := range cpOther.regs.tailcuts {
|
||||
v1 := v.get(0)
|
||||
if v1 > sk.regs.get(uint32(i)*2) {
|
||||
sk.regs.set(uint32(i)*2, v1)
|
||||
}
|
||||
v2 := v.get(1)
|
||||
if v2 > sk.regs.get(1+uint32(i)*2) {
|
||||
sk.regs.set(1+uint32(i)*2, v2)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert from sparse Sketch to dense Sketch.
|
||||
func (sk *Sketch) toNormal() {
|
||||
if len(sk.tmpSet) > 0 {
|
||||
sk.mergeSparse()
|
||||
}
|
||||
|
||||
sk.regs = newRegisters(sk.m)
|
||||
for iter := sk.sparseList.Iter(); iter.HasNext(); {
|
||||
i, r := decodeHash(iter.Next(), sk.p, pp)
|
||||
sk.insert(i, r)
|
||||
}
|
||||
|
||||
sk.tmpSet = nil
|
||||
sk.sparseList = nil
|
||||
}
|
||||
|
||||
func (sk *Sketch) insert(i uint32, r uint8) bool {
|
||||
changed := false
|
||||
if r-sk.b >= capacity {
|
||||
//overflow
|
||||
db := sk.regs.min()
|
||||
if db > 0 {
|
||||
sk.b += db
|
||||
sk.regs.rebase(db)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if r > sk.b {
|
||||
val := r - sk.b
|
||||
if c1 := capacity - 1; c1 < val {
|
||||
val = c1
|
||||
}
|
||||
|
||||
if val > sk.regs.get(i) {
|
||||
sk.regs.set(i, val)
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// Insert adds element e to sketch
|
||||
func (sk *Sketch) Insert(e []byte) bool {
|
||||
x := hash(e)
|
||||
return sk.InsertHash(x)
|
||||
}
|
||||
|
||||
// InsertHash adds hash x to sketch
|
||||
func (sk *Sketch) InsertHash(x uint64) bool {
|
||||
if sk.sparse() {
|
||||
changed := sk.tmpSet.add(encodeHash(x, sk.p, pp))
|
||||
if !changed {
|
||||
return false
|
||||
}
|
||||
if uint32(len(sk.tmpSet))*100 > sk.m/2 {
|
||||
sk.mergeSparse()
|
||||
if uint32(sk.sparseList.Len()) > sk.m/2 {
|
||||
sk.toNormal()
|
||||
}
|
||||
}
|
||||
return true
|
||||
} else {
|
||||
i, r := getPosVal(x, sk.p)
|
||||
return sk.insert(uint32(i), r)
|
||||
}
|
||||
}
|
||||
|
||||
// Estimate returns the cardinality of the Sketch
|
||||
func (sk *Sketch) Estimate() uint64 {
|
||||
if sk.sparse() {
|
||||
sk.mergeSparse()
|
||||
return uint64(linearCount(mp, mp-sk.sparseList.count))
|
||||
}
|
||||
|
||||
sum, ez := sk.regs.sumAndZeros(sk.b)
|
||||
m := float64(sk.m)
|
||||
var est float64
|
||||
|
||||
var beta func(float64) float64
|
||||
if sk.p < 16 {
|
||||
beta = beta14
|
||||
} else {
|
||||
beta = beta16
|
||||
}
|
||||
|
||||
if sk.b == 0 {
|
||||
est = (sk.alpha * m * (m - ez) / (sum + beta(ez)))
|
||||
} else {
|
||||
est = (sk.alpha * m * m / sum)
|
||||
}
|
||||
|
||||
return uint64(est + 0.5)
|
||||
}
|
||||
|
||||
func (sk *Sketch) mergeSparse() {
|
||||
if len(sk.tmpSet) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
keys := make(uint64Slice, 0, len(sk.tmpSet))
|
||||
for k := range sk.tmpSet {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Sort(keys)
|
||||
|
||||
newList := newCompressedList()
|
||||
for iter, i := sk.sparseList.Iter(), 0; iter.HasNext() || i < len(keys); {
|
||||
if !iter.HasNext() {
|
||||
newList.Append(keys[i])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if i >= len(keys) {
|
||||
newList.Append(iter.Next())
|
||||
continue
|
||||
}
|
||||
|
||||
x1, x2 := iter.Peek(), keys[i]
|
||||
if x1 == x2 {
|
||||
newList.Append(iter.Next())
|
||||
i++
|
||||
} else if x1 > x2 {
|
||||
newList.Append(x2)
|
||||
i++
|
||||
} else {
|
||||
newList.Append(iter.Next())
|
||||
}
|
||||
}
|
||||
|
||||
sk.sparseList = newList
|
||||
sk.tmpSet = set{}
|
||||
}
|
||||
|
||||
// MarshalBinary implements the encoding.BinaryMarshaler interface.
|
||||
func (sk *Sketch) MarshalBinary() (data []byte, err error) {
|
||||
// Marshal a version marker.
|
||||
data = append(data, version)
|
||||
// Marshal p.
|
||||
data = append(data, sk.p)
|
||||
// Marshal b
|
||||
data = append(data, sk.b)
|
||||
|
||||
if sk.sparse() {
|
||||
// It's using the sparse Sketch.
|
||||
data = append(data, byte(1))
|
||||
|
||||
// Add the tmp_set
|
||||
tsdata, err := sk.tmpSet.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data = append(data, tsdata...)
|
||||
|
||||
// Add the sparse Sketch
|
||||
sdata, err := sk.sparseList.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(data, sdata...), nil
|
||||
}
|
||||
|
||||
// It's using the dense Sketch.
|
||||
data = append(data, byte(0))
|
||||
|
||||
// Add the dense sketch Sketch.
|
||||
sz := len(sk.regs.tailcuts)
|
||||
data = append(data, []byte{
|
||||
byte(sz >> 24),
|
||||
byte(sz >> 16),
|
||||
byte(sz >> 8),
|
||||
byte(sz),
|
||||
}...)
|
||||
|
||||
// Marshal each element in the list.
|
||||
for i := 0; i < len(sk.regs.tailcuts); i++ {
|
||||
data = append(data, byte(sk.regs.tailcuts[i]))
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// ErrorTooShort is an error that UnmarshalBinary try to parse too short
|
||||
// binary.
|
||||
var ErrorTooShort = errors.New("too short binary")
|
||||
|
||||
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
|
||||
func (sk *Sketch) UnmarshalBinary(data []byte) error {
|
||||
if len(data) < 8 {
|
||||
return ErrorTooShort
|
||||
}
|
||||
|
||||
// Unmarshal version. We may need this in the future if we make
|
||||
// non-compatible changes.
|
||||
_ = data[0]
|
||||
|
||||
// Unmarshal p.
|
||||
p := data[1]
|
||||
|
||||
// Unmarshal b.
|
||||
sk.b = data[2]
|
||||
|
||||
// Determine if we need a sparse Sketch
|
||||
sparse := data[3] == byte(1)
|
||||
|
||||
// Make a newSketch Sketch if the precision doesn't match or if the Sketch was used
|
||||
if sk.p != p || sk.regs != nil || len(sk.tmpSet) > 0 || (sk.sparseList != nil && sk.sparseList.Len() > 0) {
|
||||
newh, err := newSketch(p, sparse)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newh.b = sk.b
|
||||
*sk = *newh
|
||||
}
|
||||
|
||||
// h is now initialised with the correct p. We just need to fill the
|
||||
// rest of the details out.
|
||||
if sparse {
|
||||
// Using the sparse Sketch.
|
||||
|
||||
// Unmarshal the tmp_set.
|
||||
tssz := binary.BigEndian.Uint32(data[4:8])
|
||||
sk.tmpSet = make(map[uint32]struct{}, tssz)
|
||||
|
||||
// We need to unmarshal tssz values in total, and each value requires us
|
||||
// to read 4 bytes.
|
||||
tsLastByte := int((tssz * 4) + 8)
|
||||
for i := 8; i < tsLastByte; i += 4 {
|
||||
k := binary.BigEndian.Uint32(data[i : i+4])
|
||||
sk.tmpSet[k] = struct{}{}
|
||||
}
|
||||
|
||||
// Unmarshal the sparse Sketch.
|
||||
return sk.sparseList.UnmarshalBinary(data[tsLastByte:])
|
||||
}
|
||||
|
||||
// Using the dense Sketch.
|
||||
sk.sparseList = nil
|
||||
sk.tmpSet = nil
|
||||
dsz := binary.BigEndian.Uint32(data[4:8])
|
||||
sk.regs = newRegisters(dsz * 2)
|
||||
data = data[8:]
|
||||
|
||||
for i, val := range data {
|
||||
sk.regs.tailcuts[i] = reg(val)
|
||||
if uint8(sk.regs.tailcuts[i]<<4>>4) > 0 {
|
||||
sk.regs.nz--
|
||||
}
|
||||
if uint8(sk.regs.tailcuts[i]>>4) > 0 {
|
||||
sk.regs.nz--
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
+114
@@ -0,0 +1,114 @@
|
||||
package hyperloglog
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
type reg uint8
|
||||
type tailcuts []reg
|
||||
|
||||
type registers struct {
|
||||
tailcuts
|
||||
nz uint32
|
||||
}
|
||||
|
||||
func (r *reg) set(offset, val uint8) bool {
|
||||
var isZero bool
|
||||
if offset == 0 {
|
||||
isZero = *r < 16
|
||||
tmpVal := uint8((*r) << 4 >> 4)
|
||||
*r = reg(tmpVal | (val << 4))
|
||||
} else {
|
||||
isZero = *r&0x0f == 0
|
||||
tmpVal := uint8((*r) >> 4 << 4)
|
||||
*r = reg(tmpVal | val)
|
||||
}
|
||||
return isZero
|
||||
}
|
||||
|
||||
func (r *reg) get(offset uint8) uint8 {
|
||||
if offset == 0 {
|
||||
return uint8((*r) >> 4)
|
||||
}
|
||||
return uint8((*r) << 4 >> 4)
|
||||
}
|
||||
|
||||
func newRegisters(size uint32) *registers {
|
||||
return ®isters{
|
||||
tailcuts: make(tailcuts, size/2),
|
||||
nz: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *registers) clone() *registers {
|
||||
if rs == nil {
|
||||
return nil
|
||||
}
|
||||
tc := make([]reg, len(rs.tailcuts))
|
||||
copy(tc, rs.tailcuts)
|
||||
return ®isters{
|
||||
tailcuts: tc,
|
||||
nz: rs.nz,
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *registers) rebase(delta uint8) {
|
||||
nz := uint32(len(rs.tailcuts)) * 2
|
||||
for i := range rs.tailcuts {
|
||||
for j := uint8(0); j < 2; j++ {
|
||||
val := rs.tailcuts[i].get(j)
|
||||
if val >= delta {
|
||||
rs.tailcuts[i].set(j, val-delta)
|
||||
if val-delta > 0 {
|
||||
nz--
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rs.nz = nz
|
||||
}
|
||||
|
||||
func (rs *registers) set(i uint32, val uint8) {
|
||||
offset, index := uint8(i)&1, i/2
|
||||
if rs.tailcuts[index].set(offset, val) {
|
||||
rs.nz--
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *registers) get(i uint32) uint8 {
|
||||
offset, index := uint8(i)&1, i/2
|
||||
return rs.tailcuts[index].get(offset)
|
||||
}
|
||||
|
||||
func (rs *registers) sumAndZeros(base uint8) (res, ez float64) {
|
||||
for _, r := range rs.tailcuts {
|
||||
for j := uint8(0); j < 2; j++ {
|
||||
v := float64(base + r.get(j))
|
||||
if v == 0 {
|
||||
ez++
|
||||
}
|
||||
res += 1.0 / math.Pow(2.0, v)
|
||||
}
|
||||
}
|
||||
rs.nz = uint32(ez)
|
||||
return res, ez
|
||||
}
|
||||
|
||||
func (rs *registers) min() uint8 {
|
||||
if rs.nz > 0 {
|
||||
return 0
|
||||
}
|
||||
min := uint8(math.MaxUint8)
|
||||
for _, r := range rs.tailcuts {
|
||||
if r == 0 || min == 0 {
|
||||
return 0
|
||||
}
|
||||
if val := uint8(r << 4 >> 4); val < min {
|
||||
min = val
|
||||
}
|
||||
if val := uint8(r >> 4); val < min {
|
||||
min = val
|
||||
}
|
||||
}
|
||||
return min
|
||||
}
|
||||
+92
@@ -0,0 +1,92 @@
|
||||
package hyperloglog
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
func getIndex(k uint32, p, pp uint8) uint32 {
|
||||
if k&1 == 1 {
|
||||
return bextr32(k, 32-p, p)
|
||||
}
|
||||
return bextr32(k, pp-p+1, p)
|
||||
}
|
||||
|
||||
// Encode a hash to be used in the sparse representation.
|
||||
func encodeHash(x uint64, p, pp uint8) uint32 {
|
||||
idx := uint32(bextr(x, 64-pp, pp))
|
||||
if bextr(x, 64-pp, pp-p) == 0 {
|
||||
zeros := bits.LeadingZeros64((bextr(x, 0, 64-pp)<<pp)|(1<<pp-1)) + 1
|
||||
return idx<<7 | uint32(zeros<<1) | 1
|
||||
}
|
||||
return idx << 1
|
||||
}
|
||||
|
||||
// Decode a hash from the sparse representation.
|
||||
func decodeHash(k uint32, p, pp uint8) (uint32, uint8) {
|
||||
var r uint8
|
||||
if k&1 == 1 {
|
||||
r = uint8(bextr32(k, 1, 6)) + pp - p
|
||||
} else {
|
||||
// We can use the 64bit clz implementation and reduce the result
|
||||
// by 32 to get a clz for a 32bit word.
|
||||
r = uint8(bits.LeadingZeros64(uint64(k<<(32-pp+p-1))) - 31) // -32 + 1
|
||||
}
|
||||
return getIndex(k, p, pp), r
|
||||
}
|
||||
|
||||
type set map[uint32]struct{}
|
||||
|
||||
func (s set) add(v uint32) bool {
|
||||
_, ok := s[v]
|
||||
if ok {
|
||||
return false
|
||||
}
|
||||
s[v] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s set) Clone() set {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
newS := make(map[uint32]struct{}, len(s))
|
||||
for k, v := range s {
|
||||
newS[k] = v
|
||||
}
|
||||
return newS
|
||||
}
|
||||
|
||||
func (s set) MarshalBinary() (data []byte, err error) {
|
||||
// 4 bytes for the size of the set, and 4 bytes for each key.
|
||||
// list.
|
||||
data = make([]byte, 0, 4+(4*len(s)))
|
||||
|
||||
// Length of the set. We only need 32 bits because the size of the set
|
||||
// couldn't exceed that on 32 bit architectures.
|
||||
sl := len(s)
|
||||
data = append(data, []byte{
|
||||
byte(sl >> 24),
|
||||
byte(sl >> 16),
|
||||
byte(sl >> 8),
|
||||
byte(sl),
|
||||
}...)
|
||||
|
||||
// Marshal each element in the set.
|
||||
for k := range s {
|
||||
data = append(data, []byte{
|
||||
byte(k >> 24),
|
||||
byte(k >> 16),
|
||||
byte(k >> 8),
|
||||
byte(k),
|
||||
}...)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
type uint64Slice []uint32
|
||||
|
||||
func (p uint64Slice) Len() int { return len(p) }
|
||||
func (p uint64Slice) Less(i, j int) bool { return p[i] < p[j] }
|
||||
func (p uint64Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
+69
@@ -0,0 +1,69 @@
|
||||
package hyperloglog
|
||||
|
||||
import (
|
||||
"github.com/alicebob/miniredis/v2/metro"
|
||||
"math"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
var hash = hashFunc
|
||||
|
||||
func beta14(ez float64) float64 {
|
||||
zl := math.Log(ez + 1)
|
||||
return -0.370393911*ez +
|
||||
0.070471823*zl +
|
||||
0.17393686*math.Pow(zl, 2) +
|
||||
0.16339839*math.Pow(zl, 3) +
|
||||
-0.09237745*math.Pow(zl, 4) +
|
||||
0.03738027*math.Pow(zl, 5) +
|
||||
-0.005384159*math.Pow(zl, 6) +
|
||||
0.00042419*math.Pow(zl, 7)
|
||||
}
|
||||
|
||||
func beta16(ez float64) float64 {
|
||||
zl := math.Log(ez + 1)
|
||||
return -0.37331876643753059*ez +
|
||||
-1.41704077448122989*zl +
|
||||
0.40729184796612533*math.Pow(zl, 2) +
|
||||
1.56152033906584164*math.Pow(zl, 3) +
|
||||
-0.99242233534286128*math.Pow(zl, 4) +
|
||||
0.26064681399483092*math.Pow(zl, 5) +
|
||||
-0.03053811369682807*math.Pow(zl, 6) +
|
||||
0.00155770210179105*math.Pow(zl, 7)
|
||||
}
|
||||
|
||||
func alpha(m float64) float64 {
|
||||
switch m {
|
||||
case 16:
|
||||
return 0.673
|
||||
case 32:
|
||||
return 0.697
|
||||
case 64:
|
||||
return 0.709
|
||||
}
|
||||
return 0.7213 / (1 + 1.079/m)
|
||||
}
|
||||
|
||||
func getPosVal(x uint64, p uint8) (uint64, uint8) {
|
||||
i := bextr(x, 64-p, p) // {x63,...,x64-p}
|
||||
w := x<<p | 1<<(p-1) // {x63-p,...,x0}
|
||||
rho := uint8(bits.LeadingZeros64(w)) + 1
|
||||
return i, rho
|
||||
}
|
||||
|
||||
func linearCount(m uint32, v uint32) float64 {
|
||||
fm := float64(m)
|
||||
return fm * math.Log(fm/float64(v))
|
||||
}
|
||||
|
||||
func bextr(v uint64, start, length uint8) uint64 {
|
||||
return (v >> start) & ((1 << length) - 1)
|
||||
}
|
||||
|
||||
func bextr32(v uint32, start, length uint8) uint32 {
|
||||
return (v >> start) & ((1 << length) - 1)
|
||||
}
|
||||
|
||||
func hashFunc(e []byte) uint64 {
|
||||
return metro.Hash64(e, 1337)
|
||||
}
|
||||
+83
@@ -0,0 +1,83 @@
|
||||
package miniredis
|
||||
|
||||
// Translate the 'KEYS' or 'PSUBSCRIBE' argument ('foo*', 'f??', &c.) into a regexp.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// patternRE compiles a glob to a regexp. Returns nil if the given
|
||||
// pattern will never match anything.
|
||||
// The general strategy is to sandwich all non-meta characters between \Q...\E.
|
||||
func patternRE(k string) *regexp.Regexp {
|
||||
re := bytes.Buffer{}
|
||||
re.WriteString(`(?s)^\Q`)
|
||||
for i := 0; i < len(k); i++ {
|
||||
p := k[i]
|
||||
switch p {
|
||||
case '*':
|
||||
re.WriteString(`\E.*\Q`)
|
||||
case '?':
|
||||
re.WriteString(`\E.\Q`)
|
||||
case '[':
|
||||
charClass := bytes.Buffer{}
|
||||
i++
|
||||
for ; i < len(k); i++ {
|
||||
if k[i] == ']' {
|
||||
break
|
||||
}
|
||||
if k[i] == '\\' {
|
||||
if i == len(k)-1 {
|
||||
// Ends with a '\'. U-huh.
|
||||
return nil
|
||||
}
|
||||
charClass.WriteByte(k[i])
|
||||
i++
|
||||
charClass.WriteByte(k[i])
|
||||
continue
|
||||
}
|
||||
charClass.WriteByte(k[i])
|
||||
}
|
||||
if charClass.Len() == 0 {
|
||||
// '[]' is valid in Redis, but matches nothing.
|
||||
return nil
|
||||
}
|
||||
re.WriteString(`\E[`)
|
||||
re.Write(charClass.Bytes())
|
||||
re.WriteString(`]\Q`)
|
||||
|
||||
case '\\':
|
||||
if i == len(k)-1 {
|
||||
// Ends with a '\'. U-huh.
|
||||
return nil
|
||||
}
|
||||
// Forget the \, keep the next char.
|
||||
i++
|
||||
re.WriteByte(k[i])
|
||||
continue
|
||||
default:
|
||||
re.WriteByte(p)
|
||||
}
|
||||
}
|
||||
re.WriteString(`\E$`)
|
||||
return regexp.MustCompile(re.String())
|
||||
}
|
||||
|
||||
// matchKeys filters only matching keys.
|
||||
// The returned boolean is whether the match pattern was valid
|
||||
func matchKeys(keys []string, match string) ([]string, bool) {
|
||||
re := patternRE(match)
|
||||
if re == nil {
|
||||
// Special case: the given pattern won't match anything or is invalid.
|
||||
return nil, false
|
||||
}
|
||||
var res []string
|
||||
for _, k := range keys {
|
||||
if !re.MatchString(k) {
|
||||
continue
|
||||
}
|
||||
res = append(res, k)
|
||||
}
|
||||
return res, true
|
||||
}
|
||||
+281
@@ -0,0 +1,281 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
var luaRedisConstants = map[string]lua.LValue{
|
||||
"LOG_DEBUG": lua.LNumber(0),
|
||||
"LOG_VERBOSE": lua.LNumber(1),
|
||||
"LOG_NOTICE": lua.LNumber(2),
|
||||
"LOG_WARNING": lua.LNumber(3),
|
||||
}
|
||||
|
||||
func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFunction, map[string]lua.LValue) {
|
||||
mkCall := func(failFast bool) func(l *lua.LState) int {
|
||||
// one server.Ctx for a single Lua run
|
||||
pCtx := &connCtx{}
|
||||
if getCtx(c).authenticated {
|
||||
pCtx.authenticated = true
|
||||
}
|
||||
pCtx.nested = true
|
||||
pCtx.nestedSHA = sha
|
||||
pCtx.selectedDB = getCtx(c).selectedDB
|
||||
|
||||
return func(l *lua.LState) int {
|
||||
top := l.GetTop()
|
||||
if top == 0 {
|
||||
l.Error(lua.LString(fmt.Sprintf("Please specify at least one argument for this redis lib call script: %s, &c.", sha)), 1)
|
||||
return 0
|
||||
}
|
||||
var args []string
|
||||
for i := 1; i <= top; i++ {
|
||||
switch a := l.Get(i).(type) {
|
||||
case lua.LNumber:
|
||||
args = append(args, a.String())
|
||||
case lua.LString:
|
||||
args = append(args, string(a))
|
||||
default:
|
||||
l.Error(lua.LString(fmt.Sprintf("Lua redis lib command arguments must be strings or integers script: %s, &c.", sha)), 1)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
if len(args) == 0 {
|
||||
l.Error(lua.LString(msgNotFromScripts(sha)), 1)
|
||||
return 0
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
wr := bufio.NewWriter(buf)
|
||||
peer := server.NewPeer(wr)
|
||||
peer.Ctx = pCtx
|
||||
srv.Dispatch(peer, args)
|
||||
wr.Flush()
|
||||
|
||||
res, err := server.ParseReply(bufio.NewReader(buf))
|
||||
if err != nil {
|
||||
if failFast {
|
||||
// call() mode
|
||||
if strings.Contains(err.Error(), "ERR unknown command") {
|
||||
l.Error(lua.LString(fmt.Sprintf("Unknown Redis command called from script script: %s, &c.", sha)), 1)
|
||||
} else {
|
||||
l.Error(lua.LString(err.Error()), 1)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
// pcall() mode
|
||||
l.Push(lua.LNil)
|
||||
return 1
|
||||
}
|
||||
|
||||
if res == nil {
|
||||
l.Push(lua.LFalse)
|
||||
} else {
|
||||
switch r := res.(type) {
|
||||
case int64:
|
||||
l.Push(lua.LNumber(r))
|
||||
case int:
|
||||
l.Push(lua.LNumber(r))
|
||||
case []uint8:
|
||||
l.Push(lua.LString(string(r)))
|
||||
case []interface{}:
|
||||
l.Push(redisToLua(l, r))
|
||||
case server.Simple:
|
||||
l.Push(luaStatusReply(string(r)))
|
||||
case string:
|
||||
l.Push(lua.LString(r))
|
||||
case error:
|
||||
l.Error(lua.LString(r.Error()), 1)
|
||||
return 0
|
||||
default:
|
||||
panic(fmt.Sprintf("type not handled (%T)", r))
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]lua.LGFunction{
|
||||
"call": mkCall(true),
|
||||
"pcall": mkCall(false),
|
||||
"error_reply": func(l *lua.LState) int {
|
||||
v := l.Get(1)
|
||||
msg, ok := v.(lua.LString)
|
||||
if !ok {
|
||||
l.Error(lua.LString("wrong number or type of arguments"), 1)
|
||||
return 0
|
||||
}
|
||||
res := &lua.LTable{}
|
||||
parts := strings.SplitN(msg.String(), " ", 2)
|
||||
// '-' at the beginging will be added as a part of error response
|
||||
if parts[0] != "" && parts[0][0] == '-' {
|
||||
parts[0] = parts[0][1:]
|
||||
}
|
||||
var final_msg string
|
||||
if len(parts) == 2 {
|
||||
final_msg = fmt.Sprintf("%s %s", parts[0], parts[1])
|
||||
} else {
|
||||
final_msg = fmt.Sprintf("ERR %s", parts[0])
|
||||
}
|
||||
res.RawSetString("err", lua.LString(final_msg))
|
||||
l.Push(res)
|
||||
return 1
|
||||
},
|
||||
"log": func(l *lua.LState) int {
|
||||
level := l.CheckInt(1)
|
||||
msg := l.CheckString(2)
|
||||
_, _ = level, msg
|
||||
// do nothing by default. To see logs uncomment:
|
||||
// fmt.Printf("%v: %v", level, msg)
|
||||
return 0
|
||||
},
|
||||
"status_reply": func(l *lua.LState) int {
|
||||
v := l.Get(1)
|
||||
msg, ok := v.(lua.LString)
|
||||
if !ok {
|
||||
l.Error(lua.LString("wrong number or type of arguments"), 1)
|
||||
return 0
|
||||
}
|
||||
res := luaStatusReply(string(msg))
|
||||
l.Push(res)
|
||||
return 1
|
||||
},
|
||||
"sha1hex": func(l *lua.LState) int {
|
||||
top := l.GetTop()
|
||||
if top != 1 {
|
||||
l.Error(lua.LString("wrong number of arguments"), 1)
|
||||
return 0
|
||||
}
|
||||
msg := lua.LVAsString(l.Get(1))
|
||||
l.Push(lua.LString(sha1Hex(msg)))
|
||||
return 1
|
||||
},
|
||||
"replicate_commands": func(l *lua.LState) int {
|
||||
// always succeeds since 7.0.0
|
||||
l.Push(lua.LTrue)
|
||||
return 1
|
||||
},
|
||||
"set_repl": func(l *lua.LState) int {
|
||||
top := l.GetTop()
|
||||
if top != 1 {
|
||||
l.Error(lua.LString("wrong number of arguments"), 1)
|
||||
return 0
|
||||
}
|
||||
// ignored
|
||||
return 1
|
||||
},
|
||||
"setresp": func(l *lua.LState) int {
|
||||
level := l.CheckInt(1)
|
||||
toresp3 := false
|
||||
switch level {
|
||||
case 2:
|
||||
toresp3 = false
|
||||
case 3:
|
||||
toresp3 = true
|
||||
default:
|
||||
l.Error(lua.LString("RESP version must be 2 or 3"), 1)
|
||||
return 0
|
||||
}
|
||||
c.SwitchResp3 = &toresp3
|
||||
return 0
|
||||
},
|
||||
}, luaRedisConstants
|
||||
}
|
||||
|
||||
func luaToRedis(l *lua.LState, c *server.Peer, value lua.LValue) {
|
||||
if value == nil {
|
||||
c.WriteNull()
|
||||
return
|
||||
}
|
||||
|
||||
switch t := value.(type) {
|
||||
case *lua.LNilType:
|
||||
c.WriteNull()
|
||||
case lua.LBool:
|
||||
if lua.LVAsBool(value) {
|
||||
c.WriteInt(1)
|
||||
} else {
|
||||
c.WriteNull()
|
||||
}
|
||||
case lua.LNumber:
|
||||
c.WriteInt(int(lua.LVAsNumber(value)))
|
||||
case lua.LString:
|
||||
s := lua.LVAsString(value)
|
||||
c.WriteBulk(s)
|
||||
case *lua.LTable:
|
||||
// special case for tables with an 'err' or 'ok' field
|
||||
// note: according to the docs this only counts when 'err' or 'ok' is
|
||||
// the only field.
|
||||
if s := t.RawGetString("err"); s.Type() != lua.LTNil {
|
||||
c.WriteError(s.String())
|
||||
return
|
||||
}
|
||||
if s := t.RawGetString("ok"); s.Type() != lua.LTNil {
|
||||
c.WriteInline(s.String())
|
||||
return
|
||||
}
|
||||
|
||||
result := []lua.LValue{}
|
||||
for j := 1; true; j++ {
|
||||
val := l.GetTable(value, lua.LNumber(j))
|
||||
if val == nil {
|
||||
result = append(result, val)
|
||||
continue
|
||||
}
|
||||
|
||||
if val.Type() == lua.LTNil {
|
||||
break
|
||||
}
|
||||
|
||||
result = append(result, val)
|
||||
}
|
||||
|
||||
c.WriteLen(len(result))
|
||||
for _, r := range result {
|
||||
luaToRedis(l, c, r)
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("wat: %T", t))
|
||||
}
|
||||
}
|
||||
|
||||
func redisToLua(l *lua.LState, res []interface{}) *lua.LTable {
|
||||
rettb := l.NewTable()
|
||||
for _, e := range res {
|
||||
var v lua.LValue
|
||||
if e == nil {
|
||||
v = lua.LFalse
|
||||
} else {
|
||||
switch et := e.(type) {
|
||||
case int:
|
||||
v = lua.LNumber(et)
|
||||
case int64:
|
||||
v = lua.LNumber(et)
|
||||
case []uint8:
|
||||
v = lua.LString(string(et))
|
||||
case []interface{}:
|
||||
v = redisToLua(l, et)
|
||||
case string:
|
||||
v = lua.LString(et)
|
||||
default:
|
||||
// TODO: oops?
|
||||
v = lua.LString(e.(string))
|
||||
}
|
||||
}
|
||||
l.RawSet(rettb, lua.LNumber(rettb.Len()+1), v)
|
||||
}
|
||||
return rettb
|
||||
}
|
||||
|
||||
func luaStatusReply(msg string) *lua.LTable {
|
||||
tab := &lua.LTable{}
|
||||
tab.RawSetString("ok", lua.LString(msg))
|
||||
return tab
|
||||
}
|
||||
+24
@@ -0,0 +1,24 @@
|
||||
This package is a mechanical translation of the reference C++ code for
|
||||
MetroHash, available at https://github.com/jandrewrogers/MetroHash
|
||||
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 Damian Gryski
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
+1
@@ -0,0 +1 @@
|
||||
This is a partial copy of github.com/dgryski/go-metro.
|
||||
+87
@@ -0,0 +1,87 @@
|
||||
package metro
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
func Hash64(buffer []byte, seed uint64) uint64 {
|
||||
|
||||
const (
|
||||
k0 = 0xD6D018F5
|
||||
k1 = 0xA2AA033B
|
||||
k2 = 0x62992FC1
|
||||
k3 = 0x30BC5B29
|
||||
)
|
||||
|
||||
ptr := buffer
|
||||
|
||||
hash := (seed + k2) * k0
|
||||
|
||||
if len(ptr) >= 32 {
|
||||
v := [4]uint64{hash, hash, hash, hash}
|
||||
|
||||
for len(ptr) >= 32 {
|
||||
v[0] += binary.LittleEndian.Uint64(ptr[:8]) * k0
|
||||
v[0] = rotate_right(v[0], 29) + v[2]
|
||||
v[1] += binary.LittleEndian.Uint64(ptr[8:16]) * k1
|
||||
v[1] = rotate_right(v[1], 29) + v[3]
|
||||
v[2] += binary.LittleEndian.Uint64(ptr[16:24]) * k2
|
||||
v[2] = rotate_right(v[2], 29) + v[0]
|
||||
v[3] += binary.LittleEndian.Uint64(ptr[24:32]) * k3
|
||||
v[3] = rotate_right(v[3], 29) + v[1]
|
||||
ptr = ptr[32:]
|
||||
}
|
||||
|
||||
v[2] ^= rotate_right(((v[0]+v[3])*k0)+v[1], 37) * k1
|
||||
v[3] ^= rotate_right(((v[1]+v[2])*k1)+v[0], 37) * k0
|
||||
v[0] ^= rotate_right(((v[0]+v[2])*k0)+v[3], 37) * k1
|
||||
v[1] ^= rotate_right(((v[1]+v[3])*k1)+v[2], 37) * k0
|
||||
hash += v[0] ^ v[1]
|
||||
}
|
||||
|
||||
if len(ptr) >= 16 {
|
||||
v0 := hash + (binary.LittleEndian.Uint64(ptr[:8]) * k2)
|
||||
v0 = rotate_right(v0, 29) * k3
|
||||
v1 := hash + (binary.LittleEndian.Uint64(ptr[8:16]) * k2)
|
||||
v1 = rotate_right(v1, 29) * k3
|
||||
v0 ^= rotate_right(v0*k0, 21) + v1
|
||||
v1 ^= rotate_right(v1*k3, 21) + v0
|
||||
hash += v1
|
||||
ptr = ptr[16:]
|
||||
}
|
||||
|
||||
if len(ptr) >= 8 {
|
||||
hash += binary.LittleEndian.Uint64(ptr[:8]) * k3
|
||||
ptr = ptr[8:]
|
||||
hash ^= rotate_right(hash, 55) * k1
|
||||
}
|
||||
|
||||
if len(ptr) >= 4 {
|
||||
hash += uint64(binary.LittleEndian.Uint32(ptr[:4])) * k3
|
||||
hash ^= rotate_right(hash, 26) * k1
|
||||
ptr = ptr[4:]
|
||||
}
|
||||
|
||||
if len(ptr) >= 2 {
|
||||
hash += uint64(binary.LittleEndian.Uint16(ptr[:2])) * k3
|
||||
ptr = ptr[2:]
|
||||
hash ^= rotate_right(hash, 48) * k1
|
||||
}
|
||||
|
||||
if len(ptr) >= 1 {
|
||||
hash += uint64(ptr[0]) * k3
|
||||
hash ^= rotate_right(hash, 37) * k1
|
||||
}
|
||||
|
||||
hash ^= rotate_right(hash, 28)
|
||||
hash *= k0
|
||||
hash ^= rotate_right(hash, 29)
|
||||
|
||||
return hash
|
||||
}
|
||||
|
||||
func Hash64Str(buffer string, seed uint64) uint64 {
|
||||
return Hash64([]byte(buffer), seed)
|
||||
}
|
||||
|
||||
func rotate_right(v uint64, k uint) uint64 {
|
||||
return (v >> k) | (v << (64 - k))
|
||||
}
|
||||
+759
@@ -0,0 +1,759 @@
|
||||
// Package miniredis is a pure Go Redis test server, for use in Go unittests.
|
||||
// There are no dependencies on system binaries, and every server you start
|
||||
// will be empty.
|
||||
//
|
||||
// import "github.com/alicebob/miniredis/v2"
|
||||
//
|
||||
// Start a server with `s := miniredis.RunT(t)`, it'll be shutdown via a t.Cleanup().
|
||||
// Or do everything manual: `s, err := miniredis.Run(); defer s.Close()`
|
||||
//
|
||||
// Point your Redis client to `s.Addr()` or `s.Host(), s.Port()`.
|
||||
//
|
||||
// Set keys directly via s.Set(...) and similar commands, or use a Redis client.
|
||||
//
|
||||
// For direct use you can select a Redis database with either `s.Select(12);
|
||||
// s.Get("foo")` or `s.DB(12).Get("foo")`.
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/proto"
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
var DumpMaxLineLen = 60
|
||||
|
||||
type hashKey map[string]string
|
||||
type listKey []string
|
||||
type setKey map[string]struct{}
|
||||
|
||||
// RedisDB holds a single (numbered) Redis database.
|
||||
type RedisDB struct {
|
||||
master *Miniredis // pointer to the lock in Miniredis
|
||||
id int // db id
|
||||
keys map[string]string // Master map of keys with their type
|
||||
stringKeys map[string]string // GET/SET &c. keys
|
||||
hashKeys map[string]hashKey // MGET/MSET &c. keys
|
||||
listKeys map[string]listKey // LPUSH &c. keys
|
||||
setKeys map[string]setKey // SADD &c. keys
|
||||
hllKeys map[string]*hll // PFADD &c. keys
|
||||
sortedsetKeys map[string]sortedSet // ZADD &c. keys
|
||||
streamKeys map[string]*streamKey // XADD &c. keys
|
||||
ttl map[string]time.Duration // effective TTL values
|
||||
lru map[string]time.Time // last recently used ( read or written to )
|
||||
keyVersion map[string]uint // used to watch values
|
||||
}
|
||||
|
||||
// Miniredis is a Redis server implementation.
|
||||
type Miniredis struct {
|
||||
sync.Mutex
|
||||
srv *server.Server
|
||||
port int
|
||||
passwords map[string]string // username password
|
||||
dbs map[int]*RedisDB
|
||||
selectedDB int // DB id used in the direct Get(), Set() &c.
|
||||
scripts map[string]string // sha1 -> lua src
|
||||
signal *sync.Cond
|
||||
now time.Time // time.Now() if not set.
|
||||
subscribers map[*Subscriber]struct{}
|
||||
rand *rand.Rand
|
||||
Ctx context.Context
|
||||
CtxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
type txCmd func(*server.Peer, *connCtx)
|
||||
|
||||
// database id + key combo
|
||||
type dbKey struct {
|
||||
db int
|
||||
key string
|
||||
}
|
||||
|
||||
// connCtx has all state for a single connection.
|
||||
// (this struct was named before context.Context existed)
|
||||
type connCtx struct {
|
||||
selectedDB int // selected DB
|
||||
authenticated bool // auth enabled and a valid AUTH seen
|
||||
transaction []txCmd // transaction callbacks. Or nil.
|
||||
dirtyTransaction bool // any error during QUEUEing
|
||||
watch map[dbKey]uint // WATCHed keys
|
||||
subscriber *Subscriber // client is in PUBSUB mode if not nil
|
||||
nested bool // this is called via Lua
|
||||
nestedSHA string // set to the SHA of the nesting function
|
||||
}
|
||||
|
||||
// NewMiniRedis makes a new, non-started, Miniredis object.
|
||||
func NewMiniRedis() *Miniredis {
|
||||
m := Miniredis{
|
||||
dbs: map[int]*RedisDB{},
|
||||
scripts: map[string]string{},
|
||||
subscribers: map[*Subscriber]struct{}{},
|
||||
}
|
||||
m.Ctx, m.CtxCancel = context.WithCancel(context.Background())
|
||||
m.signal = sync.NewCond(&m)
|
||||
return &m
|
||||
}
|
||||
|
||||
func newRedisDB(id int, m *Miniredis) RedisDB {
|
||||
return RedisDB{
|
||||
id: id,
|
||||
master: m,
|
||||
keys: map[string]string{},
|
||||
lru: map[string]time.Time{},
|
||||
stringKeys: map[string]string{},
|
||||
hashKeys: map[string]hashKey{},
|
||||
listKeys: map[string]listKey{},
|
||||
setKeys: map[string]setKey{},
|
||||
hllKeys: map[string]*hll{},
|
||||
sortedsetKeys: map[string]sortedSet{},
|
||||
streamKeys: map[string]*streamKey{},
|
||||
ttl: map[string]time.Duration{},
|
||||
keyVersion: map[string]uint{},
|
||||
}
|
||||
}
|
||||
|
||||
// Run creates and Start()s a Miniredis.
|
||||
func Run() (*Miniredis, error) {
|
||||
m := NewMiniRedis()
|
||||
return m, m.Start()
|
||||
}
|
||||
|
||||
// Run creates and Start()s a Miniredis, TLS version.
|
||||
func RunTLS(cfg *tls.Config) (*Miniredis, error) {
|
||||
m := NewMiniRedis()
|
||||
return m, m.StartTLS(cfg)
|
||||
}
|
||||
|
||||
// Tester is a minimal version of a testing.T
|
||||
type Tester interface {
|
||||
Fatalf(string, ...interface{})
|
||||
Cleanup(func())
|
||||
Logf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// RunT start a new miniredis, pass it a testing.T. It also registers the cleanup after your test is done.
|
||||
func RunT(t Tester) *Miniredis {
|
||||
m := NewMiniRedis()
|
||||
if err := m.Start(); err != nil {
|
||||
t.Fatalf("could not start miniredis: %s", err)
|
||||
// not reached
|
||||
}
|
||||
t.Cleanup(m.Close)
|
||||
return m
|
||||
}
|
||||
|
||||
func runWithClient(t Tester) (*Miniredis, *proto.Client) {
|
||||
m := RunT(t)
|
||||
|
||||
c, err := proto.Dial(m.Addr())
|
||||
if err != nil {
|
||||
t.Fatalf("could not connect to miniredis: %s", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err = c.Close(); err != nil {
|
||||
t.Logf("error closing connection to miniredis: %s", err)
|
||||
}
|
||||
})
|
||||
|
||||
return m, c
|
||||
}
|
||||
|
||||
// Start starts a server. It listens on a random port on localhost. See also
|
||||
// Addr().
|
||||
func (m *Miniredis) Start() error {
|
||||
s, err := server.NewServer(fmt.Sprintf("127.0.0.1:%d", m.port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.start(s)
|
||||
}
|
||||
|
||||
// Start starts a server, TLS version.
|
||||
func (m *Miniredis) StartTLS(cfg *tls.Config) error {
|
||||
s, err := server.NewServerTLS(fmt.Sprintf("127.0.0.1:%d", m.port), cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.start(s)
|
||||
}
|
||||
|
||||
// StartAddr runs miniredis with a given addr. Examples: "127.0.0.1:6379",
|
||||
// ":6379", or "127.0.0.1:0"
|
||||
func (m *Miniredis) StartAddr(addr string) error {
|
||||
s, err := server.NewServer(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.start(s)
|
||||
}
|
||||
|
||||
// StartAddrTLS runs miniredis with a given addr, TLS version.
|
||||
func (m *Miniredis) StartAddrTLS(addr string, cfg *tls.Config) error {
|
||||
s, err := server.NewServerTLS(addr, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.start(s)
|
||||
}
|
||||
|
||||
func (m *Miniredis) start(s *server.Server) error {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
m.srv = s
|
||||
m.port = s.Addr().Port
|
||||
|
||||
commandsConnection(m)
|
||||
commandsGeneric(m)
|
||||
commandsServer(m)
|
||||
commandsString(m)
|
||||
commandsHash(m)
|
||||
commandsList(m)
|
||||
commandsPubsub(m)
|
||||
commandsSet(m)
|
||||
commandsSortedSet(m)
|
||||
commandsStream(m)
|
||||
commandsTransaction(m)
|
||||
commandsScripting(m)
|
||||
commandsGeo(m)
|
||||
commandsCluster(m)
|
||||
commandsHll(m)
|
||||
commandsClient(m)
|
||||
commandsObject(m)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restart restarts a Close()d server on the same port. Values will be
|
||||
// preserved.
|
||||
func (m *Miniredis) Restart() error {
|
||||
return m.Start()
|
||||
}
|
||||
|
||||
// Close shuts down a Miniredis.
|
||||
func (m *Miniredis) Close() {
|
||||
m.Lock()
|
||||
|
||||
if m.srv == nil {
|
||||
m.Unlock()
|
||||
return
|
||||
}
|
||||
srv := m.srv
|
||||
m.srv = nil
|
||||
m.CtxCancel()
|
||||
m.Unlock()
|
||||
|
||||
// the OnDisconnect callbacks can lock m, so run Close() outside the lock.
|
||||
srv.Close()
|
||||
|
||||
}
|
||||
|
||||
// RequireAuth makes every connection need to AUTH first. This is the old 'AUTH [password] command.
|
||||
// Remove it by setting an empty string.
|
||||
func (m *Miniredis) RequireAuth(pw string) {
|
||||
m.RequireUserAuth("default", pw)
|
||||
}
|
||||
|
||||
// Add a username/password, for use with 'AUTH [username] [password]'.
|
||||
// There are currently no access controls for commands implemented.
|
||||
// Disable access for the user with an empty password.
|
||||
func (m *Miniredis) RequireUserAuth(username, pw string) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
if m.passwords == nil {
|
||||
m.passwords = map[string]string{}
|
||||
}
|
||||
if pw == "" {
|
||||
delete(m.passwords, username)
|
||||
return
|
||||
}
|
||||
m.passwords[username] = pw
|
||||
}
|
||||
|
||||
// DB returns a DB by ID.
|
||||
func (m *Miniredis) DB(i int) *RedisDB {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.db(i)
|
||||
}
|
||||
|
||||
// get DB. No locks!
|
||||
func (m *Miniredis) db(i int) *RedisDB {
|
||||
if db, ok := m.dbs[i]; ok {
|
||||
return db
|
||||
}
|
||||
db := newRedisDB(i, m) // main miniredis has our mutex.
|
||||
m.dbs[i] = &db
|
||||
return &db
|
||||
}
|
||||
|
||||
// SwapDB swaps DBs by IDs.
|
||||
func (m *Miniredis) SwapDB(i, j int) bool {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.swapDB(i, j)
|
||||
}
|
||||
|
||||
// swap DB. No locks!
|
||||
func (m *Miniredis) swapDB(i, j int) bool {
|
||||
db1 := m.db(i)
|
||||
db2 := m.db(j)
|
||||
|
||||
db1.id = j
|
||||
db2.id = i
|
||||
|
||||
m.dbs[i] = db2
|
||||
m.dbs[j] = db1
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Addr returns '127.0.0.1:12345'. Can be given to a Dial(). See also Host()
|
||||
// and Port(), which return the same things.
|
||||
func (m *Miniredis) Addr() string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.srv.Addr().String()
|
||||
}
|
||||
|
||||
// Host returns the host part of Addr().
|
||||
func (m *Miniredis) Host() string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.srv.Addr().IP.String()
|
||||
}
|
||||
|
||||
// Port returns the (random) port part of Addr().
|
||||
func (m *Miniredis) Port() string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return strconv.Itoa(m.srv.Addr().Port)
|
||||
}
|
||||
|
||||
// CommandCount returns the number of processed commands.
|
||||
func (m *Miniredis) CommandCount() int {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return int(m.srv.TotalCommands())
|
||||
}
|
||||
|
||||
// CurrentConnectionCount returns the number of currently connected clients.
|
||||
func (m *Miniredis) CurrentConnectionCount() int {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.srv.ClientsLen()
|
||||
}
|
||||
|
||||
// TotalConnectionCount returns the number of client connections since server start.
|
||||
func (m *Miniredis) TotalConnectionCount() int {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return int(m.srv.TotalConnections())
|
||||
}
|
||||
|
||||
// FastForward decreases all TTLs by the given duration. All TTLs <= 0 will be
|
||||
// expired.
|
||||
func (m *Miniredis) FastForward(duration time.Duration) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
for _, db := range m.dbs {
|
||||
db.fastForward(duration)
|
||||
}
|
||||
}
|
||||
|
||||
// Server returns the underlying server to allow custom commands to be implemented
|
||||
func (m *Miniredis) Server() *server.Server {
|
||||
return m.srv
|
||||
}
|
||||
|
||||
// Dump returns a text version of the selected DB, usable for debugging.
|
||||
//
|
||||
// Dump limits the maximum length of each key:value to "DumpMaxLineLen" characters.
|
||||
// To increase that, call something like:
|
||||
//
|
||||
// miniredis.DumpMaxLineLen = 1024
|
||||
// mr, _ = miniredis.Run()
|
||||
// mr.Dump()
|
||||
func (m *Miniredis) Dump() string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
var (
|
||||
maxLen = DumpMaxLineLen
|
||||
indent = " "
|
||||
db = m.db(m.selectedDB)
|
||||
r = ""
|
||||
v = func(s string) string {
|
||||
suffix := ""
|
||||
if len(s) > maxLen {
|
||||
suffix = fmt.Sprintf("...(%d)", len(s))
|
||||
s = s[:maxLen-len(suffix)]
|
||||
}
|
||||
return fmt.Sprintf("%q%s", s, suffix)
|
||||
}
|
||||
)
|
||||
|
||||
for _, k := range db.allKeys() {
|
||||
r += fmt.Sprintf("- %s\n", k)
|
||||
t := db.t(k)
|
||||
switch t {
|
||||
case keyTypeString:
|
||||
r += fmt.Sprintf("%s%s\n", indent, v(db.stringKeys[k]))
|
||||
case keyTypeHash:
|
||||
for _, hk := range db.hashFields(k) {
|
||||
r += fmt.Sprintf("%s%s: %s\n", indent, hk, v(db.hashGet(k, hk)))
|
||||
}
|
||||
case keyTypeList:
|
||||
for _, lk := range db.listKeys[k] {
|
||||
r += fmt.Sprintf("%s%s\n", indent, v(lk))
|
||||
}
|
||||
case keyTypeSet:
|
||||
for _, mk := range db.setMembers(k) {
|
||||
r += fmt.Sprintf("%s%s\n", indent, v(mk))
|
||||
}
|
||||
case keyTypeSortedSet:
|
||||
for _, el := range db.ssetElements(k) {
|
||||
r += fmt.Sprintf("%s%f: %s\n", indent, el.score, v(el.member))
|
||||
}
|
||||
case keyTypeStream:
|
||||
for _, entry := range db.streamKeys[k].entries {
|
||||
r += fmt.Sprintf("%s%s\n", indent, entry.ID)
|
||||
ev := entry.Values
|
||||
for i := 0; i < len(ev)/2; i++ {
|
||||
r += fmt.Sprintf("%s%s%s: %s\n", indent, indent, v(ev[2*i]), v(ev[2*i+1]))
|
||||
}
|
||||
}
|
||||
case keyTypeHll:
|
||||
for _, entry := range db.hllKeys {
|
||||
r += fmt.Sprintf("%s%s\n", indent, v(string(entry.Bytes())))
|
||||
}
|
||||
default:
|
||||
r += fmt.Sprintf("%s(a %s, fixme!)\n", indent, t)
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// SetTime sets the time against which EXPIREAT values are compared, and the
|
||||
// time used in stream entry IDs. Will use time.Now() if this is not set.
|
||||
func (m *Miniredis) SetTime(t time.Time) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
m.now = t
|
||||
}
|
||||
|
||||
// make every command return this message. For example:
|
||||
//
|
||||
// LOADING Redis is loading the dataset in memory
|
||||
// MASTERDOWN Link with MASTER is down and replica-serve-stale-data is set to 'no'.
|
||||
//
|
||||
// Clear it with an empty string. Don't add newlines.
|
||||
func (m *Miniredis) SetError(msg string) {
|
||||
cb := server.Hook(nil)
|
||||
if msg != "" {
|
||||
cb = func(c *server.Peer, cmd string, args ...string) bool {
|
||||
c.WriteError(msg)
|
||||
return true
|
||||
}
|
||||
}
|
||||
m.srv.SetPreHook(cb)
|
||||
}
|
||||
|
||||
// isValidCMD returns true if command is valid and can be executed.
|
||||
func (m *Miniredis) isValidCMD(c *server.Peer, cmd string) bool {
|
||||
if !m.handleAuth(c) {
|
||||
return false
|
||||
}
|
||||
if m.checkPubsub(c, cmd) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// handleAuth returns false if connection has no access. It sends the reply.
|
||||
func (m *Miniredis) handleAuth(c *server.Peer) bool {
|
||||
if getCtx(c).nested {
|
||||
return true
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
if len(m.passwords) == 0 {
|
||||
return true
|
||||
}
|
||||
if !getCtx(c).authenticated {
|
||||
c.WriteError("NOAUTH Authentication required.")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// handlePubsub sends an error to the user if the connection is in PUBSUB mode.
|
||||
// It'll return true if it did.
|
||||
func (m *Miniredis) checkPubsub(c *server.Peer, cmd string) bool {
|
||||
if getCtx(c).nested {
|
||||
return false
|
||||
}
|
||||
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
ctx := getCtx(c)
|
||||
if ctx.subscriber == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
prefix := "ERR "
|
||||
if strings.ToLower(cmd) == "exec" {
|
||||
prefix = "EXECABORT Transaction discarded because of: "
|
||||
}
|
||||
c.WriteError(fmt.Sprintf(
|
||||
"%sCan't execute '%s': only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT are allowed in this context",
|
||||
prefix,
|
||||
strings.ToLower(cmd),
|
||||
))
|
||||
return true
|
||||
}
|
||||
|
||||
func getCtx(c *server.Peer) *connCtx {
|
||||
if c.Ctx == nil {
|
||||
c.Ctx = &connCtx{}
|
||||
}
|
||||
return c.Ctx.(*connCtx)
|
||||
}
|
||||
|
||||
func startTx(ctx *connCtx) {
|
||||
ctx.transaction = []txCmd{}
|
||||
ctx.dirtyTransaction = false
|
||||
}
|
||||
|
||||
func stopTx(ctx *connCtx) {
|
||||
ctx.transaction = nil
|
||||
unwatch(ctx)
|
||||
}
|
||||
|
||||
func inTx(ctx *connCtx) bool {
|
||||
return ctx.transaction != nil
|
||||
}
|
||||
|
||||
func addTxCmd(ctx *connCtx, cb txCmd) {
|
||||
ctx.transaction = append(ctx.transaction, cb)
|
||||
}
|
||||
|
||||
func watch(db *RedisDB, ctx *connCtx, key string) {
|
||||
if ctx.watch == nil {
|
||||
ctx.watch = map[dbKey]uint{}
|
||||
}
|
||||
ctx.watch[dbKey{db: db.id, key: key}] = db.keyVersion[key] // Can be 0.
|
||||
}
|
||||
|
||||
func unwatch(ctx *connCtx) {
|
||||
ctx.watch = nil
|
||||
}
|
||||
|
||||
// setDirty can be called even when not in an tx. Is an no-op then.
|
||||
func setDirty(c *server.Peer) {
|
||||
if c.Ctx == nil {
|
||||
// No transaction. Not relevant.
|
||||
return
|
||||
}
|
||||
getCtx(c).dirtyTransaction = true
|
||||
}
|
||||
|
||||
func (m *Miniredis) addSubscriber(s *Subscriber) {
|
||||
m.subscribers[s] = struct{}{}
|
||||
}
|
||||
|
||||
// closes and remove the subscriber.
|
||||
func (m *Miniredis) removeSubscriber(s *Subscriber) {
|
||||
_, ok := m.subscribers[s]
|
||||
delete(m.subscribers, s)
|
||||
if ok {
|
||||
s.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Miniredis) publish(c, msg string) int {
|
||||
n := 0
|
||||
for s := range m.subscribers {
|
||||
n += s.Publish(c, msg)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// enter 'subscribed state', or return the existing one.
|
||||
func (m *Miniredis) subscribedState(c *server.Peer) *Subscriber {
|
||||
ctx := getCtx(c)
|
||||
sub := ctx.subscriber
|
||||
if sub != nil {
|
||||
return sub
|
||||
}
|
||||
|
||||
sub = newSubscriber()
|
||||
m.addSubscriber(sub)
|
||||
|
||||
c.OnDisconnect(func() {
|
||||
m.Lock()
|
||||
m.removeSubscriber(sub)
|
||||
m.Unlock()
|
||||
})
|
||||
|
||||
ctx.subscriber = sub
|
||||
|
||||
go monitorPublish(c, sub.publish)
|
||||
go monitorPpublish(c, sub.ppublish)
|
||||
|
||||
return sub
|
||||
}
|
||||
|
||||
// whenever the p?sub count drops to 0 subscribed state should be stopped, and
|
||||
// all redis commands are allowed again.
|
||||
func endSubscriber(m *Miniredis, c *server.Peer) {
|
||||
ctx := getCtx(c)
|
||||
if sub := ctx.subscriber; sub != nil {
|
||||
m.removeSubscriber(sub) // will Close() the sub
|
||||
}
|
||||
ctx.subscriber = nil
|
||||
}
|
||||
|
||||
// Start a new pubsub subscriber. It can (un) subscribe to channels and
|
||||
// patterns, and has a channel to get published messages. Close it with
|
||||
// Close().
|
||||
// Does not close itself when there are no subscriptions left.
|
||||
func (m *Miniredis) NewSubscriber() *Subscriber {
|
||||
sub := newSubscriber()
|
||||
|
||||
m.Lock()
|
||||
m.addSubscriber(sub)
|
||||
m.Unlock()
|
||||
|
||||
return sub
|
||||
}
|
||||
|
||||
func (m *Miniredis) allSubscribers() []*Subscriber {
|
||||
var subs []*Subscriber
|
||||
for s := range m.subscribers {
|
||||
subs = append(subs, s)
|
||||
}
|
||||
return subs
|
||||
}
|
||||
|
||||
func (m *Miniredis) Seed(seed int) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// m.rand is not safe for concurrent use.
|
||||
m.rand = rand.New(rand.NewSource(int64(seed)))
|
||||
}
|
||||
|
||||
func (m *Miniredis) randIntn(n int) int {
|
||||
if m.rand == nil {
|
||||
return rand.Intn(n)
|
||||
}
|
||||
return m.rand.Intn(n)
|
||||
}
|
||||
|
||||
// shuffle shuffles a list of strings. Kinda.
|
||||
func (m *Miniredis) shuffle(l []string) {
|
||||
for range l {
|
||||
i := m.randIntn(len(l))
|
||||
j := m.randIntn(len(l))
|
||||
l[i], l[j] = l[j], l[i]
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Miniredis) effectiveNow() time.Time {
|
||||
if !m.now.IsZero() {
|
||||
return m.now
|
||||
}
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
// convert a unixtimestamp to a duration, to use an absolute time as TTL.
|
||||
// d can be either time.Second or time.Millisecond.
|
||||
func (m *Miniredis) at(i int, d time.Duration) time.Duration {
|
||||
var ts time.Time
|
||||
switch d {
|
||||
case time.Millisecond:
|
||||
ts = time.Unix(int64(i/1000), 1000000*int64(i%1000))
|
||||
case time.Second:
|
||||
ts = time.Unix(int64(i), 0)
|
||||
default:
|
||||
panic("invalid time unit (d). Fixme!")
|
||||
}
|
||||
now := m.effectiveNow()
|
||||
return ts.Sub(now)
|
||||
}
|
||||
|
||||
// copy does not mind if dst already exists.
|
||||
func (m *Miniredis) copy(
|
||||
srcDB *RedisDB, src string,
|
||||
destDB *RedisDB, dst string,
|
||||
) error {
|
||||
if !srcDB.exists(src) {
|
||||
return ErrKeyNotFound
|
||||
}
|
||||
|
||||
switch srcDB.t(src) {
|
||||
case keyTypeString:
|
||||
destDB.stringKeys[dst] = srcDB.stringKeys[src]
|
||||
case keyTypeHash:
|
||||
destDB.hashKeys[dst] = copyHashKey(srcDB.hashKeys[src])
|
||||
case keyTypeList:
|
||||
destDB.listKeys[dst] = copyListKey(srcDB.listKeys[src])
|
||||
case keyTypeSet:
|
||||
destDB.setKeys[dst] = copySetKey(srcDB.setKeys[src])
|
||||
case keyTypeSortedSet:
|
||||
destDB.sortedsetKeys[dst] = copySortedSet(srcDB.sortedsetKeys[src])
|
||||
case keyTypeStream:
|
||||
destDB.streamKeys[dst] = srcDB.streamKeys[src].copy()
|
||||
case keyTypeHll:
|
||||
destDB.hllKeys[dst] = srcDB.hllKeys[src].copy()
|
||||
default:
|
||||
panic("missing case")
|
||||
}
|
||||
destDB.keys[dst] = srcDB.keys[src]
|
||||
destDB.incr(dst)
|
||||
if v, ok := srcDB.ttl[src]; ok {
|
||||
destDB.ttl[dst] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyHashKey(orig hashKey) hashKey {
|
||||
cpy := hashKey{}
|
||||
for k, v := range orig {
|
||||
cpy[k] = v
|
||||
}
|
||||
return cpy
|
||||
}
|
||||
|
||||
func copyListKey(orig listKey) listKey {
|
||||
cpy := make(listKey, len(orig))
|
||||
copy(cpy, orig)
|
||||
return cpy
|
||||
}
|
||||
|
||||
func copySetKey(orig setKey) setKey {
|
||||
cpy := setKey{}
|
||||
for k, v := range orig {
|
||||
cpy[k] = v
|
||||
}
|
||||
return cpy
|
||||
}
|
||||
|
||||
func copySortedSet(orig sortedSet) sortedSet {
|
||||
cpy := sortedSet{}
|
||||
for k, v := range orig {
|
||||
cpy[k] = v
|
||||
}
|
||||
return cpy
|
||||
}
|
||||
+60
@@ -0,0 +1,60 @@
|
||||
package miniredis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2/server"
|
||||
)
|
||||
|
||||
// optInt parses an int option in a command.
|
||||
// Writes "invalid integer" error to c if it's not a valid integer. Returns
|
||||
// whether or not things were okay.
|
||||
func optInt(c *server.Peer, src string, dest *int) bool {
|
||||
return optIntErr(c, src, dest, msgInvalidInt)
|
||||
}
|
||||
|
||||
func optIntErr(c *server.Peer, src string, dest *int, errMsg string) bool {
|
||||
n, err := strconv.Atoi(src)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(errMsg)
|
||||
return false
|
||||
}
|
||||
*dest = n
|
||||
return true
|
||||
}
|
||||
|
||||
// optIntSimple sets dest or returns an error
|
||||
func optIntSimple(src string, dest *int) error {
|
||||
n, err := strconv.Atoi(src)
|
||||
if err != nil {
|
||||
return errors.New(msgInvalidInt)
|
||||
}
|
||||
*dest = n
|
||||
return nil
|
||||
}
|
||||
|
||||
func optDuration(c *server.Peer, src string, dest *time.Duration) bool {
|
||||
n, err := strconv.ParseFloat(src, 64)
|
||||
if err != nil {
|
||||
setDirty(c)
|
||||
c.WriteError(msgInvalidTimeout)
|
||||
return false
|
||||
}
|
||||
if n < 0 {
|
||||
setDirty(c)
|
||||
c.WriteError(msgTimeoutNegative)
|
||||
return false
|
||||
}
|
||||
if math.IsInf(n, 0) {
|
||||
setDirty(c)
|
||||
c.WriteError(msgTimeoutIsOutOfRange)
|
||||
return false
|
||||
}
|
||||
|
||||
*dest = time.Duration(n*1_000_000) * time.Microsecond
|
||||
return true
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user