Size computation for allocation may overflow (#99)

* Size computation for allocation may overflow

Performing calculations involving the size of potentially large strings or slices can result in an overflow (for signed integer types) or a wraparound (for unsigned types). An overflow causes the result of the calculation to become negative, while a wraparound results in a small (positive) number.
This commit is contained in:
2025-12-08 11:22:28 +00:00
committed by GitHub
parent 56051779ee
commit a750c4f5b9
93 changed files with 10500 additions and 443 deletions
+1 -1
View File
@@ -18,6 +18,6 @@ jobs:
pr-checks:
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
with:
go-version: "1.24"
go-version: "1.24.11"
coverage-threshold: 70
secrets: inherit
+1 -1
View File
@@ -17,5 +17,5 @@ jobs:
release:
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
with:
go-version: "1.24"
go-version: "1.24.11"
secrets: inherit
+1
View File
@@ -787,6 +787,7 @@ func (mm *TaskMemoryMonitor) ForceGC() (before, after TaskMemoryStats, err error
}
if mm.logger != nil {
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
freed := int64(before.HeapAlloc) - int64(after.HeapAlloc)
mm.logger.Infof("Forced GC: freed %d bytes (%.2f MB)", freed, float64(freed)/(1024*1024))
}
+1
View File
@@ -105,6 +105,7 @@ func (l *ConfigLoader) loadFile(path string) (*UnifiedConfig, error) {
}
// Read the file with validated path
// #nosec G304 -- path is validated via filepath.Abs above
data, err := os.ReadFile(absPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", absPath, err)
+1
View File
@@ -252,6 +252,7 @@ func (m *ConfigMigrator) MigrateFile(filePath string) (*UnifiedConfig, error) {
}
// Read the file with validated path
// #nosec G304 -- path is validated via filepath.Abs above
data, err := os.ReadFile(absPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
+1
View File
@@ -348,6 +348,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
filePath := r.credentialsFilePath()
// #nosec G304 -- path is constructed from trusted config values via credentialsFilePath()
data, err := os.ReadFile(filePath)
if err != nil {
if os.IsNotExist(err) {
+1
View File
@@ -538,6 +538,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
delay = float64(re.config.MaxDelay)
}
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
if re.config.EnableJitter {
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0)
delay += jitter
+1 -1
View File
@@ -6,7 +6,7 @@ require (
github.com/alicebob/miniredis/v2 v2.35.0
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
github.com/redis/go-redis/v9 v9.14.0
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.10.0
golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1
+2 -2
View File
@@ -20,8 +20,8 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
+2 -2
View File
@@ -76,7 +76,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
// Test connectivity
if err := backend.Ping(context.Background()); err != nil {
pool.Close()
_ = pool.Close()
return nil, fmt.Errorf("failed to ping Redis: %w", err)
}
@@ -263,7 +263,7 @@ func (r *RedisBackend) Clear(ctx context.Context) error {
if err != nil {
continue
}
conn.Do("DEL", key) // Best effort, ignore errors
_, _ = conn.Do("DEL", key) // Best effort, ignore errors
}
return nil
+23 -22
View File
@@ -82,7 +82,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
// Reuse existing connection - validate if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
// Connection is stale, close it and try again
conn.Close()
_ = conn.Close()
p.totalConns.Add(-1)
continue
}
@@ -94,6 +94,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
default:
// No available connection, create new one if under limit
// #nosec G115 -- MaxConnections is a small config value that fits in int32
if p.totalConns.Load() < int32(p.config.MaxConnections) {
conn, err = p.createConnection()
if err != nil {
@@ -115,7 +116,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
case conn = <-p.connections:
// Validate connection if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
conn.Close()
_ = conn.Close()
p.totalConns.Add(-1)
continue
}
@@ -144,7 +145,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
p.activeConns.Add(-1)
if p.closed.Load() || conn.closed.Load() {
conn.Close()
_ = conn.Close()
p.totalConns.Add(-1)
return
}
@@ -155,7 +156,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
// Successfully returned to pool
default:
// Pool full, close connection
conn.Close()
_ = conn.Close()
p.totalConns.Add(-1)
}
}
@@ -173,7 +174,7 @@ func (p *ConnectionPool) Close() error {
// Close all pooled connections
for conn := range p.connections {
conn.Close()
_ = conn.Close()
}
return nil
@@ -212,7 +213,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
// Authenticate if password is provided
if p.config.Password != "" {
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
redisConn.Close()
_ = redisConn.Close()
return nil, fmt.Errorf("authentication failed: %w", err)
}
}
@@ -220,7 +221,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
// Select database
if p.config.DB != 0 {
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
redisConn.Close()
_ = redisConn.Close()
return nil, fmt.Errorf("failed to select database: %w", err)
}
}
@@ -246,15 +247,15 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
c.mu.Lock()
defer c.mu.Unlock()
// Build command arguments
// Check for overflow: ensure len(args)+1 doesn't cause allocation overflow
// Limit to a safe value that prevents integer overflow in allocation size calculation
// (capacity * sizeof(string) must fit in int/size_t)
argsLen := len(args)
const maxSafeArgs = (1 << 20) - 1 // 1M args is already absurdly large for Redis commands
if argsLen < 0 || argsLen > maxSafeArgs {
return nil, errors.New("too many arguments")
// Validate argument count to prevent integer overflow in slice operations
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
const maxSafeArgs = (1 << 20) - 1
if len(args) > maxSafeArgs {
return nil, errors.New("too many arguments: exceeds maximum safe count")
}
// Build command arguments
// Validate total argument size to prevent memory exhaustion
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
totalBytes := len(command)
for _, s := range args {
@@ -267,13 +268,13 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
return nil, errors.New("total argument size exceeds maximum allowed")
}
}
cmdArgs := make([]string, 0, argsLen+1)
cmdArgs = append(cmdArgs, command)
cmdArgs = append(cmdArgs, args...)
// Build command slice: prepend command to args
// Using append avoids arithmetic on potentially large len(args)
cmdArgs := append([]string{command}, args...)
// Set write timeout
if c.writeTimeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
// Write command (using pooled writer for memory efficiency)
@@ -287,7 +288,7 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
// Set read timeout
if c.readTimeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
// Read response (using pooled reader for memory efficiency)
@@ -328,8 +329,8 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
// Set a read deadline for the ping
if conn.conn != nil {
conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
defer conn.conn.SetReadDeadline(time.Time{}) // Clear deadline
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
}
_, err := conn.Do("PING")
+3
View File
@@ -158,6 +158,7 @@ func (cb *CircuitBreaker) AllowRequest() bool {
case StateHalfOpen:
// Allow limited requests in half-open state
current := cb.halfOpenRequests.Add(1)
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
return current <= int32(cb.config.HalfOpenMaxRequests)
default:
@@ -181,6 +182,7 @@ func (cb *CircuitBreaker) RecordSuccess() {
case StateHalfOpen:
// If we've had enough successful requests, close the circuit
successfulRequests := cb.halfOpenRequests.Load()
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
cb.setState(StateClosed)
cb.consecutiveFailures.Store(0)
@@ -203,6 +205,7 @@ func (cb *CircuitBreaker) RecordFailure() {
switch state {
case StateClosed:
// Check if we should open the circuit
// #nosec G115 -- MaxFailures is a small config value that fits in int32
if failures >= int32(cb.config.MaxFailures) {
cb.openCircuit()
} else if cb.config.FailureThreshold > 0 {
+2
View File
@@ -217,6 +217,7 @@ func (hc *HealthChecker) recordSuccess(latency time.Duration) {
newStatus := currentStatus
// Check if we should become healthy
// #nosec G115 -- HealthyThreshold is a small config value that fits in int32
if successes >= int32(hc.config.HealthyThreshold) {
if latency > hc.config.DegradedThreshold {
newStatus = HealthDegraded
@@ -241,6 +242,7 @@ func (hc *HealthChecker) recordFailure() {
hc.timeMu.Unlock()
// Check if we should become unhealthy
// #nosec G115 -- UnhealthyThreshold is a small config value that fits in int32
if failures >= int32(hc.config.UnhealthyThreshold) {
hc.setStatus(HealthUnhealthy)
}
+1
View File
@@ -150,6 +150,7 @@ func (h *HealthCheckBackend) IsHealthy() bool {
// recordResult records the result of an operation for health tracking
func (h *HealthCheckBackend) recordResult(success bool) {
// #nosec G115 -- threshold config values are small integers that fit in int32
if success {
fails := h.consecutiveFails.Swap(0)
oks := h.consecutiveOK.Add(1)
+2 -1
View File
@@ -304,7 +304,8 @@ func (f *Factory) createSecureTLSConfig() *tls.Config {
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
},
InsecureSkipVerify: false, // SECURITY: Always verify certificates
InsecureSkipVerify: false, // SECURITY: Always verify certificates
// #nosec G402 -- PreferServerCipherSuites is deprecated in Go 1.17+ but setting it to false is safe
PreferServerCipherSuites: false, // Let client choose best cipher
}
}
+2
View File
@@ -144,12 +144,14 @@ func getOrCreateLogFile(filename string) io.Writer {
}
// Ensure log directory exists
// #nosec G301 -- log directory needs to be readable by monitoring tools
if err := os.MkdirAll(logDir, 0755); err != nil {
// Fall back to stderr if we can't create the directory
return os.Stderr
}
filepath := logDir + "/" + filename
// #nosec G302 G304 -- log files need to be readable; path is from trusted env var
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
// Fall back to stderr if we can't open the file
+2
View File
@@ -107,6 +107,7 @@ const (
JWTPattern = `^[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+$`
// Bearer token pattern (Authorization header)
// #nosec G101 -- This is a regex pattern for validation, not a hardcoded credential
BearerTokenPattern = `^Bearer\s+([A-Za-z0-9._~+/-]+=*)$`
// Client ID pattern (alphanumeric with common separators)
@@ -119,6 +120,7 @@ const (
SessionIDPattern = `^[a-fA-F0-9]{32,128}$`
// CSRF token pattern (base64url)
// #nosec G101 -- This is a regex pattern for validation, not a hardcoded credential
CSRFTokenPattern = `^[A-Za-z0-9_-]+$`
// Nonce pattern (base64url)
+3 -1
View File
@@ -202,8 +202,10 @@ func (p *TransportPool) createTransport(config TransportConfig) *http.Transport
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
},
// #nosec G402 -- PreferServerCipherSuites is deprecated in Go 1.17+ but setting it is harmless
PreferServerCipherSuites: true,
InsecureSkipVerify: config.InsecureSkipVerify,
// #nosec G402 -- InsecureSkipVerify is configurable for testing/dev environments
InsecureSkipVerify: config.InsecureSkipVerify,
}
return &http.Transport{
+3
View File
@@ -148,6 +148,7 @@ func (cb *CircuitBreaker) allowRequest() bool {
// allowHalfOpenRequest checks if a request is allowed in half-open state
func (cb *CircuitBreaker) allowHalfOpenRequest() bool {
current := atomic.AddInt32(&cb.halfOpenRequests, 1)
// #nosec G115 -- MaxRequests is a small config value that fits in int32
if current <= int32(cb.config.MaxRequests) {
return true
}
@@ -164,6 +165,7 @@ func (cb *CircuitBreaker) recordFailure() {
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
// #nosec G115 -- FailureThreshold is a small config value that fits in int32
if state == CircuitBreakerClosed && failures >= int32(cb.config.FailureThreshold) {
cb.transitionToOpen()
} else if state == CircuitBreakerHalfOpen {
@@ -180,6 +182,7 @@ func (cb *CircuitBreaker) recordSuccess() {
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
// #nosec G115 -- SuccessThreshold is a small config value that fits in int32
if state == CircuitBreakerHalfOpen && successes >= int32(cb.config.SuccessThreshold) {
cb.transitionToClosed()
}
+1
View File
@@ -191,6 +191,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
}
// Add jitter
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
if re.config.RandomizationFactor > 0 {
jitter := delay * re.config.RandomizationFactor
minDelay := delay - jitter
+4 -1
View File
@@ -169,7 +169,10 @@ func (jwk *JWK) ToRSAPublicKey() (*rsa.PublicKey, error) {
// Pad to 8 bytes for uint64
paddedE := make([]byte, 8)
copy(paddedE[8-len(eBytes):], eBytes)
e = int(binary.BigEndian.Uint64(paddedE))
eUint64 := binary.BigEndian.Uint64(paddedE)
// RSA exponents are typically small (65537 is common), so overflow is not a concern
// #nosec G115 -- RSA public exponents are small values that fit in int
e = int(eUint64)
} else {
return nil, fmt.Errorf("exponent too large")
}
+8 -5
View File
@@ -120,8 +120,9 @@ func NewMemoryMonitor(logger *Logger, thresholds MemoryAlertThresholds) *MemoryM
alertThresholds: thresholds,
baselineHeap: memStats.HeapAlloc,
baselineGoroutines: runtime.NumGoroutine(),
lastGCTime: time.Unix(0, int64(memStats.LastGC)),
lastGCCount: memStats.NumGC,
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
lastGCTime: time.Unix(0, int64(memStats.LastGC)),
lastGCCount: memStats.NumGC,
}
}
@@ -158,9 +159,10 @@ func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
StackSysBytes: memStats.StackSys,
GCSysBytes: memStats.GCSys,
NumGoroutines: runtime.NumGoroutine(),
LastGCTime: time.Unix(0, int64(memStats.LastGC)),
GCFrequency: gcFrequency,
Timestamp: now,
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
LastGCTime: time.Unix(0, int64(memStats.LastGC)),
GCFrequency: gcFrequency,
Timestamp: now,
}
// Get application-specific stats
@@ -386,6 +388,7 @@ func (mm *MemoryMonitor) TriggerGC() {
after := mm.GetCurrentStats()
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
freedBytes := int64(before.HeapAllocBytes) - int64(after.HeapAllocBytes)
freedMB := float64(freedBytes) / (1024 * 1024)
+1
View File
@@ -63,6 +63,7 @@ func generateSecureRandomString(length int) (string, error) {
}
// Cookie names and configuration constants used for session management
// #nosec G101 -- These are cookie names, not hardcoded credentials
const (
mainCookieName = "_oidc_raczylo_m"
accessTokenCookie = "_oidc_raczylo_a"
+3 -2
View File
@@ -51,7 +51,8 @@ func NewShardedCache(numShards int, maxSize int) *ShardedCache {
}
return &ShardedCache{
shards: shards,
shards: shards,
// #nosec G115 -- numShards is validated to be positive and small (typically 32-256)
numShards: uint32(numShards),
maxPerShard: maxPerShard,
}
@@ -61,7 +62,7 @@ func NewShardedCache(numShards int, maxSize int) *ShardedCache {
// FNV-1a is fast and provides good distribution.
func (c *ShardedCache) getShard(key string) *cacheShard {
h := fnv.New32a()
h.Write([]byte(key))
_, _ = h.Write([]byte(key)) // hash.Hash.Write never returns an error
return c.shards[h.Sum32()%c.numShards]
}
+1
View File
@@ -734,6 +734,7 @@ func (r *TestSuiteRunner) RunMemoryLeakTests(t *testing.T, tests []MemoryLeakTes
}
// Check memory growth
// #nosec G115 -- memory stats are within int64 range for practical purposes
memoryGrowthBytes := int64(finalMem.Alloc) - int64(initialMem.Alloc)
memoryGrowthMB := float64(memoryGrowthBytes) / (1024 * 1024)
+4
View File
@@ -9,3 +9,7 @@ coverage.txt
**/coverage.txt
.vscode
tmp/*
*.test
# maintenanceNotifications upgrade documentation (temporary)
maintenanceNotifications/docs/
+2 -2
View File
@@ -1,8 +1,8 @@
GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | sort)
REDIS_VERSION ?= 8.2
REDIS_VERSION ?= 8.4
RE_CLUSTER ?= false
RCE_DOCKER ?= true
CLIENT_LIBS_TEST_IMAGE ?= redislabs/client-libs-test:8.2.1-pre
CLIENT_LIBS_TEST_IMAGE ?= redislabs/client-libs-test:8.4.0
docker.start:
export RE_CLUSTER=$(RE_CLUSTER) && \
+149 -14
View File
@@ -2,7 +2,7 @@
[![build workflow](https://github.com/redis/go-redis/actions/workflows/build.yml/badge.svg)](https://github.com/redis/go-redis/actions)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/redis/go-redis/v9)](https://pkg.go.dev/github.com/redis/go-redis/v9?tab=doc)
[![Documentation](https://img.shields.io/badge/redis-documentation-informational)](https://redis.uptrace.dev/)
[![Documentation](https://img.shields.io/badge/redis-documentation-informational)](https://redis.io/docs/latest/develop/clients/go/)
[![Go Report Card](https://goreportcard.com/badge/github.com/redis/go-redis/v9)](https://goreportcard.com/report/github.com/redis/go-redis/v9)
[![codecov](https://codecov.io/github/redis/go-redis/graph/badge.svg?token=tsrCZKuSSw)](https://codecov.io/github/redis/go-redis)
@@ -17,15 +17,15 @@
## Supported versions
In `go-redis` we are aiming to support the last three releases of Redis. Currently, this means we do support:
- [Redis 7.2](https://raw.githubusercontent.com/redis/redis/7.2/00-RELEASENOTES) - using Redis Stack 7.2 for modules support
- [Redis 7.4](https://raw.githubusercontent.com/redis/redis/7.4/00-RELEASENOTES) - using Redis Stack 7.4 for modules support
- [Redis 8.0](https://raw.githubusercontent.com/redis/redis/8.0/00-RELEASENOTES) - using Redis CE 8.0 where modules are included
- [Redis 8.2](https://raw.githubusercontent.com/redis/redis/8.2/00-RELEASENOTES) - using Redis CE 8.2 where modules are included
- [Redis 8.0](https://raw.githubusercontent.com/redis/redis/8.0/00-RELEASENOTES) - using Redis CE 8.0
- [Redis 8.2](https://raw.githubusercontent.com/redis/redis/8.2/00-RELEASENOTES) - using Redis CE 8.2
- [Redis 8.4](https://raw.githubusercontent.com/redis/redis/8.4/00-RELEASENOTES) - using Redis CE 8.4
Although the `go.mod` states it requires at minimum `go 1.18`, our CI is configured to run the tests against all three
versions of Redis and latest two versions of Go ([1.23](https://go.dev/doc/devel/release#go1.23.0),
[1.24](https://go.dev/doc/devel/release#go1.24.0)). We observe that some modules related test may not pass with
Redis Stack 7.2 and some commands are changed with Redis CE 8.0.
Although it is not officially supported, `go-redis/v9` should be able to work with any Redis 7.0+.
Please do refer to the documentation and the tests if you experience any issues. We do plan to update the go version
in the `go.mod` to `go 1.24` in one of the next releases.
@@ -43,10 +43,6 @@ in the `go.mod` to `go 1.24` in one of the next releases.
[Work at Redis](https://redis.com/company/careers/jobs/)
## Documentation
- [English](https://redis.uptrace.dev)
- [简体中文](https://redis.uptrace.dev/zh/)
## Resources
@@ -55,16 +51,18 @@ in the `go.mod` to `go 1.24` in one of the next releases.
- [Reference](https://pkg.go.dev/github.com/redis/go-redis/v9)
- [Examples](https://pkg.go.dev/github.com/redis/go-redis/v9#pkg-examples)
## old documentation
- [English](https://redis.uptrace.dev)
- [简体中文](https://redis.uptrace.dev/zh/)
## Ecosystem
- [Redis Mock](https://github.com/go-redis/redismock)
- [Entra ID (Azure AD)](https://github.com/redis/go-redis-entraid)
- [Distributed Locks](https://github.com/bsm/redislock)
- [Redis Cache](https://github.com/go-redis/cache)
- [Rate limiting](https://github.com/go-redis/redis_rate)
This client also works with [Kvrocks](https://github.com/apache/incubator-kvrocks), a distributed
key value NoSQL database that uses RocksDB as storage engine and is compatible with Redis protocol.
## Features
- Redis commands except QUIT and SYNC.
@@ -75,7 +73,6 @@ key value NoSQL database that uses RocksDB as storage engine and is compatible w
- [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html).
- [Redis Sentinel](https://redis.uptrace.dev/guide/go-redis-sentinel.html).
- [Redis Cluster](https://redis.uptrace.dev/guide/go-redis-cluster.html).
- [Redis Ring](https://redis.uptrace.dev/guide/ring.html).
- [Redis Performance Monitoring](https://redis.uptrace.dev/guide/redis-performance-monitoring.html).
- [Redis Probabilistic [RedisStack]](https://redis.io/docs/data-types/probabilistic/)
- [Customizable read and write buffers size.](#custom-buffer-sizes)
@@ -429,6 +426,144 @@ vals, err := rdb.Eval(ctx, "return {KEYS[1],ARGV[1]}", []string{"key"}, "hello")
res, err := rdb.Do(ctx, "set", "key", "value").Result()
```
## Typed Errors
go-redis provides typed error checking functions for common Redis errors:
```go
// Cluster and replication errors
redis.IsLoadingError(err) // Redis is loading the dataset
redis.IsReadOnlyError(err) // Write to read-only replica
redis.IsClusterDownError(err) // Cluster is down
redis.IsTryAgainError(err) // Command should be retried
redis.IsMasterDownError(err) // Master is down
redis.IsMovedError(err) // Returns (address, true) if key moved
redis.IsAskError(err) // Returns (address, true) if key being migrated
// Connection and resource errors
redis.IsMaxClientsError(err) // Maximum clients reached
redis.IsAuthError(err) // Authentication failed (NOAUTH, WRONGPASS, unauthenticated)
redis.IsPermissionError(err) // Permission denied (NOPERM)
redis.IsOOMError(err) // Out of memory (OOM)
// Transaction errors
redis.IsExecAbortError(err) // Transaction aborted (EXECABORT)
```
### Error Wrapping in Hooks
When wrapping errors in hooks, use custom error types with `Unwrap()` method (preferred) or `fmt.Errorf` with `%w`. Always call `cmd.SetErr()` to preserve error type information:
```go
// Custom error type (preferred)
type AppError struct {
Code string
RequestID string
Err error
}
func (e *AppError) Error() string {
return fmt.Sprintf("[%s] request_id=%s: %v", e.Code, e.RequestID, e.Err)
}
func (e *AppError) Unwrap() error {
return e.Err
}
// Hook implementation
func (h MyHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
err := next(ctx, cmd)
if err != nil {
// Wrap with custom error type
wrappedErr := &AppError{
Code: "REDIS_ERROR",
RequestID: getRequestID(ctx),
Err: err,
}
cmd.SetErr(wrappedErr)
return wrappedErr // Return wrapped error to preserve it
}
return nil
}
}
// Typed error detection works through wrappers
if redis.IsLoadingError(err) {
// Retry logic
}
// Extract custom error if needed
var appErr *AppError
if errors.As(err, &appErr) {
log.Printf("Request: %s", appErr.RequestID)
}
```
Alternatively, use `fmt.Errorf` with `%w`:
```go
wrappedErr := fmt.Errorf("context: %w", err)
cmd.SetErr(wrappedErr)
```
### Pipeline Hook Example
For pipeline operations, use `ProcessPipelineHook`:
```go
type PipelineLoggingHook struct{}
func (h PipelineLoggingHook) DialHook(next redis.DialHook) redis.DialHook {
return next
}
func (h PipelineLoggingHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
return next
}
func (h PipelineLoggingHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error {
start := time.Now()
// Execute the pipeline
err := next(ctx, cmds)
duration := time.Since(start)
log.Printf("Pipeline executed %d commands in %v", len(cmds), duration)
// Process individual command errors
// Note: Individual command errors are already set on each cmd by the pipeline execution
for _, cmd := range cmds {
if cmdErr := cmd.Err(); cmdErr != nil {
// Check for specific error types using typed error functions
if redis.IsAuthError(cmdErr) {
log.Printf("Auth error in pipeline command %s: %v", cmd.Name(), cmdErr)
} else if redis.IsPermissionError(cmdErr) {
log.Printf("Permission error in pipeline command %s: %v", cmd.Name(), cmdErr)
}
// Optionally wrap individual command errors to add context
// The wrapped error preserves type information through errors.As()
wrappedErr := fmt.Errorf("pipeline cmd %s failed: %w", cmd.Name(), cmdErr)
cmd.SetErr(wrappedErr)
}
}
// Return the pipeline-level error (connection errors, etc.)
// You can wrap it if needed, or return it as-is
return err
}
}
// Register the hook
rdb.AddHook(PipelineLoggingHook{})
// Use pipeline - errors are still properly typed
pipe := rdb.Pipeline()
pipe.Set(ctx, "key1", "value1", 0)
pipe.Get(ctx, "key2")
_, err := pipe.Exec(ctx)
```
## Run the test
+224
View File
@@ -1,5 +1,229 @@
# Release Notes
# 9.17.2 (2025-12-01)
## 🐛 Bug Fixes
- **Connection Pool**: Fixed critical race condition in turn management that could cause connection leaks when dial goroutines complete after request timeout ([#3626](https://github.com/redis/go-redis/pull/3626)) by [@cyningsun](https://github.com/cyningsun)
- **Context Timeout**: Improved context timeout calculation to use minimum of remaining time and DialTimeout, preventing goroutines from waiting longer than necessary ([#3626](https://github.com/redis/go-redis/pull/3626)) by [@cyningsun](https://github.com/cyningsun)
## 🧰 Maintenance
- chore(deps): bump rojopolis/spellcheck-github-actions from 0.54.0 to 0.55.0 ([#3627](https://github.com/redis/go-redis/pull/3627))
## Contributors
We'd like to thank all the contributors who worked on this release!
[@cyningsun](https://github.com/cyningsun) and [@ndyakov](https://github.com/ndyakov)
---
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.17.1...v9.17.2
# 9.17.1 (2025-11-25)
## 🐛 Bug Fixes
- add wait to keyless commands list ([#3615](https://github.com/redis/go-redis/pull/3615)) by [@marcoferrer](https://github.com/marcoferrer)
- fix(time): remove cached time optimization ([#3611](https://github.com/redis/go-redis/pull/3611)) by [@ndyakov](https://github.com/ndyakov)
## 🧰 Maintenance
- chore(deps): bump golangci/golangci-lint-action from 9.0.0 to 9.1.0 ([#3609](https://github.com/redis/go-redis/pull/3609))
- chore(deps): bump actions/checkout from 5 to 6 ([#3610](https://github.com/redis/go-redis/pull/3610))
- chore(script): fix help call in tag.sh ([#3606](https://github.com/redis/go-redis/pull/3606)) by [@ndyakov](https://github.com/ndyakov)
## Contributors
We'd like to thank all the contributors who worked on this release!
[@marcoferrer](https://github.com/marcoferrer) and [@ndyakov](https://github.com/ndyakov)
---
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.17.0...v9.17.1
# 9.17.0 (2025-11-19)
## 🚀 Highlights
### Redis 8.4 Support
Added support for Redis 8.4, including new commands and features ([#3572](https://github.com/redis/go-redis/pull/3572))
### Typed Errors
Introduced typed errors for better error handling using `errors.As` instead of string checks. Errors can now be wrapped and set to commands in hooks without breaking library functionality ([#3602](https://github.com/redis/go-redis/pull/3602))
### New Commands
- **CAS/CAD Commands**: Added support for Compare-And-Set/Compare-And-Delete operations with conditional matching (`IFEQ`, `IFNE`, `IFDEQ`, `IFDNE`) ([#3583](https://github.com/redis/go-redis/pull/3583), [#3595](https://github.com/redis/go-redis/pull/3595))
- **MSETEX**: Atomically set multiple key-value pairs with expiration options and conditional modes ([#3580](https://github.com/redis/go-redis/pull/3580))
- **XReadGroup CLAIM**: Consume both incoming and idle pending entries from streams in a single call ([#3578](https://github.com/redis/go-redis/pull/3578))
- **ACL Commands**: Added `ACLGenPass`, `ACLUsers`, and `ACLWhoAmI` ([#3576](https://github.com/redis/go-redis/pull/3576))
- **SLOWLOG Commands**: Added `SLOWLOG LEN` and `SLOWLOG RESET` ([#3585](https://github.com/redis/go-redis/pull/3585))
- **LATENCY Commands**: Added `LATENCY LATEST` and `LATENCY RESET` ([#3584](https://github.com/redis/go-redis/pull/3584))
### Search & Vector Improvements
- **Hybrid Search**: Added **EXPERIMENTAL** support for the new `FT.HYBRID` command ([#3573](https://github.com/redis/go-redis/pull/3573))
- **Vector Range**: Added `VRANGE` command for vector sets ([#3543](https://github.com/redis/go-redis/pull/3543))
- **FT.INFO Enhancements**: Added vector-specific attributes in FT.INFO response ([#3596](https://github.com/redis/go-redis/pull/3596))
### Connection Pool Improvements
- **Improved Connection Success Rate**: Implemented FIFO queue-based fairness and context pattern for connection creation to prevent premature cancellation under high concurrency ([#3518](https://github.com/redis/go-redis/pull/3518))
- **Connection State Machine**: Resolved race conditions and improved pool performance with proper state tracking ([#3559](https://github.com/redis/go-redis/pull/3559))
- **Pool Performance**: Significant performance improvements with faster semaphores, lockless hook manager, and reduced allocations (47-67% faster Get/Put operations) ([#3565](https://github.com/redis/go-redis/pull/3565))
### Metrics & Observability
- **Canceled Metric Attribute**: Added 'canceled' metrics attribute to distinguish context cancellation errors from other errors ([#3566](https://github.com/redis/go-redis/pull/3566))
## ✨ New Features
- Typed errors with wrapping support ([#3602](https://github.com/redis/go-redis/pull/3602)) by [@ndyakov](https://github.com/ndyakov)
- CAS/CAD commands (marked as experimental) ([#3583](https://github.com/redis/go-redis/pull/3583), [#3595](https://github.com/redis/go-redis/pull/3595)) by [@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis)
- MSETEX command support ([#3580](https://github.com/redis/go-redis/pull/3580)) by [@ofekshenawa](https://github.com/ofekshenawa)
- XReadGroup CLAIM argument ([#3578](https://github.com/redis/go-redis/pull/3578)) by [@ofekshenawa](https://github.com/ofekshenawa)
- ACL commands: GenPass, Users, WhoAmI ([#3576](https://github.com/redis/go-redis/pull/3576)) by [@destinyoooo](https://github.com/destinyoooo)
- SLOWLOG commands: LEN, RESET ([#3585](https://github.com/redis/go-redis/pull/3585)) by [@destinyoooo](https://github.com/destinyoooo)
- LATENCY commands: LATEST, RESET ([#3584](https://github.com/redis/go-redis/pull/3584)) by [@destinyoooo](https://github.com/destinyoooo)
- Hybrid search command (FT.HYBRID) ([#3573](https://github.com/redis/go-redis/pull/3573)) by [@htemelski-redis](https://github.com/htemelski-redis)
- Vector range command (VRANGE) ([#3543](https://github.com/redis/go-redis/pull/3543)) by [@cxljs](https://github.com/cxljs)
- Vector-specific attributes in FT.INFO ([#3596](https://github.com/redis/go-redis/pull/3596)) by [@ndyakov](https://github.com/ndyakov)
- Improved connection pool success rate with FIFO queue ([#3518](https://github.com/redis/go-redis/pull/3518)) by [@cyningsun](https://github.com/cyningsun)
- Canceled metrics attribute for context errors ([#3566](https://github.com/redis/go-redis/pull/3566)) by [@pvragov](https://github.com/pvragov)
## 🐛 Bug Fixes
- Fixed Failover Client MaintNotificationsConfig ([#3600](https://github.com/redis/go-redis/pull/3600)) by [@ajax16384](https://github.com/ajax16384)
- Fixed ACLGenPass function to use the bit parameter ([#3597](https://github.com/redis/go-redis/pull/3597)) by [@destinyoooo](https://github.com/destinyoooo)
- Return error instead of panic from commands ([#3568](https://github.com/redis/go-redis/pull/3568)) by [@dragneelfps](https://github.com/dragneelfps)
- Safety harness in `joinErrors` to prevent panic ([#3577](https://github.com/redis/go-redis/pull/3577)) by [@manisharma](https://github.com/manisharma)
## ⚡ Performance
- Connection state machine with race condition fixes ([#3559](https://github.com/redis/go-redis/pull/3559)) by [@ndyakov](https://github.com/ndyakov)
- Pool performance improvements: 47-67% faster Get/Put, 33% less memory, 50% fewer allocations ([#3565](https://github.com/redis/go-redis/pull/3565)) by [@ndyakov](https://github.com/ndyakov)
## 🧪 Testing & Infrastructure
- Updated to Redis 8.4.0 image ([#3603](https://github.com/redis/go-redis/pull/3603)) by [@ndyakov](https://github.com/ndyakov)
- Added Redis 8.4-RC1-pre to CI ([#3572](https://github.com/redis/go-redis/pull/3572)) by [@ndyakov](https://github.com/ndyakov)
- Refactored tests for idiomatic Go ([#3561](https://github.com/redis/go-redis/pull/3561), [#3562](https://github.com/redis/go-redis/pull/3562), [#3563](https://github.com/redis/go-redis/pull/3563)) by [@12ya](https://github.com/12ya)
## 👥 Contributors
We'd like to thank all the contributors who worked on this release!
[@12ya](https://github.com/12ya), [@ajax16384](https://github.com/ajax16384), [@cxljs](https://github.com/cxljs), [@cyningsun](https://github.com/cyningsun), [@destinyoooo](https://github.com/destinyoooo), [@dragneelfps](https://github.com/dragneelfps), [@htemelski-redis](https://github.com/htemelski-redis), [@manisharma](https://github.com/manisharma), [@ndyakov](https://github.com/ndyakov), [@ofekshenawa](https://github.com/ofekshenawa), [@pvragov](https://github.com/pvragov)
---
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.16.0...v9.17.0
# 9.16.0 (2025-10-23)
## 🚀 Highlights
### Maintenance Notifications Support
This release introduces comprehensive support for Redis maintenance notifications, enabling applications to handle server maintenance events gracefully. The new `maintnotifications` package provides:
- **RESP3 Push Notifications**: Full support for Redis RESP3 protocol push notifications
- **Connection Handoff**: Automatic connection migration during server maintenance with configurable retry policies and circuit breakers
- **Graceful Degradation**: Configurable timeout relaxation during maintenance windows to prevent false failures
- **Event-Driven Architecture**: Background workers with on-demand scaling for efficient handoff processing
- **Production-Ready**: Comprehensive E2E testing framework and monitoring capabilities
For detailed usage examples and configuration options, see the [maintenance notifications documentation](maintnotifications/README.md).
## ✨ New Features
- **Trace Filtering**: Add support for filtering traces for specific commands, including pipeline operations and dial operations ([#3519](https://github.com/redis/go-redis/pull/3519), [#3550](https://github.com/redis/go-redis/pull/3550))
- New `TraceCmdFilter` option to selectively trace commands
- Reduces overhead by excluding high-frequency or low-value commands from traces
## 🐛 Bug Fixes
- **Pipeline Error Handling**: Fix issue where pipeline repeatedly sets the same error ([#3525](https://github.com/redis/go-redis/pull/3525))
- **Connection Pool**: Ensure re-authentication does not interfere with connection handoff operations ([#3547](https://github.com/redis/go-redis/pull/3547))
## 🔧 Improvements
- **Hash Commands**: Update hash command implementations ([#3523](https://github.com/redis/go-redis/pull/3523))
- **OpenTelemetry**: Use `metric.WithAttributeSet` to avoid unnecessary attribute copying in redisotel ([#3552](https://github.com/redis/go-redis/pull/3552))
## 📚 Documentation
- **Cluster Client**: Add explanation for why `MaxRetries` is disabled for `ClusterClient` ([#3551](https://github.com/redis/go-redis/pull/3551))
## 🧪 Testing & Infrastructure
- **E2E Testing**: Upgrade E2E testing framework with improved reliability and coverage ([#3541](https://github.com/redis/go-redis/pull/3541))
- **Release Process**: Improved resiliency of the release process ([#3530](https://github.com/redis/go-redis/pull/3530))
## 📦 Dependencies
- Bump `rojopolis/spellcheck-github-actions` from 0.51.0 to 0.52.0 ([#3520](https://github.com/redis/go-redis/pull/3520))
- Bump `github/codeql-action` from 3 to 4 ([#3544](https://github.com/redis/go-redis/pull/3544))
## 👥 Contributors
We'd like to thank all the contributors who worked on this release!
[@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), [@Sovietaced](https://github.com/Sovietaced), [@Udhayarajan](https://github.com/Udhayarajan), [@boekkooi-impossiblecloud](https://github.com/boekkooi-impossiblecloud), [@Pika-Gopher](https://github.com/Pika-Gopher), [@cxljs](https://github.com/cxljs), [@huiyifyj](https://github.com/huiyifyj), [@omid-h70](https://github.com/omid-h70)
---
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.14.0...v9.16.0
# 9.15.0 was accidentally released. Please use version 9.16.0 instead.
# 9.15.0-beta.3 (2025-09-26)
## Highlights
This beta release includes a pre-production version of processing push notifications and hitless upgrades.
# Changes
- chore: Update hash_commands.go ([#3523](https://github.com/redis/go-redis/pull/3523))
## 🚀 New Features
- feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418))
## 🐛 Bug Fixes
- fix: pipeline repeatedly sets the error ([#3525](https://github.com/redis/go-redis/pull/3525))
## 🧰 Maintenance
- chore(deps): bump rojopolis/spellcheck-github-actions from 0.51.0 to 0.52.0 ([#3520](https://github.com/redis/go-redis/pull/3520))
- feat(e2e-testing): maintnotifications e2e and refactor ([#3526](https://github.com/redis/go-redis/pull/3526))
- feat(tag.sh): Improved resiliency of the release process ([#3530](https://github.com/redis/go-redis/pull/3530))
## Contributors
We'd like to thank all the contributors who worked on this release!
[@cxljs](https://github.com/cxljs), [@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), and [@omid-h70](https://github.com/omid-h70)
# 9.15.0-beta.1 (2025-09-10)
## Highlights
This beta release includes a pre-production version of processing push notifications and hitless upgrades.
### Hitless Upgrades
Hitless upgrades is a major new feature that allows for zero-downtime upgrades in Redis clusters.
You can find more information in the [Hitless Upgrades documentation](https://github.com/redis/go-redis/tree/master/hitless).
# Changes
## 🚀 New Features
- [CAE-1088] & [CAE-1072] feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418))
## Contributors
We'd like to thank all the contributors who worked on this release!
[@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), [@ofekshenawa](https://github.com/ofekshenawa)
# 9.14.0 (2025-09-10)
## Highlights
+27
View File
@@ -8,8 +8,12 @@ type ACLCmdable interface {
ACLLog(ctx context.Context, count int64) *ACLLogCmd
ACLLogReset(ctx context.Context) *StatusCmd
ACLGenPass(ctx context.Context, bit int) *StringCmd
ACLSetUser(ctx context.Context, username string, rules ...string) *StatusCmd
ACLDelUser(ctx context.Context, username string) *IntCmd
ACLUsers(ctx context.Context) *StringSliceCmd
ACLWhoAmI(ctx context.Context) *StringCmd
ACLList(ctx context.Context) *StringSliceCmd
ACLCat(ctx context.Context) *StringSliceCmd
@@ -65,6 +69,29 @@ func (c cmdable) ACLSetUser(ctx context.Context, username string, rules ...strin
return cmd
}
func (c cmdable) ACLGenPass(ctx context.Context, bit int) *StringCmd {
args := make([]interface{}, 0, 3)
args = append(args, "acl", "genpass")
if bit > 0 {
args = append(args, bit)
}
cmd := NewStringCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
func (c cmdable) ACLUsers(ctx context.Context) *StringSliceCmd {
cmd := NewStringSliceCmd(ctx, "acl", "users")
_ = c(ctx, cmd)
return cmd
}
func (c cmdable) ACLWhoAmI(ctx context.Context) *StringCmd {
cmd := NewStringCmd(ctx, "acl", "whoami")
_ = c(ctx, cmd)
return cmd
}
func (c cmdable) ACLList(ctx context.Context) *StringSliceCmd {
cmd := NewStringSliceCmd(ctx, "acl", "list")
_ = c(ctx, cmd)
+111
View File
@@ -0,0 +1,111 @@
package redis
import (
"context"
"errors"
"net"
"time"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/push"
)
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
var ErrInvalidCommand = errors.New("invalid command type")
// ErrInvalidPool is returned when the pool type is not supported.
var ErrInvalidPool = errors.New("invalid pool type")
// newClientAdapter creates a new client adapter for regular Redis clients.
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
return &clientAdapter{client: client}
}
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
type clientAdapter struct {
client *baseClient
}
// GetOptions returns the client options.
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
return &optionsAdapter{options: ca.client.opt}
}
// GetPushProcessor returns the client's push notification processor.
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
}
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
type optionsAdapter struct {
options *Options
}
// GetReadTimeout returns the read timeout.
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
return oa.options.ReadTimeout
}
// GetWriteTimeout returns the write timeout.
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
return oa.options.WriteTimeout
}
// GetNetwork returns the network type.
func (oa *optionsAdapter) GetNetwork() string {
return oa.options.Network
}
// GetAddr returns the connection address.
func (oa *optionsAdapter) GetAddr() string {
return oa.options.Addr
}
// IsTLSEnabled returns true if TLS is enabled.
func (oa *optionsAdapter) IsTLSEnabled() bool {
return oa.options.TLSConfig != nil
}
// GetProtocol returns the protocol version.
func (oa *optionsAdapter) GetProtocol() int {
return oa.options.Protocol
}
// GetPoolSize returns the connection pool size.
func (oa *optionsAdapter) GetPoolSize() int {
return oa.options.PoolSize
}
// NewDialer returns a new dialer function for the connection.
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
baseDialer := oa.options.NewDialer()
return func(ctx context.Context) (net.Conn, error) {
// Extract network and address from the options
network := oa.options.Network
addr := oa.options.Addr
return baseDialer(ctx, network, addr)
}
}
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
type pushProcessorAdapter struct {
processor push.NotificationProcessor
}
// RegisterHandler registers a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
if pushHandler, ok := handler.(push.NotificationHandler); ok {
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
}
return errors.New("handler must implement push.NotificationHandler")
}
// UnregisterHandler removes a handler for a specific push notification name.
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
return ppa.processor.UnregisterHandler(pushNotificationName)
}
// GetHandler returns the handler for a specific push notification name.
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
return ppa.processor.GetHandler(pushNotificationName)
}
+6 -2
View File
@@ -141,7 +141,9 @@ func (c cmdable) BitPos(ctx context.Context, key string, bit int64, pos ...int64
args[3] = pos[0]
args[4] = pos[1]
default:
panic("too many arguments")
cmd := NewIntCmd(ctx)
cmd.SetErr(errors.New("too many arguments"))
return cmd
}
cmd := NewIntCmd(ctx, args...)
_ = c(ctx, cmd)
@@ -182,7 +184,9 @@ func (c cmdable) BitFieldRO(ctx context.Context, key string, values ...interface
args[0] = "BITFIELD_RO"
args[1] = key
if len(values)%2 != 0 {
panic("BitFieldRO: invalid number of arguments, must be even")
c := NewIntSliceCmd(ctx)
c.SetErr(errors.New("BitFieldRO: invalid number of arguments, must be even"))
return c
}
for i := 0; i < len(values); i += 2 {
args = append(args, "GET", values[i], values[i+1])
+169 -3
View File
@@ -64,6 +64,7 @@ var keylessCommands = map[string]struct{}{
"sync": {},
"unsubscribe": {},
"unwatch": {},
"wait": {},
}
type Cmder interface {
@@ -698,6 +699,68 @@ func (cmd *IntCmd) readReply(rd *proto.Reader) (err error) {
//------------------------------------------------------------------------------
// DigestCmd is a command that returns a uint64 xxh3 hash digest.
//
// This command is specifically designed for the Redis DIGEST command,
// which returns the xxh3 hash of a key's value as a hex string.
// The hex string is automatically parsed to a uint64 value.
//
// The digest can be used for optimistic locking with SetIFDEQ, SetIFDNE,
// and DelExArgs commands.
//
// For examples of client-side digest generation and usage patterns, see:
// example/digest-optimistic-locking/
//
// Redis 8.4+. See https://redis.io/commands/digest/
type DigestCmd struct {
baseCmd
val uint64
}
var _ Cmder = (*DigestCmd)(nil)
func NewDigestCmd(ctx context.Context, args ...interface{}) *DigestCmd {
return &DigestCmd{
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
func (cmd *DigestCmd) SetVal(val uint64) {
cmd.val = val
}
func (cmd *DigestCmd) Val() uint64 {
return cmd.val
}
func (cmd *DigestCmd) Result() (uint64, error) {
return cmd.val, cmd.err
}
func (cmd *DigestCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *DigestCmd) readReply(rd *proto.Reader) (err error) {
// Redis DIGEST command returns a hex string (e.g., "a1b2c3d4e5f67890")
// We parse it as a uint64 xxh3 hash value
var hexStr string
hexStr, err = rd.ReadString()
if err != nil {
return err
}
// Parse hex string to uint64
cmd.val, err = strconv.ParseUint(hexStr, 16, 64)
return err
}
//------------------------------------------------------------------------------
type IntSliceCmd struct {
baseCmd
@@ -1585,6 +1648,12 @@ func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error {
type XMessage struct {
ID string
Values map[string]interface{}
// MillisElapsedFromDelivery is the number of milliseconds since the entry was last delivered.
// Only populated when using XREADGROUP with CLAIM argument for claimed entries.
MillisElapsedFromDelivery int64
// DeliveredCount is the number of times the entry was delivered.
// Only populated when using XREADGROUP with CLAIM argument for claimed entries.
DeliveredCount int64
}
type XMessageSliceCmd struct {
@@ -1641,10 +1710,16 @@ func readXMessageSlice(rd *proto.Reader) ([]XMessage, error) {
}
func readXMessage(rd *proto.Reader) (XMessage, error) {
if err := rd.ReadFixedArrayLen(2); err != nil {
// Read array length can be 2 or 4 (with CLAIM metadata)
n, err := rd.ReadArrayLen()
if err != nil {
return XMessage{}, err
}
if n != 2 && n != 4 {
return XMessage{}, fmt.Errorf("redis: got %d elements in the XMessage array, expected 2 or 4", n)
}
id, err := rd.ReadString()
if err != nil {
return XMessage{}, err
@@ -1657,10 +1732,24 @@ func readXMessage(rd *proto.Reader) (XMessage, error) {
}
}
return XMessage{
msg := XMessage{
ID: id,
Values: v,
}, nil
}
if n == 4 {
msg.MillisElapsedFromDelivery, err = rd.ReadInt()
if err != nil {
return XMessage{}, err
}
msg.DeliveredCount, err = rd.ReadInt()
if err != nil {
return XMessage{}, err
}
}
return msg, nil
}
func stringInterfaceMapParser(rd *proto.Reader) (map[string]interface{}, error) {
@@ -3768,6 +3857,83 @@ func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error {
//-----------------------------------------------------------------------
type Latency struct {
Name string
Time time.Time
Latest time.Duration
Max time.Duration
}
type LatencyCmd struct {
baseCmd
val []Latency
}
var _ Cmder = (*LatencyCmd)(nil)
func NewLatencyCmd(ctx context.Context, args ...interface{}) *LatencyCmd {
return &LatencyCmd{
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
}
}
func (cmd *LatencyCmd) SetVal(val []Latency) {
cmd.val = val
}
func (cmd *LatencyCmd) Val() []Latency {
return cmd.val
}
func (cmd *LatencyCmd) Result() ([]Latency, error) {
return cmd.val, cmd.err
}
func (cmd *LatencyCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *LatencyCmd) readReply(rd *proto.Reader) error {
n, err := rd.ReadArrayLen()
if err != nil {
return err
}
cmd.val = make([]Latency, n)
for i := 0; i < len(cmd.val); i++ {
nn, err := rd.ReadArrayLen()
if err != nil {
return err
}
if nn < 3 {
return fmt.Errorf("redis: got %d elements in latency get, expected at least 3", nn)
}
if cmd.val[i].Name, err = rd.ReadString(); err != nil {
return err
}
createdAt, err := rd.ReadInt()
if err != nil {
return err
}
cmd.val[i].Time = time.Unix(createdAt, 0)
latest, err := rd.ReadInt()
if err != nil {
return err
}
cmd.val[i].Latest = time.Duration(latest) * time.Millisecond
maximum, err := rd.ReadInt()
if err != nil {
return err
}
cmd.val[i].Max = time.Duration(maximum) * time.Millisecond
}
return nil
}
//-----------------------------------------------------------------------
type MapStringInterfaceCmd struct {
baseCmd
+53 -1
View File
@@ -193,6 +193,7 @@ type Cmdable interface {
ClientID(ctx context.Context) *IntCmd
ClientUnblock(ctx context.Context, id int64) *IntCmd
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
ConfigResetStat(ctx context.Context) *StatusCmd
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
@@ -210,9 +211,13 @@ type Cmdable interface {
ShutdownNoSave(ctx context.Context) *StatusCmd
SlaveOf(ctx context.Context, host, port string) *StatusCmd
SlowLogGet(ctx context.Context, num int64) *SlowLogCmd
SlowLogLen(ctx context.Context) *IntCmd
SlowLogReset(ctx context.Context) *StatusCmd
Time(ctx context.Context) *TimeCmd
DebugObject(ctx context.Context, key string) *StringCmd
MemoryUsage(ctx context.Context, key string, samples ...int) *IntCmd
Latency(ctx context.Context) *LatencyCmd
LatencyReset(ctx context.Context, events ...interface{}) *StatusCmd
ModuleLoadex(ctx context.Context, conf *ModuleLoadexConfig) *StringCmd
@@ -519,6 +524,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
return cmd
}
// ClientMaintNotifications enables or disables maintenance notifications for maintenance upgrades.
// When enabled, the client will receive push notifications about Redis maintenance events.
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
args := []interface{}{"client", "maint_notifications"}
if enabled {
if endpointType == "" {
endpointType = "none"
}
args = append(args, "on", "moving-endpoint-type", endpointType)
} else {
args = append(args, "off")
}
cmd := NewStatusCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// ------------------------------------------------------------------------------------------------
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {
@@ -655,6 +677,34 @@ func (c cmdable) SlowLogGet(ctx context.Context, num int64) *SlowLogCmd {
return cmd
}
func (c cmdable) SlowLogLen(ctx context.Context) *IntCmd {
cmd := NewIntCmd(ctx, "slowlog", "len")
_ = c(ctx, cmd)
return cmd
}
func (c cmdable) SlowLogReset(ctx context.Context) *StatusCmd {
cmd := NewStatusCmd(ctx, "slowlog", "reset")
_ = c(ctx, cmd)
return cmd
}
func (c cmdable) Latency(ctx context.Context) *LatencyCmd {
cmd := NewLatencyCmd(ctx, "latency", "latest")
_ = c(ctx, cmd)
return cmd
}
func (c cmdable) LatencyReset(ctx context.Context, events ...interface{}) *StatusCmd {
args := make([]interface{}, 2+len(events))
args[0] = "latency"
args[1] = "reset"
copy(args[2:], events)
cmd := NewStatusCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
func (c cmdable) Sync(_ context.Context) {
panic("not implemented")
}
@@ -675,7 +725,9 @@ func (c cmdable) MemoryUsage(ctx context.Context, key string, samples ...int) *I
args := []interface{}{"memory", "usage", key}
if len(samples) > 0 {
if len(samples) != 1 {
panic("MemoryUsage expects single sample count")
cmd := NewIntCmd(ctx)
cmd.SetErr(errors.New("MemoryUsage expects single sample count"))
return cmd
}
args = append(args, "SAMPLES", samples[0])
}
+5 -5
View File
@@ -2,7 +2,7 @@
services:
redis:
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
platform: linux/amd64
container_name: redis-standalone
environment:
@@ -23,7 +23,7 @@ services:
- all
osscluster:
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
platform: linux/amd64
container_name: redis-osscluster
environment:
@@ -40,7 +40,7 @@ services:
- all
sentinel-cluster:
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
platform: linux/amd64
container_name: redis-sentinel-cluster
network_mode: "host"
@@ -60,7 +60,7 @@ services:
- all
sentinel:
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
platform: linux/amd64
container_name: redis-sentinel
depends_on:
@@ -84,7 +84,7 @@ services:
- all
ring-cluster:
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
platform: linux/amd64
container_name: redis-ring-cluster
environment:
+207 -45
View File
@@ -52,34 +52,82 @@ type Error interface {
var _ Error = proto.RedisError("")
func isContextError(err error) bool {
switch err {
case context.Canceled, context.DeadlineExceeded:
return true
default:
return false
// Check for wrapped context errors using errors.Is
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
}
// isTimeoutError checks if an error is a timeout error, even if wrapped.
// Returns (isTimeout, shouldRetryOnTimeout) where:
// - isTimeout: true if the error is any kind of timeout error
// - shouldRetryOnTimeout: true if Timeout() method returns true
func isTimeoutError(err error) (isTimeout bool, hasTimeoutFlag bool) {
// Check for timeoutError interface (works with wrapped errors)
var te timeoutError
if errors.As(err, &te) {
return true, te.Timeout()
}
// Check for net.Error specifically (common case for network timeouts)
var netErr net.Error
if errors.As(err, &netErr) {
return true, netErr.Timeout()
}
return false, false
}
func shouldRetry(err error, retryTimeout bool) bool {
switch err {
case io.EOF, io.ErrUnexpectedEOF:
return true
case nil, context.Canceled, context.DeadlineExceeded:
if err == nil {
return false
case pool.ErrPoolTimeout:
}
// Check for EOF errors (works with wrapped errors)
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
return true
}
// Check for context errors (works with wrapped errors)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
// Check for pool timeout (works with wrapped errors)
if errors.Is(err, pool.ErrPoolTimeout) {
// connection pool timeout, increase retries. #3289
return true
}
if v, ok := err.(timeoutError); ok {
if v.Timeout() {
// Check for timeout errors (works with wrapped errors)
if isTimeout, hasTimeoutFlag := isTimeoutError(err); isTimeout {
if hasTimeoutFlag {
return retryTimeout
}
return true
}
// Check for typed Redis errors using errors.As (works with wrapped errors)
if proto.IsMaxClientsError(err) {
return true
}
if proto.IsLoadingError(err) {
return true
}
if proto.IsReadOnlyError(err) {
return true
}
if proto.IsMasterDownError(err) {
return true
}
if proto.IsClusterDownError(err) {
return true
}
if proto.IsTryAgainError(err) {
return true
}
// Fallback to string checking for backward compatibility with plain errors
s := err.Error()
if s == "ERR max number of clients reached" {
if strings.HasPrefix(s, "ERR max number of clients reached") {
return true
}
if strings.HasPrefix(s, "LOADING ") {
@@ -88,29 +136,42 @@ func shouldRetry(err error, retryTimeout bool) bool {
if strings.HasPrefix(s, "READONLY ") {
return true
}
if strings.HasPrefix(s, "MASTERDOWN ") {
return true
}
if strings.HasPrefix(s, "CLUSTERDOWN ") {
return true
}
if strings.HasPrefix(s, "TRYAGAIN ") {
return true
}
if strings.HasPrefix(s, "MASTERDOWN ") {
return true
}
return false
}
func isRedisError(err error) bool {
_, ok := err.(proto.RedisError)
return ok
// Check if error implements the Error interface (works with wrapped errors)
var redisErr Error
if errors.As(err, &redisErr) {
return true
}
// Also check for proto.RedisError specifically
var protoRedisErr proto.RedisError
return errors.As(err, &protoRedisErr)
}
func isBadConn(err error, allowTimeout bool, addr string) bool {
switch err {
case nil:
if err == nil {
return false
case context.Canceled, context.DeadlineExceeded:
}
// Check for context errors (works with wrapped errors)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
// Check for pool timeout errors (works with wrapped errors)
if errors.Is(err, pool.ErrConnUnusableTimeout) {
return true
}
@@ -131,7 +192,9 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
}
if allowTimeout {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Check for network timeout errors (works with wrapped errors)
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return false
}
}
@@ -140,44 +203,143 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
}
func isMovedError(err error) (moved bool, ask bool, addr string) {
if !isRedisError(err) {
return
// Check for typed MovedError
if movedErr, ok := proto.IsMovedError(err); ok {
addr = movedErr.Addr()
addr = internal.GetAddr(addr)
return true, false, addr
}
// Check for typed AskError
if askErr, ok := proto.IsAskError(err); ok {
addr = askErr.Addr()
addr = internal.GetAddr(addr)
return false, true, addr
}
// Fallback to string checking for backward compatibility
s := err.Error()
switch {
case strings.HasPrefix(s, "MOVED "):
moved = true
case strings.HasPrefix(s, "ASK "):
ask = true
default:
return
if strings.HasPrefix(s, "MOVED ") {
// Parse: MOVED 3999 127.0.0.1:6381
parts := strings.Split(s, " ")
if len(parts) == 3 {
addr = internal.GetAddr(parts[2])
return true, false, addr
}
}
if strings.HasPrefix(s, "ASK ") {
// Parse: ASK 3999 127.0.0.1:6381
parts := strings.Split(s, " ")
if len(parts) == 3 {
addr = internal.GetAddr(parts[2])
return false, true, addr
}
}
ind := strings.LastIndex(s, " ")
if ind == -1 {
return false, false, ""
}
addr = s[ind+1:]
addr = internal.GetAddr(addr)
return
return false, false, ""
}
func isLoadingError(err error) bool {
return strings.HasPrefix(err.Error(), "LOADING ")
return proto.IsLoadingError(err)
}
func isReadOnlyError(err error) bool {
return strings.HasPrefix(err.Error(), "READONLY ")
return proto.IsReadOnlyError(err)
}
func isMovedSameConnAddr(err error, addr string) bool {
redisError := err.Error()
if !strings.HasPrefix(redisError, "MOVED ") {
return false
if movedErr, ok := proto.IsMovedError(err); ok {
return strings.HasSuffix(movedErr.Addr(), addr)
}
return strings.HasSuffix(redisError, " "+addr)
return false
}
//------------------------------------------------------------------------------
// Typed error checking functions for public use.
// These functions work correctly even when errors are wrapped in hooks.
// IsLoadingError checks if an error is a Redis LOADING error, even if wrapped.
// LOADING errors occur when Redis is loading the dataset in memory.
func IsLoadingError(err error) bool {
return proto.IsLoadingError(err)
}
// IsReadOnlyError checks if an error is a Redis READONLY error, even if wrapped.
// READONLY errors occur when trying to write to a read-only replica.
func IsReadOnlyError(err error) bool {
return proto.IsReadOnlyError(err)
}
// IsClusterDownError checks if an error is a Redis CLUSTERDOWN error, even if wrapped.
// CLUSTERDOWN errors occur when the cluster is down.
func IsClusterDownError(err error) bool {
return proto.IsClusterDownError(err)
}
// IsTryAgainError checks if an error is a Redis TRYAGAIN error, even if wrapped.
// TRYAGAIN errors occur when a command cannot be processed and should be retried.
func IsTryAgainError(err error) bool {
return proto.IsTryAgainError(err)
}
// IsMasterDownError checks if an error is a Redis MASTERDOWN error, even if wrapped.
// MASTERDOWN errors occur when the master is down.
func IsMasterDownError(err error) bool {
return proto.IsMasterDownError(err)
}
// IsMaxClientsError checks if an error is a Redis max clients error, even if wrapped.
// This error occurs when the maximum number of clients has been reached.
func IsMaxClientsError(err error) bool {
return proto.IsMaxClientsError(err)
}
// IsMovedError checks if an error is a Redis MOVED error, even if wrapped.
// MOVED errors occur in cluster mode when a key has been moved to a different node.
// Returns the address of the node where the key has been moved and a boolean indicating if it's a MOVED error.
func IsMovedError(err error) (addr string, ok bool) {
if movedErr, isMovedErr := proto.IsMovedError(err); isMovedErr {
return movedErr.Addr(), true
}
return "", false
}
// IsAskError checks if an error is a Redis ASK error, even if wrapped.
// ASK errors occur in cluster mode when a key is being migrated and the client should ask another node.
// Returns the address of the node to ask and a boolean indicating if it's an ASK error.
func IsAskError(err error) (addr string, ok bool) {
if askErr, isAskErr := proto.IsAskError(err); isAskErr {
return askErr.Addr(), true
}
return "", false
}
// IsAuthError checks if an error is a Redis authentication error, even if wrapped.
// Authentication errors occur when:
// - NOAUTH: Redis requires authentication but none was provided
// - WRONGPASS: Redis authentication failed due to incorrect password
// - unauthenticated: Error returned when password changed
func IsAuthError(err error) bool {
return proto.IsAuthError(err)
}
// IsPermissionError checks if an error is a Redis permission error, even if wrapped.
// Permission errors (NOPERM) occur when a user does not have permission to execute a command.
func IsPermissionError(err error) bool {
return proto.IsPermissionError(err)
}
// IsExecAbortError checks if an error is a Redis EXECABORT error, even if wrapped.
// EXECABORT errors occur when a transaction is aborted.
func IsExecAbortError(err error) bool {
return proto.IsExecAbortError(err)
}
// IsOOMError checks if an error is a Redis OOM (Out Of Memory) error, even if wrapped.
// OOM errors occur when Redis is out of memory.
func IsOOMError(err error) bool {
return proto.IsOOMError(err)
}
//------------------------------------------------------------------------------
+4 -4
View File
@@ -116,16 +116,16 @@ func (c cmdable) HMGet(ctx context.Context, key string, fields ...string) *Slice
// HSet accepts values in following formats:
//
// - HSet("myhash", "key1", "value1", "key2", "value2")
// - HSet(ctx, "myhash", "key1", "value1", "key2", "value2")
//
// - HSet("myhash", []string{"key1", "value1", "key2", "value2"})
// - HSet(ctx, "myhash", []string{"key1", "value1", "key2", "value2"})
//
// - HSet("myhash", map[string]interface{}{"key1": "value1", "key2": "value2"})
// - HSet(ctx, "myhash", map[string]interface{}{"key1": "value1", "key2": "value2"})
//
// Playing struct With "redis" tag.
// type MyHash struct { Key1 string `redis:"key1"`; Key2 int `redis:"key2"` }
//
// - HSet("myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0
// - HSet(ctx, "myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0
//
// For struct, can be a structure pointer type, we only parse the field whose tag is redis.
// if you don't want the field to be read, you can use the `redis:"-"` flag to ignore it,
@@ -0,0 +1,100 @@
package streaming
import (
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
)
// ConnReAuthCredentialsListener is a credentials listener for a specific connection
// that triggers re-authentication when credentials change.
//
// This listener implements the auth.CredentialsListener interface and is subscribed
// to a StreamingCredentialsProvider. When new credentials are received via OnNext,
// it marks the connection for re-authentication through the manager.
//
// The re-authentication is always performed asynchronously to avoid blocking the
// credentials provider and to prevent potential deadlocks with the pool semaphore.
// The actual re-auth happens when the connection is returned to the pool in an idle state.
//
// Lifecycle:
// - Created during connection initialization via Manager.Listener()
// - Subscribed to the StreamingCredentialsProvider
// - Receives credential updates via OnNext()
// - Cleaned up when connection is removed from pool via Manager.RemoveListener()
type ConnReAuthCredentialsListener struct {
// reAuth is the function to re-authenticate the connection with new credentials
reAuth func(conn *pool.Conn, credentials auth.Credentials) error
// onErr is the function to call when re-authentication or acquisition fails
onErr func(conn *pool.Conn, err error)
// conn is the connection this listener is associated with
conn *pool.Conn
// manager is the streaming credentials manager for coordinating re-auth
manager *Manager
}
// OnNext is called when new credentials are received from the StreamingCredentialsProvider.
//
// This method marks the connection for asynchronous re-authentication. The actual
// re-authentication happens in the background when the connection is returned to the
// pool and is in an idle state.
//
// Asynchronous re-auth is used to:
// - Avoid blocking the credentials provider's notification goroutine
// - Prevent deadlocks with the pool's semaphore (especially with small pool sizes)
// - Ensure re-auth happens when the connection is safe to use (not processing commands)
//
// The reAuthFn callback receives:
// - nil if the connection was successfully acquired for re-auth
// - error if acquisition timed out or failed
//
// Thread-safe: Called by the credentials provider's notification goroutine.
func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
if c.conn == nil || c.conn.IsClosed() || c.manager == nil || c.reAuth == nil {
return
}
// Always use async reauth to avoid complex pool semaphore issues
// The synchronous path can cause deadlocks in the pool's semaphore mechanism
// when called from the Subscribe goroutine, especially with small pool sizes.
// The connection pool hook will re-authenticate the connection when it is
// returned to the pool in a clean, idle state.
c.manager.MarkForReAuth(c.conn, func(err error) {
// err is from connection acquisition (timeout, etc.)
if err != nil {
// Log the error
c.OnError(err)
return
}
// err is from reauth command execution
err = c.reAuth(c.conn, credentials)
if err != nil {
// Log the error
c.OnError(err)
return
}
})
}
// OnError is called when an error occurs during credential streaming or re-authentication.
//
// This method can be called from:
// - The StreamingCredentialsProvider when there's an error in the credentials stream
// - The re-auth process when connection acquisition times out
// - The re-auth process when the AUTH command fails
//
// The error is delegated to the onErr callback provided during listener creation.
//
// Thread-safe: Can be called from multiple goroutines (provider, re-auth worker).
func (c *ConnReAuthCredentialsListener) OnError(err error) {
if c.onErr == nil {
return
}
c.onErr(c.conn, err)
}
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
@@ -0,0 +1,77 @@
package streaming
import (
"sync"
"github.com/redis/go-redis/v9/auth"
)
// CredentialsListeners is a thread-safe collection of credentials listeners
// indexed by connection ID.
//
// This collection is used by the Manager to maintain a registry of listeners
// for each connection in the pool. Listeners are reused when connections are
// reinitialized (e.g., after a handoff) to avoid creating duplicate subscriptions
// to the StreamingCredentialsProvider.
//
// The collection supports concurrent access from multiple goroutines during
// connection initialization, credential updates, and connection removal.
type CredentialsListeners struct {
// listeners maps connection ID to credentials listener
listeners map[uint64]auth.CredentialsListener
// lock protects concurrent access to the listeners map
lock sync.RWMutex
}
// NewCredentialsListeners creates a new thread-safe credentials listeners collection.
func NewCredentialsListeners() *CredentialsListeners {
return &CredentialsListeners{
listeners: make(map[uint64]auth.CredentialsListener),
}
}
// Add adds or updates a credentials listener for a connection.
//
// If a listener already exists for the connection ID, it is replaced.
// This is safe because the old listener should have been unsubscribed
// before the connection was reinitialized.
//
// Thread-safe: Can be called concurrently from multiple goroutines.
func (c *CredentialsListeners) Add(connID uint64, listener auth.CredentialsListener) {
c.lock.Lock()
defer c.lock.Unlock()
if c.listeners == nil {
c.listeners = make(map[uint64]auth.CredentialsListener)
}
c.listeners[connID] = listener
}
// Get retrieves the credentials listener for a connection.
//
// Returns:
// - listener: The credentials listener for the connection, or nil if not found
// - ok: true if a listener exists for the connection ID, false otherwise
//
// Thread-safe: Can be called concurrently from multiple goroutines.
func (c *CredentialsListeners) Get(connID uint64) (auth.CredentialsListener, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
if len(c.listeners) == 0 {
return nil, false
}
listener, ok := c.listeners[connID]
return listener, ok
}
// Remove removes the credentials listener for a connection.
//
// This is called when a connection is removed from the pool to prevent
// memory leaks. If no listener exists for the connection ID, this is a no-op.
//
// Thread-safe: Can be called concurrently from multiple goroutines.
func (c *CredentialsListeners) Remove(connID uint64) {
c.lock.Lock()
defer c.lock.Unlock()
delete(c.listeners, connID)
}
+137
View File
@@ -0,0 +1,137 @@
package streaming
import (
"errors"
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
)
// Manager coordinates streaming credentials and re-authentication for a connection pool.
//
// The manager is responsible for:
// - Creating and managing per-connection credentials listeners
// - Providing the pool hook for re-authentication
// - Coordinating between credentials updates and pool operations
//
// When credentials change via a StreamingCredentialsProvider:
// 1. The credentials listener (ConnReAuthCredentialsListener) receives the update
// 2. It calls MarkForReAuth on the manager
// 3. The manager delegates to the pool hook
// 4. The pool hook schedules background re-authentication
//
// The manager maintains a registry of credentials listeners indexed by connection ID,
// allowing listener reuse when connections are reinitialized (e.g., after handoff).
type Manager struct {
// credentialsListeners maps connection ID to credentials listener
credentialsListeners *CredentialsListeners
// pool is the connection pool being managed
pool pool.Pooler
// poolHookRef is the re-authentication pool hook
poolHookRef *ReAuthPoolHook
}
// NewManager creates a new streaming credentials manager.
//
// Parameters:
// - pl: The connection pool to manage
// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication
//
// The manager creates a ReAuthPoolHook sized to match the pool size, ensuring that
// re-auth operations don't exhaust the connection pool.
func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager {
m := &Manager{
pool: pl,
poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout),
credentialsListeners: NewCredentialsListeners(),
}
m.poolHookRef.manager = m
return m
}
// PoolHook returns the pool hook for re-authentication.
//
// This hook should be registered with the connection pool to enable
// automatic re-authentication when credentials change.
func (m *Manager) PoolHook() pool.PoolHook {
return m.poolHookRef
}
// Listener returns or creates a credentials listener for a connection.
//
// This method is called during connection initialization to set up the
// credentials listener. If a listener already exists for the connection ID
// (e.g., after a handoff), it is reused.
//
// Parameters:
// - poolCn: The connection to create/get a listener for
// - reAuth: Function to re-authenticate the connection with new credentials
// - onErr: Function to call when re-authentication fails
//
// Returns:
// - auth.CredentialsListener: The listener to subscribe to the credentials provider
// - error: Non-nil if poolCn is nil
//
// Note: The reAuth and onErr callbacks are captured once when the listener is
// created and reused for the connection's lifetime. They should not change.
//
// Thread-safe: Can be called concurrently during connection initialization.
func (m *Manager) Listener(
poolCn *pool.Conn,
reAuth func(*pool.Conn, auth.Credentials) error,
onErr func(*pool.Conn, error),
) (auth.CredentialsListener, error) {
if poolCn == nil {
return nil, errors.New("poolCn cannot be nil")
}
connID := poolCn.GetID()
// if we reconnect the underlying network connection, the streaming credentials listener will continue to work
// so we can get the old listener from the cache and use it.
// subscribing the same (an already subscribed) listener for a StreamingCredentialsProvider SHOULD be a no-op
listener, ok := m.credentialsListeners.Get(connID)
if !ok || listener == nil {
// Create new listener for this connection
// Note: Callbacks (reAuth, onErr) are captured once and reused for the connection's lifetime
newCredListener := &ConnReAuthCredentialsListener{
conn: poolCn,
reAuth: reAuth,
onErr: onErr,
manager: m,
}
m.credentialsListeners.Add(connID, newCredListener)
listener = newCredListener
}
return listener, nil
}
// MarkForReAuth marks a connection for re-authentication.
//
// This method is called by the credentials listener when new credentials are
// received. It delegates to the pool hook to schedule background re-authentication.
//
// Parameters:
// - poolCn: The connection to re-authenticate
// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails
//
// Thread-safe: Called by credentials listeners when credentials change.
func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
connID := poolCn.GetID()
m.poolHookRef.MarkForReAuth(connID, reAuthFn)
}
// RemoveListener removes the credentials listener for a connection.
//
// This method is called by the pool hook's OnRemove to clean up listeners
// when connections are removed from the pool.
//
// Parameters:
// - connID: The connection ID whose listener should be removed
//
// Thread-safe: Called during connection removal.
func (m *Manager) RemoveListener(connID uint64) {
m.credentialsListeners.Remove(connID)
}
@@ -0,0 +1,241 @@
package streaming
import (
"context"
"sync"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
)
// ReAuthPoolHook is a pool hook that manages background re-authentication of connections
// when credentials change via a streaming credentials provider.
//
// The hook uses a semaphore-based worker pool to limit concurrent re-authentication
// operations and prevent pool exhaustion. When credentials change, connections are
// marked for re-authentication and processed asynchronously in the background.
//
// The re-authentication process:
// 1. OnPut: When a connection is returned to the pool, check if it needs re-auth
// 2. If yes, schedule it for background processing (move from shouldReAuth to scheduledReAuth)
// 3. A worker goroutine acquires the connection (waits until it's not in use)
// 4. Executes the re-auth function while holding the connection
// 5. Releases the connection back to the pool
//
// The hook ensures that:
// - Only one re-auth operation runs per connection at a time
// - Connections are not used for commands during re-authentication
// - Re-auth operations timeout if they can't acquire the connection
// - Resources are properly cleaned up on connection removal
type ReAuthPoolHook struct {
// shouldReAuth maps connection ID to re-auth function
// Connections in this map need re-authentication but haven't been scheduled yet
shouldReAuth map[uint64]func(error)
shouldReAuthLock sync.RWMutex
// workers is a semaphore limiting concurrent re-auth operations
// Initialized with poolSize tokens to prevent pool exhaustion
// Uses FastSemaphore for better performance with eventual fairness
workers *internal.FastSemaphore
// reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth
reAuthTimeout time.Duration
// scheduledReAuth maps connection ID to scheduled status
// Connections in this map have a background worker attempting re-authentication
scheduledReAuth map[uint64]bool
scheduledLock sync.RWMutex
// manager is a back-reference for cleanup operations
manager *Manager
}
// NewReAuthPoolHook creates a new re-authentication pool hook.
//
// Parameters:
// - poolSize: Maximum number of concurrent re-auth operations (typically matches pool size)
// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication
//
// The poolSize parameter is used to initialize the worker semaphore, ensuring that
// re-auth operations don't exhaust the connection pool.
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
return &ReAuthPoolHook{
shouldReAuth: make(map[uint64]func(error)),
scheduledReAuth: make(map[uint64]bool),
workers: internal.NewFastSemaphore(int32(poolSize)),
reAuthTimeout: reAuthTimeout,
}
}
// MarkForReAuth marks a connection for re-authentication.
//
// This method is called when credentials change and a connection needs to be
// re-authenticated. The actual re-authentication happens asynchronously when
// the connection is returned to the pool (in OnPut).
//
// Parameters:
// - connID: The connection ID to mark for re-authentication
// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails
//
// Thread-safe: Can be called concurrently from multiple goroutines.
func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
r.shouldReAuthLock.Lock()
defer r.shouldReAuthLock.Unlock()
r.shouldReAuth[connID] = reAuthFn
}
// OnGet is called when a connection is retrieved from the pool.
//
// This hook checks if the connection needs re-authentication or has a scheduled
// re-auth operation. If so, it rejects the connection (returns accept=false),
// causing the pool to try another connection.
//
// Returns:
// - accept: false if connection needs re-auth, true otherwise
// - err: always nil (errors are not used in this hook)
//
// Thread-safe: Called concurrently by multiple goroutines getting connections.
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
connID := conn.GetID()
r.shouldReAuthLock.RLock()
_, shouldReAuth := r.shouldReAuth[connID]
r.shouldReAuthLock.RUnlock()
// This connection was marked for reauth while in the pool,
// reject the connection
if shouldReAuth {
// simply reject the connection, it will be re-authenticated in OnPut
return false, nil
}
r.scheduledLock.RLock()
_, hasScheduled := r.scheduledReAuth[connID]
r.scheduledLock.RUnlock()
// has scheduled reauth, reject the connection
if hasScheduled {
// simply reject the connection, it currently has a reauth scheduled
// and the worker is waiting for slot to execute the reauth
return false, nil
}
return true, nil
}
// OnPut is called when a connection is returned to the pool.
//
// This hook checks if the connection needs re-authentication. If so, it schedules
// a background goroutine to perform the re-auth asynchronously. The goroutine:
// 1. Waits for a worker slot (semaphore)
// 2. Acquires the connection (waits until not in use)
// 3. Executes the re-auth function
// 4. Releases the connection and worker slot
//
// The connection is always pooled (not removed) since re-auth happens in background.
//
// Returns:
// - shouldPool: always true (connection stays in pool during background re-auth)
// - shouldRemove: always false
// - err: always nil
//
// Thread-safe: Called concurrently by multiple goroutines returning connections.
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
if conn == nil {
// noop
return true, false, nil
}
connID := conn.GetID()
// Check if reauth is needed and get the function with proper locking
r.shouldReAuthLock.RLock()
reAuthFn, ok := r.shouldReAuth[connID]
r.shouldReAuthLock.RUnlock()
if ok {
// Acquire both locks to atomically move from shouldReAuth to scheduledReAuth
// This prevents race conditions where OnGet might miss the transition
r.shouldReAuthLock.Lock()
r.scheduledLock.Lock()
r.scheduledReAuth[connID] = true
delete(r.shouldReAuth, connID)
r.scheduledLock.Unlock()
r.shouldReAuthLock.Unlock()
go func() {
r.workers.AcquireBlocking()
// safety first
if conn == nil || (conn != nil && conn.IsClosed()) {
r.workers.Release()
return
}
defer func() {
if rec := recover(); rec != nil {
// once again - safety first
internal.Logger.Printf(context.Background(), "panic in reauth worker: %v", rec)
}
r.scheduledLock.Lock()
delete(r.scheduledReAuth, connID)
r.scheduledLock.Unlock()
r.workers.Release()
}()
// Create timeout context for connection acquisition
// This prevents indefinite waiting if the connection is stuck
ctx, cancel := context.WithTimeout(context.Background(), r.reAuthTimeout)
defer cancel()
// Try to acquire the connection for re-authentication
// We need to ensure the connection is IDLE (not IN_USE) before transitioning to UNUSABLE
// This prevents re-authentication from interfering with active commands
// Use AwaitAndTransition to wait for the connection to become IDLE
stateMachine := conn.GetStateMachine()
if stateMachine == nil {
// No state machine - should not happen, but handle gracefully
reAuthFn(pool.ErrConnUnusableTimeout)
return
}
// Use predefined slice to avoid allocation
_, err := stateMachine.AwaitAndTransition(ctx, pool.ValidFromIdle(), pool.StateUnusable)
if err != nil {
// Timeout or other error occurred, cannot acquire connection
reAuthFn(err)
return
}
// safety first
if !conn.IsClosed() {
// Successfully acquired the connection, perform reauth
reAuthFn(nil)
}
// Release the connection: transition from UNUSABLE back to IDLE
stateMachine.Transition(pool.StateIdle)
}()
}
// the reauth will happen in background, as far as the pool is concerned:
// pool the connection, don't remove it, no error
return true, false, nil
}
// OnRemove is called when a connection is removed from the pool.
//
// This hook cleans up all state associated with the connection:
// - Removes from shouldReAuth map (pending re-auth)
// - Removes from scheduledReAuth map (active re-auth)
// - Removes credentials listener from manager
//
// This prevents memory leaks and ensures that removed connections don't have
// lingering re-auth operations or listeners.
//
// Thread-safe: Called when connections are removed due to errors, timeouts, or pool closure.
func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) {
connID := conn.GetID()
r.shouldReAuthLock.Lock()
r.scheduledLock.Lock()
delete(r.scheduledReAuth, connID)
delete(r.shouldReAuth, connID)
r.scheduledLock.Unlock()
r.shouldReAuthLock.Unlock()
if r.manager != nil {
r.manager.RemoveListener(connID)
}
}
var _ pool.PoolHook = (*ReAuthPoolHook)(nil)
+54
View File
@@ -0,0 +1,54 @@
// Package interfaces provides shared interfaces used by both the main redis package
// and the maintnotifications upgrade package to avoid circular dependencies.
package interfaces
import (
"context"
"net"
"time"
)
// NotificationProcessor is (most probably) a push.NotificationProcessor
// forward declaration to avoid circular imports
type NotificationProcessor interface {
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
UnregisterHandler(pushNotificationName string) error
GetHandler(pushNotificationName string) interface{}
}
// ClientInterface defines the interface that clients must implement for maintnotifications upgrades.
type ClientInterface interface {
// GetOptions returns the client options.
GetOptions() OptionsInterface
// GetPushProcessor returns the client's push notification processor.
GetPushProcessor() NotificationProcessor
}
// OptionsInterface defines the interface for client options.
// Uses an adapter pattern to avoid circular dependencies.
type OptionsInterface interface {
// GetReadTimeout returns the read timeout.
GetReadTimeout() time.Duration
// GetWriteTimeout returns the write timeout.
GetWriteTimeout() time.Duration
// GetNetwork returns the network type.
GetNetwork() string
// GetAddr returns the connection address.
GetAddr() string
// IsTLSEnabled returns true if TLS is enabled.
IsTLSEnabled() bool
// GetProtocol returns the protocol version.
GetProtocol() int
// GetPoolSize returns the connection pool size.
GetPoolSize() int
// NewDialer returns a new dialer function for the connection.
NewDialer() func(context.Context) (net.Conn, error)
}
+57 -4
View File
@@ -7,20 +7,73 @@ import (
"os"
)
// TODO (ned): Revisit logging
// Add more standardized approach with log levels and configurability
type Logging interface {
Printf(ctx context.Context, format string, v ...interface{})
}
type logger struct {
type DefaultLogger struct {
log *log.Logger
}
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
func (l *DefaultLogger) Printf(ctx context.Context, format string, v ...interface{}) {
_ = l.log.Output(2, fmt.Sprintf(format, v...))
}
func NewDefaultLogger() Logging {
return &DefaultLogger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
}
}
// Logger calls Output to print to the stderr.
// Arguments are handled in the manner of fmt.Print.
var Logger Logging = &logger{
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
var Logger Logging = NewDefaultLogger()
var LogLevel LogLevelT = LogLevelError
// LogLevelT represents the logging level
type LogLevelT int
// Log level constants for the entire go-redis library
const (
LogLevelError LogLevelT = iota // 0 - errors only
LogLevelWarn // 1 - warnings and errors
LogLevelInfo // 2 - info, warnings, and errors
LogLevelDebug // 3 - debug, info, warnings, and errors
)
// String returns the string representation of the log level
func (l LogLevelT) String() string {
switch l {
case LogLevelError:
return "ERROR"
case LogLevelWarn:
return "WARN"
case LogLevelInfo:
return "INFO"
case LogLevelDebug:
return "DEBUG"
default:
return "UNKNOWN"
}
}
// IsValid returns true if the log level is valid
func (l LogLevelT) IsValid() bool {
return l >= LogLevelError && l <= LogLevelDebug
}
func (l LogLevelT) WarnOrAbove() bool {
return l >= LogLevelWarn
}
func (l LogLevelT) InfoOrAbove() bool {
return l >= LogLevelInfo
}
func (l LogLevelT) DebugOrAbove() bool {
return l >= LogLevelDebug
}
@@ -0,0 +1,625 @@
package logs
import (
"encoding/json"
"fmt"
"regexp"
"github.com/redis/go-redis/v9/internal"
)
// appendJSONIfDebug appends JSON data to a message only if the global log level is Debug
func appendJSONIfDebug(message string, data map[string]interface{}) string {
if internal.LogLevel.DebugOrAbove() {
jsonData, _ := json.Marshal(data)
return fmt.Sprintf("%s %s", message, string(jsonData))
}
return message
}
const (
// ========================================
// CIRCUIT_BREAKER.GO - Circuit breaker management
// ========================================
CircuitBreakerTransitioningToHalfOpenMessage = "circuit breaker transitioning to half-open"
CircuitBreakerOpenedMessage = "circuit breaker opened"
CircuitBreakerReopenedMessage = "circuit breaker reopened"
CircuitBreakerClosedMessage = "circuit breaker closed"
CircuitBreakerCleanupMessage = "circuit breaker cleanup"
CircuitBreakerOpenMessage = "circuit breaker is open, failing fast"
// ========================================
// CONFIG.GO - Configuration and debug
// ========================================
DebugLoggingEnabledMessage = "debug logging enabled"
ConfigDebugMessage = "config debug"
// ========================================
// ERRORS.GO - Error message constants
// ========================================
InvalidRelaxedTimeoutErrorMessage = "relaxed timeout must be greater than 0"
InvalidHandoffTimeoutErrorMessage = "handoff timeout must be greater than 0"
InvalidHandoffWorkersErrorMessage = "MaxWorkers must be greater than or equal to 0"
InvalidHandoffQueueSizeErrorMessage = "handoff queue size must be greater than 0"
InvalidPostHandoffRelaxedDurationErrorMessage = "post-handoff relaxed duration must be greater than or equal to 0"
InvalidEndpointTypeErrorMessage = "invalid endpoint type"
InvalidMaintNotificationsErrorMessage = "invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')"
InvalidHandoffRetriesErrorMessage = "MaxHandoffRetries must be between 1 and 10"
InvalidClientErrorMessage = "invalid client type"
InvalidNotificationErrorMessage = "invalid notification format"
MaxHandoffRetriesReachedErrorMessage = "max handoff retries reached"
HandoffQueueFullErrorMessage = "handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration"
InvalidCircuitBreakerFailureThresholdErrorMessage = "circuit breaker failure threshold must be >= 1"
InvalidCircuitBreakerResetTimeoutErrorMessage = "circuit breaker reset timeout must be >= 0"
InvalidCircuitBreakerMaxRequestsErrorMessage = "circuit breaker max requests must be >= 1"
ConnectionMarkedForHandoffErrorMessage = "connection marked for handoff"
ConnectionInvalidHandoffStateErrorMessage = "connection is in invalid state for handoff"
ShutdownErrorMessage = "shutdown"
CircuitBreakerOpenErrorMessage = "circuit breaker is open, failing fast"
// ========================================
// EXAMPLE_HOOKS.GO - Example metrics hooks
// ========================================
MetricsHookProcessingNotificationMessage = "metrics hook processing"
MetricsHookRecordedErrorMessage = "metrics hook recorded error"
// ========================================
// HANDOFF_WORKER.GO - Connection handoff processing
// ========================================
HandoffStartedMessage = "handoff started"
HandoffFailedMessage = "handoff failed"
ConnectionNotMarkedForHandoffMessage = "is not marked for handoff and has no retries"
ConnectionNotMarkedForHandoffErrorMessage = "is not marked for handoff"
HandoffRetryAttemptMessage = "Performing handoff"
CannotQueueHandoffForRetryMessage = "can't queue handoff for retry"
HandoffQueueFullMessage = "handoff queue is full"
FailedToDialNewEndpointMessage = "failed to dial new endpoint"
ApplyingRelaxedTimeoutDueToPostHandoffMessage = "applying relaxed timeout due to post-handoff"
HandoffSuccessMessage = "handoff succeeded"
RemovingConnectionFromPoolMessage = "removing connection from pool"
NoPoolProvidedMessageCannotRemoveMessage = "no pool provided, cannot remove connection, closing it"
WorkerExitingDueToShutdownMessage = "worker exiting due to shutdown"
WorkerExitingDueToShutdownWhileProcessingMessage = "worker exiting due to shutdown while processing request"
WorkerPanicRecoveredMessage = "worker panic recovered"
WorkerExitingDueToInactivityTimeoutMessage = "worker exiting due to inactivity timeout"
ReachedMaxHandoffRetriesMessage = "reached max handoff retries"
// ========================================
// MANAGER.GO - Moving operation tracking and handler registration
// ========================================
DuplicateMovingOperationMessage = "duplicate MOVING operation ignored"
TrackingMovingOperationMessage = "tracking MOVING operation"
UntrackingMovingOperationMessage = "untracking MOVING operation"
OperationNotTrackedMessage = "operation not tracked"
FailedToRegisterHandlerMessage = "failed to register handler"
// ========================================
// HOOKS.GO - Notification processing hooks
// ========================================
ProcessingNotificationMessage = "processing notification started"
ProcessingNotificationFailedMessage = "proccessing notification failed"
ProcessingNotificationSucceededMessage = "processing notification succeeded"
// ========================================
// POOL_HOOK.GO - Pool connection management
// ========================================
FailedToQueueHandoffMessage = "failed to queue handoff"
MarkedForHandoffMessage = "connection marked for handoff"
// ========================================
// PUSH_NOTIFICATION_HANDLER.GO - Push notification validation and processing
// ========================================
InvalidNotificationFormatMessage = "invalid notification format"
InvalidNotificationTypeFormatMessage = "invalid notification type format"
InvalidSeqIDInMovingNotificationMessage = "invalid seqID in MOVING notification"
InvalidTimeSInMovingNotificationMessage = "invalid timeS in MOVING notification"
InvalidNewEndpointInMovingNotificationMessage = "invalid newEndpoint in MOVING notification"
NoConnectionInHandlerContextMessage = "no connection in handler context"
InvalidConnectionTypeInHandlerContextMessage = "invalid connection type in handler context"
SchedulingHandoffToCurrentEndpointMessage = "scheduling handoff to current endpoint"
RelaxedTimeoutDueToNotificationMessage = "applying relaxed timeout due to notification"
UnrelaxedTimeoutMessage = "clearing relaxed timeout"
ManagerNotInitializedMessage = "manager not initialized"
FailedToMarkForHandoffMessage = "failed to mark connection for handoff"
// ========================================
// used in pool/conn
// ========================================
UnrelaxedTimeoutAfterDeadlineMessage = "clearing relaxed timeout after deadline"
)
func HandoffStarted(connID uint64, newEndpoint string) string {
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffStartedMessage, newEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": newEndpoint,
})
}
func HandoffFailed(connID uint64, newEndpoint string, attempt int, maxAttempts int, err error) string {
message := fmt.Sprintf("conn[%d] %s to %s (attempt %d/%d): %v", connID, HandoffFailedMessage, newEndpoint, attempt, maxAttempts, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": newEndpoint,
"attempt": attempt,
"maxAttempts": maxAttempts,
"error": err.Error(),
})
}
func HandoffSucceeded(connID uint64, newEndpoint string) string {
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffSuccessMessage, newEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": newEndpoint,
})
}
// Timeout-related log functions
func RelaxedTimeoutDueToNotification(connID uint64, notificationType string, timeout interface{}) string {
message := fmt.Sprintf("conn[%d] %s %s (%v)", connID, RelaxedTimeoutDueToNotificationMessage, notificationType, timeout)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"notificationType": notificationType,
"timeout": fmt.Sprintf("%v", timeout),
})
}
func UnrelaxedTimeout(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
func UnrelaxedTimeoutAfterDeadline(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutAfterDeadlineMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
// Handoff queue and marking functions
func HandoffQueueFull(queueLen, queueCap int) string {
message := fmt.Sprintf("%s (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration", HandoffQueueFullMessage, queueLen, queueCap)
return appendJSONIfDebug(message, map[string]interface{}{
"queueLen": queueLen,
"queueCap": queueCap,
})
}
func FailedToQueueHandoff(connID uint64, err error) string {
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToQueueHandoffMessage, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"error": err.Error(),
})
}
func FailedToMarkForHandoff(connID uint64, err error) string {
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToMarkForHandoffMessage, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"error": err.Error(),
})
}
func FailedToDialNewEndpoint(connID uint64, endpoint string, err error) string {
message := fmt.Sprintf("conn[%d] %s %s: %v", connID, FailedToDialNewEndpointMessage, endpoint, err)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"error": err.Error(),
})
}
func ReachedMaxHandoffRetries(connID uint64, endpoint string, maxRetries int) string {
message := fmt.Sprintf("conn[%d] %s to %s (max retries: %d)", connID, ReachedMaxHandoffRetriesMessage, endpoint, maxRetries)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"maxRetries": maxRetries,
})
}
// Notification processing functions
func ProcessingNotification(connID uint64, seqID int64, notificationType string, notification interface{}) string {
message := fmt.Sprintf("conn[%d] seqID[%d] %s %s: %v", connID, seqID, ProcessingNotificationMessage, notificationType, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seqID": seqID,
"notificationType": notificationType,
"notification": fmt.Sprintf("%v", notification),
})
}
func ProcessingNotificationFailed(connID uint64, notificationType string, err error, notification interface{}) string {
message := fmt.Sprintf("conn[%d] %s %s: %v - %v", connID, ProcessingNotificationFailedMessage, notificationType, err, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"notificationType": notificationType,
"error": err.Error(),
"notification": fmt.Sprintf("%v", notification),
})
}
func ProcessingNotificationSucceeded(connID uint64, notificationType string) string {
message := fmt.Sprintf("conn[%d] %s %s", connID, ProcessingNotificationSucceededMessage, notificationType)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"notificationType": notificationType,
})
}
// Moving operation tracking functions
func DuplicateMovingOperation(connID uint64, endpoint string, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, DuplicateMovingOperationMessage, endpoint, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"seqID": seqID,
})
}
func TrackingMovingOperation(connID uint64, endpoint string, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, TrackingMovingOperationMessage, endpoint, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
"seqID": seqID,
})
}
func UntrackingMovingOperation(connID uint64, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, UntrackingMovingOperationMessage, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seqID": seqID,
})
}
func OperationNotTracked(connID uint64, seqID int64) string {
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, OperationNotTrackedMessage, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seqID": seqID,
})
}
// Connection pool functions
func RemovingConnectionFromPool(connID uint64, reason error) string {
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, RemovingConnectionFromPoolMessage, reason)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"reason": reason.Error(),
})
}
func NoPoolProvidedCannotRemove(connID uint64, reason error) string {
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, NoPoolProvidedMessageCannotRemoveMessage, reason)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"reason": reason.Error(),
})
}
// Circuit breaker functions
func CircuitBreakerOpen(connID uint64, endpoint string) string {
message := fmt.Sprintf("conn[%d] %s for %s", connID, CircuitBreakerOpenMessage, endpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"endpoint": endpoint,
})
}
// Additional handoff functions for specific cases
func ConnectionNotMarkedForHandoff(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
func ConnectionNotMarkedForHandoffError(connID uint64) string {
return fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffErrorMessage)
}
func HandoffRetryAttempt(connID uint64, retries int, newEndpoint string, oldEndpoint string) string {
message := fmt.Sprintf("conn[%d] Retry %d: %s to %s(was %s)", connID, retries, HandoffRetryAttemptMessage, newEndpoint, oldEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"retries": retries,
"newEndpoint": newEndpoint,
"oldEndpoint": oldEndpoint,
})
}
func CannotQueueHandoffForRetry(err error) string {
message := fmt.Sprintf("%s: %v", CannotQueueHandoffForRetryMessage, err)
return appendJSONIfDebug(message, map[string]interface{}{
"error": err.Error(),
})
}
// Validation and error functions
func InvalidNotificationFormat(notification interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidNotificationFormatMessage, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"notification": fmt.Sprintf("%v", notification),
})
}
func InvalidNotificationTypeFormat(notificationType interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidNotificationTypeFormatMessage, notificationType)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": fmt.Sprintf("%v", notificationType),
})
}
// InvalidNotification creates a log message for invalid notifications of any type
func InvalidNotification(notificationType string, notification interface{}) string {
message := fmt.Sprintf("invalid %s notification: %v", notificationType, notification)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"notification": fmt.Sprintf("%v", notification),
})
}
func InvalidSeqIDInMovingNotification(seqID interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidSeqIDInMovingNotificationMessage, seqID)
return appendJSONIfDebug(message, map[string]interface{}{
"seqID": fmt.Sprintf("%v", seqID),
})
}
func InvalidTimeSInMovingNotification(timeS interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidTimeSInMovingNotificationMessage, timeS)
return appendJSONIfDebug(message, map[string]interface{}{
"timeS": fmt.Sprintf("%v", timeS),
})
}
func InvalidNewEndpointInMovingNotification(newEndpoint interface{}) string {
message := fmt.Sprintf("%s: %v", InvalidNewEndpointInMovingNotificationMessage, newEndpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"newEndpoint": fmt.Sprintf("%v", newEndpoint),
})
}
func NoConnectionInHandlerContext(notificationType string) string {
message := fmt.Sprintf("%s for %s notification", NoConnectionInHandlerContextMessage, notificationType)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
})
}
func InvalidConnectionTypeInHandlerContext(notificationType string, conn interface{}, handlerCtx interface{}) string {
message := fmt.Sprintf("%s for %s notification - %T %#v", InvalidConnectionTypeInHandlerContextMessage, notificationType, conn, handlerCtx)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"connType": fmt.Sprintf("%T", conn),
})
}
func SchedulingHandoffToCurrentEndpoint(connID uint64, seconds float64) string {
message := fmt.Sprintf("conn[%d] %s in %v seconds", connID, SchedulingHandoffToCurrentEndpointMessage, seconds)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"seconds": seconds,
})
}
func ManagerNotInitialized() string {
return appendJSONIfDebug(ManagerNotInitializedMessage, map[string]interface{}{})
}
func FailedToRegisterHandler(notificationType string, err error) string {
message := fmt.Sprintf("%s for %s: %v", FailedToRegisterHandlerMessage, notificationType, err)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"error": err.Error(),
})
}
func ShutdownError() string {
return appendJSONIfDebug(ShutdownErrorMessage, map[string]interface{}{})
}
// Configuration validation error functions
func InvalidRelaxedTimeoutError() string {
return appendJSONIfDebug(InvalidRelaxedTimeoutErrorMessage, map[string]interface{}{})
}
func InvalidHandoffTimeoutError() string {
return appendJSONIfDebug(InvalidHandoffTimeoutErrorMessage, map[string]interface{}{})
}
func InvalidHandoffWorkersError() string {
return appendJSONIfDebug(InvalidHandoffWorkersErrorMessage, map[string]interface{}{})
}
func InvalidHandoffQueueSizeError() string {
return appendJSONIfDebug(InvalidHandoffQueueSizeErrorMessage, map[string]interface{}{})
}
func InvalidPostHandoffRelaxedDurationError() string {
return appendJSONIfDebug(InvalidPostHandoffRelaxedDurationErrorMessage, map[string]interface{}{})
}
func InvalidEndpointTypeError() string {
return appendJSONIfDebug(InvalidEndpointTypeErrorMessage, map[string]interface{}{})
}
func InvalidMaintNotificationsError() string {
return appendJSONIfDebug(InvalidMaintNotificationsErrorMessage, map[string]interface{}{})
}
func InvalidHandoffRetriesError() string {
return appendJSONIfDebug(InvalidHandoffRetriesErrorMessage, map[string]interface{}{})
}
func InvalidClientError() string {
return appendJSONIfDebug(InvalidClientErrorMessage, map[string]interface{}{})
}
func InvalidNotificationError() string {
return appendJSONIfDebug(InvalidNotificationErrorMessage, map[string]interface{}{})
}
func MaxHandoffRetriesReachedError() string {
return appendJSONIfDebug(MaxHandoffRetriesReachedErrorMessage, map[string]interface{}{})
}
func HandoffQueueFullError() string {
return appendJSONIfDebug(HandoffQueueFullErrorMessage, map[string]interface{}{})
}
func InvalidCircuitBreakerFailureThresholdError() string {
return appendJSONIfDebug(InvalidCircuitBreakerFailureThresholdErrorMessage, map[string]interface{}{})
}
func InvalidCircuitBreakerResetTimeoutError() string {
return appendJSONIfDebug(InvalidCircuitBreakerResetTimeoutErrorMessage, map[string]interface{}{})
}
func InvalidCircuitBreakerMaxRequestsError() string {
return appendJSONIfDebug(InvalidCircuitBreakerMaxRequestsErrorMessage, map[string]interface{}{})
}
// Configuration and debug functions
func DebugLoggingEnabled() string {
return appendJSONIfDebug(DebugLoggingEnabledMessage, map[string]interface{}{})
}
func ConfigDebug(config interface{}) string {
message := fmt.Sprintf("%s: %+v", ConfigDebugMessage, config)
return appendJSONIfDebug(message, map[string]interface{}{
"config": fmt.Sprintf("%+v", config),
})
}
// Handoff worker functions
func WorkerExitingDueToShutdown() string {
return appendJSONIfDebug(WorkerExitingDueToShutdownMessage, map[string]interface{}{})
}
func WorkerExitingDueToShutdownWhileProcessing() string {
return appendJSONIfDebug(WorkerExitingDueToShutdownWhileProcessingMessage, map[string]interface{}{})
}
func WorkerPanicRecovered(panicValue interface{}) string {
message := fmt.Sprintf("%s: %v", WorkerPanicRecoveredMessage, panicValue)
return appendJSONIfDebug(message, map[string]interface{}{
"panic": fmt.Sprintf("%v", panicValue),
})
}
func WorkerExitingDueToInactivityTimeout(timeout interface{}) string {
message := fmt.Sprintf("%s (%v)", WorkerExitingDueToInactivityTimeoutMessage, timeout)
return appendJSONIfDebug(message, map[string]interface{}{
"timeout": fmt.Sprintf("%v", timeout),
})
}
func ApplyingRelaxedTimeoutDueToPostHandoff(connID uint64, timeout interface{}, until string) string {
message := fmt.Sprintf("conn[%d] %s (%v) until %s", connID, ApplyingRelaxedTimeoutDueToPostHandoffMessage, timeout, until)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
"timeout": fmt.Sprintf("%v", timeout),
"until": until,
})
}
// Example hooks functions
func MetricsHookProcessingNotification(notificationType string, connID uint64) string {
message := fmt.Sprintf("%s %s notification on conn[%d]", MetricsHookProcessingNotificationMessage, notificationType, connID)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"connID": connID,
})
}
func MetricsHookRecordedError(notificationType string, connID uint64, err error) string {
message := fmt.Sprintf("%s for %s notification on conn[%d]: %v", MetricsHookRecordedErrorMessage, notificationType, connID, err)
return appendJSONIfDebug(message, map[string]interface{}{
"notificationType": notificationType,
"connID": connID,
"error": err.Error(),
})
}
// Pool hook functions
func MarkedForHandoff(connID uint64) string {
message := fmt.Sprintf("conn[%d] %s", connID, MarkedForHandoffMessage)
return appendJSONIfDebug(message, map[string]interface{}{
"connID": connID,
})
}
// Circuit breaker additional functions
func CircuitBreakerTransitioningToHalfOpen(endpoint string) string {
message := fmt.Sprintf("%s for %s", CircuitBreakerTransitioningToHalfOpenMessage, endpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
})
}
func CircuitBreakerOpened(endpoint string, failures int64) string {
message := fmt.Sprintf("%s for endpoint %s after %d failures", CircuitBreakerOpenedMessage, endpoint, failures)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
"failures": failures,
})
}
func CircuitBreakerReopened(endpoint string) string {
message := fmt.Sprintf("%s for endpoint %s due to failure in half-open state", CircuitBreakerReopenedMessage, endpoint)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
})
}
func CircuitBreakerClosed(endpoint string, successes int64) string {
message := fmt.Sprintf("%s for endpoint %s after %d successful requests", CircuitBreakerClosedMessage, endpoint, successes)
return appendJSONIfDebug(message, map[string]interface{}{
"endpoint": endpoint,
"successes": successes,
})
}
func CircuitBreakerCleanup(removed int, total int) string {
message := fmt.Sprintf("%s removed %d/%d entries", CircuitBreakerCleanupMessage, removed, total)
return appendJSONIfDebug(message, map[string]interface{}{
"removed": removed,
"total": total,
})
}
// ExtractDataFromLogMessage extracts structured data from maintnotifications log messages
// Returns a map containing the parsed key-value pairs from the structured data section
// Example: "conn[123] handoff started to localhost:6379 {"connID":123,"endpoint":"localhost:6379"}"
// Returns: map[string]interface{}{"connID": 123, "endpoint": "localhost:6379"}
func ExtractDataFromLogMessage(logMessage string) map[string]interface{} {
result := make(map[string]interface{})
// Find the JSON data section at the end of the message
re := regexp.MustCompile(`(\{.*\})$`)
matches := re.FindStringSubmatch(logMessage)
if len(matches) < 2 {
return result
}
jsonStr := matches[1]
if jsonStr == "" {
return result
}
// Parse the JSON directly
var jsonResult map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &jsonResult); err == nil {
return jsonResult
}
// If JSON parsing fails, return empty map
return result
}
+789 -21
View File
@@ -1,28 +1,124 @@
// Package pool implements the pool management
package pool
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/proto"
)
var noDeadline = time.Time{}
// Preallocated errors for hot paths to avoid allocations
var (
errAlreadyMarkedForHandoff = errors.New("connection is already marked for handoff")
errNotMarkedForHandoff = errors.New("connection was not marked for handoff")
errHandoffStateChanged = errors.New("handoff state changed during marking")
errConnectionNotAvailable = errors.New("redis: connection not available")
errConnNotAvailableForWrite = errors.New("redis: connection not available for write operation")
)
// getCachedTimeNs returns the current time in nanoseconds.
// This function previously used a global cache updated by a background goroutine,
// but that caused unnecessary CPU usage when the client was idle (ticker waking up
// the scheduler every 50ms). We now use time.Now() directly, which is fast enough
// on modern systems (vDSO on Linux) and only adds ~1-2% overhead in extreme
// high-concurrency benchmarks while eliminating idle CPU usage.
func getCachedTimeNs() int64 {
return time.Now().UnixNano()
}
// GetCachedTimeNs returns the current time in nanoseconds.
// Exported for use by other packages that need fast time access.
func GetCachedTimeNs() int64 {
return getCachedTimeNs()
}
// Global atomic counter for connection IDs
var connIDCounter uint64
// HandoffState represents the atomic state for connection handoffs
// This struct is stored atomically to prevent race conditions between
// checking handoff status and reading handoff parameters
type HandoffState struct {
ShouldHandoff bool // Whether connection should be handed off
Endpoint string // New endpoint for handoff
SeqID int64 // Sequence ID from MOVING notification
}
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
type atomicNetConn struct {
conn net.Conn
}
// generateConnID generates a fast unique identifier for a connection with zero allocations
func generateConnID() uint64 {
return atomic.AddUint64(&connIDCounter, 1)
}
type Conn struct {
usedAt int64 // atomic
netConn net.Conn
// Connection identifier for unique tracking
id uint64
usedAt atomic.Int64
lastPutAt atomic.Int64
// Lock-free netConn access using atomic.Value
// Contains *atomicNetConn wrapper, accessed atomically for better performance
netConnAtomic atomic.Value // stores *atomicNetConn
rd *proto.Reader
bw *bufio.Writer
wr *proto.Writer
Inited bool
// Lightweight mutex to protect reader operations during handoff
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
readerMu sync.RWMutex
// State machine for connection state management
// Replaces: usable, Inited, used
// Provides thread-safe state transitions with FIFO waiting queue
// States: CREATED → INITIALIZING → IDLE ⇄ IN_USE
// ↓
// UNUSABLE (handoff/reauth)
// ↓
// IDLE/CLOSED
stateMachine *ConnStateMachine
// Handoff metadata - managed separately from state machine
// These are atomic for lock-free access during handoff operations
handoffStateAtomic atomic.Value // stores *HandoffState
handoffRetriesAtomic atomic.Uint32 // retry counter
pooled bool
pubsub bool
closed atomic.Bool
createdAt time.Time
expiresAt time.Time
// maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers
// Using atomic operations for lock-free access to avoid mutex contention
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
// Counter to track multiple relaxed timeout setters if we have nested calls
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
// if counter reaches 0, we clear the relaxed timeouts
relaxedCounter atomic.Int32
// Connection initialization function for reconnections
initConnFunc func(context.Context, *Conn) error
onClose func() error
}
@@ -32,9 +128,11 @@ func NewConn(netConn net.Conn) *Conn {
}
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
now := time.Now()
cn := &Conn{
netConn: netConn,
createdAt: time.Now(),
createdAt: now,
id: generateConnID(), // Generate unique ID for this connection
stateMachine: NewConnStateMachine(),
}
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
@@ -50,37 +148,656 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
}
// Store netConn atomically for lock-free access using wrapper
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
cn.wr = proto.NewWriter(cn.bw)
cn.SetUsedAt(time.Now())
cn.SetUsedAt(now)
// Initialize handoff state atomically
initialHandoffState := &HandoffState{
ShouldHandoff: false,
Endpoint: "",
SeqID: 0,
}
cn.handoffStateAtomic.Store(initialHandoffState)
return cn
}
func (cn *Conn) UsedAt() time.Time {
unix := atomic.LoadInt64(&cn.usedAt)
return time.Unix(unix, 0)
return time.Unix(0, cn.usedAt.Load())
}
func (cn *Conn) SetUsedAt(tm time.Time) {
cn.usedAt.Store(tm.UnixNano())
}
func (cn *Conn) SetUsedAt(tm time.Time) {
atomic.StoreInt64(&cn.usedAt, tm.Unix())
func (cn *Conn) UsedAtNs() int64 {
return cn.usedAt.Load()
}
func (cn *Conn) SetUsedAtNs(ns int64) {
cn.usedAt.Store(ns)
}
func (cn *Conn) LastPutAtNs() int64 {
return cn.lastPutAt.Load()
}
func (cn *Conn) SetLastPutAtNs(ns int64) {
cn.lastPutAt.Store(ns)
}
// Backward-compatible wrapper methods for state machine
// These maintain the existing API while using the new state machine internally
// CompareAndSwapUsable atomically compares and swaps the usable flag (lock-free).
//
// This is used by background operations (handoff, re-auth) to acquire exclusive
// access to a connection. The operation sets usable to false, preventing the pool
// from returning the connection to clients.
//
// Returns true if the swap was successful (old value matched), false otherwise.
//
// Implementation note: This is a compatibility wrapper around the state machine.
// It checks if the current state is "usable" (IDLE or IN_USE) and transitions accordingly.
// Deprecated: Use GetStateMachine().TryTransition() directly for better state management.
func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
currentState := cn.stateMachine.GetState()
// Check if current state matches the "old" usable value
currentUsable := (currentState == StateIdle || currentState == StateInUse)
if currentUsable != old {
return false
}
// If we're trying to set to the same value, succeed immediately
if old == new {
return true
}
// Transition based on new value
if new {
// Trying to make usable - transition from UNUSABLE to IDLE
// This should only work from UNUSABLE or INITIALIZING states
// Use predefined slice to avoid allocation
_, err := cn.stateMachine.TryTransition(
validFromInitializingOrUnusable,
StateIdle,
)
return err == nil
}
// Trying to make unusable - transition from IDLE to UNUSABLE
// This is typically for acquiring the connection for background operations
// Use predefined slice to avoid allocation
_, err := cn.stateMachine.TryTransition(
validFromIdle,
StateUnusable,
)
return err == nil
}
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
//
// A connection is "usable" when it's in a stable state and can be returned to clients.
// It becomes unusable during:
// - Handoff operations (network connection replacement)
// - Re-authentication (credential updates)
// - Other background operations that need exclusive access
//
// Note: CREATED state is considered usable because new connections need to pass OnGet() hook
// before initialization. The initialization happens after OnGet() in the client code.
func (cn *Conn) IsUsable() bool {
state := cn.stateMachine.GetState()
// CREATED, IDLE, and IN_USE states are considered usable
// CREATED: new connection, not yet initialized (will be initialized by client)
// IDLE: initialized and ready to be acquired
// IN_USE: usable but currently acquired by someone
return state == StateCreated || state == StateIdle || state == StateInUse
}
// SetUsable sets the usable flag for the connection (lock-free).
//
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
// This method is kept for backwards compatibility.
//
// This should be called to mark a connection as usable after initialization or
// to release it after a background operation completes.
//
// Prefer CompareAndSwapUsable() when acquiring exclusive access to avoid race conditions.
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
func (cn *Conn) SetUsable(usable bool) {
if usable {
// Transition to IDLE state (ready to be acquired)
cn.stateMachine.Transition(StateIdle)
} else {
// Transition to UNUSABLE state (for background operations)
cn.stateMachine.Transition(StateUnusable)
}
}
// IsInited returns true if the connection has been initialized.
// This is a backward-compatible wrapper around the state machine.
func (cn *Conn) IsInited() bool {
state := cn.stateMachine.GetState()
// Connection is initialized if it's in IDLE or any post-initialization state
return state != StateCreated && state != StateInitializing && state != StateClosed
}
// Used - State machine based implementation
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
// This method is kept for backwards compatibility.
//
// This is the preferred method for acquiring a connection from the pool, as it
// ensures that only one goroutine marks the connection as used.
//
// Implementation: Uses state machine transitions IDLE ⇄ IN_USE
//
// Returns true if the swap was successful (old value matched), false otherwise.
// Deprecated: Use GetStateMachine().TryTransition() directly for better state management.
func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
if old == new {
// No change needed
currentState := cn.stateMachine.GetState()
currentUsed := (currentState == StateInUse)
return currentUsed == old
}
if !old && new {
// Acquiring: IDLE → IN_USE
// Use predefined slice to avoid allocation
_, err := cn.stateMachine.TryTransition(validFromCreatedOrIdle, StateInUse)
return err == nil
} else {
// Releasing: IN_USE → IDLE
// Use predefined slice to avoid allocation
_, err := cn.stateMachine.TryTransition(validFromInUse, StateIdle)
return err == nil
}
}
// IsUsed returns true if the connection is currently in use (lock-free).
//
// Deprecated: Use GetStateMachine().GetState() == StateInUse directly for better clarity.
// This method is kept for backwards compatibility.
//
// A connection is "used" when it has been retrieved from the pool and is
// actively processing a command. Background operations (like re-auth) should
// wait until the connection is not used before executing commands.
func (cn *Conn) IsUsed() bool {
return cn.stateMachine.GetState() == StateInUse
}
// SetUsed sets the used flag for the connection (lock-free).
//
// This should be called when returning a connection to the pool (set to false)
// or when a single-connection pool retrieves its connection (set to true).
//
// Prefer CompareAndSwapUsed() when acquiring from a multi-connection pool to
// avoid race conditions.
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
func (cn *Conn) SetUsed(val bool) {
if val {
cn.stateMachine.Transition(StateInUse)
} else {
cn.stateMachine.Transition(StateIdle)
}
}
// getNetConn returns the current network connection using atomic load (lock-free).
// This is the fast path for accessing netConn without mutex overhead.
func (cn *Conn) getNetConn() net.Conn {
if v := cn.netConnAtomic.Load(); v != nil {
if wrapper, ok := v.(*atomicNetConn); ok {
return wrapper.conn
}
}
return nil
}
// setNetConn stores the network connection atomically (lock-free).
// This is used for the fast path of connection replacement.
func (cn *Conn) setNetConn(netConn net.Conn) {
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
}
// Handoff state management - atomic access to handoff metadata
// ShouldHandoff returns true if connection needs handoff (lock-free).
func (cn *Conn) ShouldHandoff() bool {
if v := cn.handoffStateAtomic.Load(); v != nil {
return v.(*HandoffState).ShouldHandoff
}
return false
}
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
func (cn *Conn) GetHandoffEndpoint() string {
if v := cn.handoffStateAtomic.Load(); v != nil {
return v.(*HandoffState).Endpoint
}
return ""
}
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
func (cn *Conn) GetMovingSeqID() int64 {
if v := cn.handoffStateAtomic.Load(); v != nil {
return v.(*HandoffState).SeqID
}
return 0
}
// GetHandoffInfo returns all handoff information atomically (lock-free).
// This method prevents race conditions by returning all handoff state in a single atomic operation.
// Returns (shouldHandoff, endpoint, seqID).
func (cn *Conn) GetHandoffInfo() (bool, string, int64) {
if v := cn.handoffStateAtomic.Load(); v != nil {
state := v.(*HandoffState)
return state.ShouldHandoff, state.Endpoint, state.SeqID
}
return false, "", 0
}
// HandoffRetries returns the current handoff retry count (lock-free).
func (cn *Conn) HandoffRetries() int {
return int(cn.handoffRetriesAtomic.Load())
}
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
return int(cn.handoffRetriesAtomic.Add(uint32(n)))
}
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
func (cn *Conn) IsPooled() bool {
return cn.pooled
}
// IsPubSub returns true if the connection is used for PubSub.
func (cn *Conn) IsPubSub() bool {
return cn.pubsub
}
// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades.
// These timeouts will be used for all subsequent commands until the deadline expires.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
cn.relaxedCounter.Add(1)
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
}
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
// After the deadline, timeouts automatically revert to normal values.
// Uses atomic operations for lock-free access.
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
cn.SetRelaxedTimeout(readTimeout, writeTimeout)
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
}
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
// Uses atomic operations for lock-free access.
func (cn *Conn) ClearRelaxedTimeout() {
// Atomically decrement counter and check if we should clear
newCount := cn.relaxedCounter.Add(-1)
deadlineNs := cn.relaxedDeadlineNs.Load()
if newCount <= 0 && (deadlineNs == 0 || time.Now().UnixNano() >= deadlineNs) {
// Use atomic load to get current value for CAS to avoid stale value race
current := cn.relaxedCounter.Load()
if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) {
cn.clearRelaxedTimeout()
}
}
}
func (cn *Conn) clearRelaxedTimeout() {
cn.relaxedReadTimeoutNs.Store(0)
cn.relaxedWriteTimeoutNs.Store(0)
cn.relaxedDeadlineNs.Store(0)
cn.relaxedCounter.Store(0)
}
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
// This checks both the timeout values and the deadline (if set).
// Uses atomic operations for lock-free access.
func (cn *Conn) HasRelaxedTimeout() bool {
// Fast path: no relaxed timeouts are set
if cn.relaxedCounter.Load() <= 0 {
return false
}
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// If no relaxed timeouts are set, return false
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
return false
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, relaxed timeouts are active
if deadlineNs == 0 {
return true
}
// If deadline is set, check if it's still in the future
return time.Now().UnixNano() < deadlineNs
}
// getEffectiveReadTimeout returns the timeout to use for read operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
// Fast path: no relaxed timeout set
if readTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(readTimeoutNs)
}
// Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
nowNs := getCachedTimeNs()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(readTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
newCount := cn.relaxedCounter.Add(-1)
if newCount <= 0 {
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
// getEffectiveWriteTimeout returns the timeout to use for write operations.
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
// This method automatically clears expired relaxed timeouts using atomic operations.
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
// Fast path: no relaxed timeout set
if writeTimeoutNs <= 0 {
return normalTimeout
}
deadlineNs := cn.relaxedDeadlineNs.Load()
// If no deadline is set, use relaxed timeout
if deadlineNs == 0 {
return time.Duration(writeTimeoutNs)
}
// Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
nowNs := getCachedTimeNs()
// Check if deadline has passed
if nowNs < deadlineNs {
// Deadline is in the future, use relaxed timeout
return time.Duration(writeTimeoutNs)
} else {
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
newCount := cn.relaxedCounter.Add(-1)
if newCount <= 0 {
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
cn.clearRelaxedTimeout()
}
return normalTimeout
}
}
func (cn *Conn) SetOnClose(fn func() error) {
cn.onClose = fn
}
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
cn.initConnFunc = fn
}
// ExecuteInitConn runs the stored connection initialization function if available.
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
if cn.initConnFunc != nil {
return cn.initConnFunc(ctx, cn)
}
return fmt.Errorf("redis: no initConnFunc set for conn[%d]", cn.GetID())
}
func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.netConn = netConn
// Store the new connection atomically first (lock-free)
cn.setNetConn(netConn)
// Protect reader reset operations to avoid data races
// Use write lock since we're modifying the reader state
cn.readerMu.Lock()
cn.rd.Reset(netConn)
cn.readerMu.Unlock()
cn.bw.Reset(netConn)
}
// GetNetConn safely returns the current network connection using atomic load (lock-free).
// This method is used by the pool for health checks and provides better performance.
func (cn *Conn) GetNetConn() net.Conn {
return cn.getNetConn()
}
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
// This method ensures only one initialization can happen at a time by using atomic state transitions.
// If another goroutine is currently initializing, this will wait for it to complete.
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
// Wait for and transition to INITIALIZING state - this prevents concurrent initializations
// Valid from states: CREATED (first init), IDLE (reconnect), UNUSABLE (handoff/reauth)
// If another goroutine is initializing, we'll wait for it to finish
// if the context has a deadline, use that, otherwise use the connection read (relaxed) timeout
// which should be set during handoff. If it is not set, use a 5 second default
deadline, ok := ctx.Deadline()
if !ok {
deadline = time.Now().Add(cn.getEffectiveReadTimeout(5 * time.Second))
}
waitCtx, cancel := context.WithDeadline(ctx, deadline)
defer cancel()
// Use predefined slice to avoid allocation
finalState, err := cn.stateMachine.AwaitAndTransition(
waitCtx,
validFromCreatedIdleOrUnusable,
StateInitializing,
)
if err != nil {
return fmt.Errorf("cannot initialize connection from state %s: %w", finalState, err)
}
// Replace the underlying connection
cn.SetNetConn(netConn)
// Execute initialization
// NOTE: ExecuteInitConn (via baseClient.initConn) will transition to IDLE on success
// or CLOSED on failure. We don't need to do it here.
// NOTE: Initconn returns conn in IDLE state
initErr := cn.ExecuteInitConn(ctx)
if initErr != nil {
// ExecuteInitConn already transitioned to CLOSED, just return the error
return initErr
}
// ExecuteInitConn already transitioned to IDLE
return nil
}
// MarkForHandoff marks the connection for handoff due to MOVING notification.
// Returns an error if the connection is already marked for handoff.
// Note: This only sets metadata - the connection state is not changed until OnPut.
// This allows the current user to finish using the connection before handoff.
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
// Check if already marked for handoff
if cn.ShouldHandoff() {
return errAlreadyMarkedForHandoff
}
// Set handoff metadata atomically
cn.handoffStateAtomic.Store(&HandoffState{
ShouldHandoff: true,
Endpoint: newEndpoint,
SeqID: seqID,
})
return nil
}
// MarkQueuedForHandoff marks the connection as queued for handoff processing.
// This makes the connection unusable until handoff completes.
// This is called from OnPut hook, where the connection is typically in IN_USE state.
// The pool will preserve the UNUSABLE state and not overwrite it with IDLE.
func (cn *Conn) MarkQueuedForHandoff() error {
// Get current handoff state
currentState := cn.handoffStateAtomic.Load()
if currentState == nil {
return errNotMarkedForHandoff
}
state := currentState.(*HandoffState)
if !state.ShouldHandoff {
return errNotMarkedForHandoff
}
// Create new state with ShouldHandoff=false but preserve endpoint and seqID
// This prevents the connection from being queued multiple times while still
// allowing the worker to access the handoff metadata
newState := &HandoffState{
ShouldHandoff: false,
Endpoint: state.Endpoint, // Preserve endpoint for handoff processing
SeqID: state.SeqID, // Preserve seqID for handoff processing
}
// Atomic compare-and-swap to update state
if !cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
// State changed between load and CAS - retry or return error
return errHandoffStateChanged
}
// Transition to UNUSABLE from IN_USE (normal flow), IDLE (edge cases), or CREATED (tests/uninitialized)
// The connection is typically in IN_USE state when OnPut is called (normal Put flow)
// But in some edge cases or tests, it might be in IDLE or CREATED state
// The pool will detect this state change and preserve it (not overwrite with IDLE)
// Use predefined slice to avoid allocation
finalState, err := cn.stateMachine.TryTransition(validFromCreatedInUseOrIdle, StateUnusable)
if err != nil {
// Check if already in UNUSABLE state (race condition or retry)
// ShouldHandoff should be false now, but check just in case
if finalState == StateUnusable && !cn.ShouldHandoff() {
// Already unusable - this is fine, keep the new handoff state
return nil
}
// Restore the original state if transition fails for other reasons
cn.handoffStateAtomic.Store(currentState)
return fmt.Errorf("failed to mark connection as unusable: %w", err)
}
return nil
}
// GetID returns the unique identifier for this connection.
func (cn *Conn) GetID() uint64 {
return cn.id
}
// GetStateMachine returns the connection's state machine for advanced state management.
// This is primarily used by internal packages like maintnotifications for handoff processing.
func (cn *Conn) GetStateMachine() *ConnStateMachine {
return cn.stateMachine
}
// TryAcquire attempts to acquire the connection for use.
// This is an optimized inline method for the hot path (Get operation).
//
// It tries to transition from IDLE -> IN_USE or CREATED -> CREATED.
// Returns true if the connection was successfully acquired, false otherwise.
// The CREATED->CREATED is done so we can keep the state correct for later
// initialization of the connection in initConn.
//
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast()
//
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
// methods. This breaks encapsulation but is necessary for performance.
// The IDLE->IN_USE and CREATED->CREATED transitions don't need
// waiter notification, and benchmarks show 1-3% improvement. If the state machine ever
// needs to notify waiters on these transitions, update this to use TryTransitionFast().
func (cn *Conn) TryAcquire() bool {
// The || operator short-circuits, so only 1 CAS in the common case
return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) ||
cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateCreated))
}
// Release releases the connection back to the pool.
// This is an optimized inline method for the hot path (Put operation).
//
// It tries to transition from IN_USE -> IDLE.
// Returns true if the connection was successfully released, false otherwise.
//
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast().
//
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
// methods. This breaks encapsulation but is necessary for performance.
// If the state machine ever needs to notify waiters
// on this transition, update this to use TryTransitionFast().
func (cn *Conn) Release() bool {
// Inline the hot path - single CAS operation
return cn.stateMachine.state.CompareAndSwap(uint32(StateInUse), uint32(StateIdle))
}
// ClearHandoffState clears the handoff state after successful handoff.
// Makes the connection usable again.
func (cn *Conn) ClearHandoffState() {
// Clear handoff metadata
cn.handoffStateAtomic.Store(&HandoffState{
ShouldHandoff: false,
Endpoint: "",
SeqID: 0,
})
// Reset retry counter
cn.handoffRetriesAtomic.Store(0)
// Mark connection as usable again
// Use state machine directly instead of deprecated SetUsable
// probably done by initConn
cn.stateMachine.Transition(StateIdle)
}
// HasBufferedData safely checks if the connection has buffered data.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) HasBufferedData() bool {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
return cn.rd.Buffered() > 0
}
// PeekReplyTypeSafe safely peeks at the reply type.
// This method is used to avoid data races when checking for push notifications.
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
// Use read lock for concurrent access to reader state
cn.readerMu.RLock()
defer cn.readerMu.RUnlock()
if cn.rd.Buffered() <= 0 {
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
}
return cn.rd.PeekReplyType()
}
func (cn *Conn) Write(b []byte) (int, error) {
return cn.netConn.Write(b)
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Write(b)
}
return 0, net.ErrClosed
}
func (cn *Conn) RemoteAddr() net.Addr {
if cn.netConn != nil {
return cn.netConn.RemoteAddr()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.RemoteAddr()
}
return nil
}
@@ -89,7 +806,16 @@ func (cn *Conn) WithReader(
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
// Get the connection directly from atomic storage
netConn := cn.getNetConn()
if netConn == nil {
return errConnectionNotAvailable
}
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
}
@@ -100,13 +826,25 @@ func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
if timeout >= 0 {
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
return err
// Use relaxed timeout if set, otherwise use provided timeout
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
// Set write deadline on the connection
if netConn := cn.getNetConn(); netConn != nil {
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
return err
}
} else {
// Connection is not available - return preallocated error
return errConnNotAvailableForWrite
}
}
// Reset the buffered writer if needed, should not happen
if cn.bw.Buffered() > 0 {
cn.bw.Reset(cn.netConn)
if netConn := cn.getNetConn(); netConn != nil {
cn.bw.Reset(netConn)
}
}
if err := fn(cn.wr); err != nil {
@@ -116,17 +854,47 @@ func (cn *Conn) WithWriter(
return cn.bw.Flush()
}
func (cn *Conn) IsClosed() bool {
return cn.closed.Load() || cn.stateMachine.GetState() == StateClosed
}
func (cn *Conn) Close() error {
cn.closed.Store(true)
// Transition to CLOSED state
cn.stateMachine.Transition(StateClosed)
if cn.onClose != nil {
// ignore error
_ = cn.onClose()
}
return cn.netConn.Close()
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return netConn.Close()
}
return nil
}
// MaybeHasData tries to peek at the next byte in the socket without consuming it
// This is used to check if there are push notifications available
// Important: This will work on Linux, but not on Windows
func (cn *Conn) MaybeHasData() bool {
// Lock-free netConn access for better performance
if netConn := cn.getNetConn(); netConn != nil {
return maybeHasData(netConn)
}
return false
}
// deadline computes the effective deadline time based on context and timeout.
// It updates the usedAt timestamp to now.
// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation).
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
tm := time.Now()
cn.SetUsedAt(tm)
// Use cached time for deadline calculation (called 2x per command: read + write)
nowNs := getCachedTimeNs()
cn.SetUsedAtNs(nowNs)
tm := time.Unix(0, nowNs)
if timeout > 0 {
tm = tm.Add(timeout)
+11 -1
View File
@@ -12,6 +12,9 @@ import (
var errUnexpectedRead = errors.New("unexpected read from socket")
// connCheck checks if the connection is still alive and if there is data in the socket
// it will try to peek at the next byte without consuming it since we may want to work with it
// later on (e.g. push notifications)
func connCheck(conn net.Conn) error {
// Reset previous timeout.
_ = conn.SetDeadline(time.Time{})
@@ -29,7 +32,9 @@ func connCheck(conn net.Conn) error {
if err := rawConn.Read(func(fd uintptr) bool {
var buf [1]byte
n, err := syscall.Read(int(fd), buf[:])
// Use MSG_PEEK to peek at data without consuming it
n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK|syscall.MSG_DONTWAIT)
switch {
case n == 0 && err == nil:
sysErr = io.EOF
@@ -47,3 +52,8 @@ func connCheck(conn net.Conn) error {
return sysErr
}
// maybeHasData checks if there is data in the socket without consuming it
func maybeHasData(conn net.Conn) bool {
return connCheck(conn) == errUnexpectedRead
}
+13 -2
View File
@@ -2,8 +2,19 @@
package pool
import "net"
import (
"errors"
"net"
)
func connCheck(conn net.Conn) error {
// errUnexpectedRead is placeholder error variable for non-unix build constraints
var errUnexpectedRead = errors.New("unexpected read from socket")
func connCheck(_ net.Conn) error {
return nil
}
// since we can't check for data on the socket, we just assume there is some
func maybeHasData(_ net.Conn) bool {
return true
}
+343
View File
@@ -0,0 +1,343 @@
package pool
import (
"container/list"
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
)
// ConnState represents the connection state in the state machine.
// States are designed to be lightweight and fast to check.
//
// State Transitions:
//
// CREATED → INITIALIZING → IDLE ⇄ IN_USE
// ↓
// UNUSABLE (handoff/reauth)
// ↓
// IDLE/CLOSED
type ConnState uint32
const (
// StateCreated - Connection just created, not yet initialized
StateCreated ConnState = iota
// StateInitializing - Connection initialization in progress
StateInitializing
// StateIdle - Connection initialized and idle in pool, ready to be acquired
StateIdle
// StateInUse - Connection actively processing a command (retrieved from pool)
StateInUse
// StateUnusable - Connection temporarily unusable due to background operation
// (handoff, reauth, etc.). Cannot be acquired from pool.
StateUnusable
// StateClosed - Connection closed
StateClosed
)
// Predefined state slices to avoid allocations in hot paths
var (
validFromInUse = []ConnState{StateInUse}
validFromCreatedOrIdle = []ConnState{StateCreated, StateIdle}
validFromCreatedInUseOrIdle = []ConnState{StateCreated, StateInUse, StateIdle}
// For AwaitAndTransition calls
validFromCreatedIdleOrUnusable = []ConnState{StateCreated, StateIdle, StateUnusable}
validFromIdle = []ConnState{StateIdle}
// For CompareAndSwapUsable
validFromInitializingOrUnusable = []ConnState{StateInitializing, StateUnusable}
)
// Accessor functions for predefined slices to avoid allocations in external packages
// These return the same slice instance, so they're zero-allocation
// ValidFromIdle returns a predefined slice containing only StateIdle.
// Use this to avoid allocations when calling AwaitAndTransition or TryTransition.
func ValidFromIdle() []ConnState {
return validFromIdle
}
// ValidFromCreatedIdleOrUnusable returns a predefined slice for initialization transitions.
// Use this to avoid allocations when calling AwaitAndTransition or TryTransition.
func ValidFromCreatedIdleOrUnusable() []ConnState {
return validFromCreatedIdleOrUnusable
}
// String returns a human-readable string representation of the state.
func (s ConnState) String() string {
switch s {
case StateCreated:
return "CREATED"
case StateInitializing:
return "INITIALIZING"
case StateIdle:
return "IDLE"
case StateInUse:
return "IN_USE"
case StateUnusable:
return "UNUSABLE"
case StateClosed:
return "CLOSED"
default:
return fmt.Sprintf("UNKNOWN(%d)", s)
}
}
var (
// ErrInvalidStateTransition is returned when a state transition is not allowed
ErrInvalidStateTransition = errors.New("invalid state transition")
// ErrStateMachineClosed is returned when operating on a closed state machine
ErrStateMachineClosed = errors.New("state machine is closed")
// ErrTimeout is returned when a state transition times out
ErrTimeout = errors.New("state transition timeout")
)
// waiter represents a goroutine waiting for a state transition.
// Designed for minimal allocations and fast processing.
type waiter struct {
validStates map[ConnState]struct{} // States we're waiting for
targetState ConnState // State to transition to
done chan error // Signaled when transition completes or times out
}
// ConnStateMachine manages connection state transitions with FIFO waiting queue.
// Optimized for:
// - Lock-free reads (hot path)
// - Minimal allocations
// - Fast state transitions
// - FIFO fairness for waiters
// Note: Handoff metadata (endpoint, seqID, retries) is managed separately in the Conn struct.
type ConnStateMachine struct {
// Current state - atomic for lock-free reads
state atomic.Uint32
// FIFO queue for waiters - only locked during waiter add/remove/notify
mu sync.Mutex
waiters *list.List // List of *waiter
waiterCount atomic.Int32 // Fast lock-free check for waiters (avoids mutex in hot path)
}
// NewConnStateMachine creates a new connection state machine.
// Initial state is StateCreated.
func NewConnStateMachine() *ConnStateMachine {
sm := &ConnStateMachine{
waiters: list.New(),
}
sm.state.Store(uint32(StateCreated))
return sm
}
// GetState returns the current state (lock-free read).
// This is the hot path - optimized for zero allocations and minimal overhead.
// Note: Zero allocations applies to state reads; converting the returned state to a string
// (via String()) may allocate if the state is unknown.
func (sm *ConnStateMachine) GetState() ConnState {
return ConnState(sm.state.Load())
}
// TryTransitionFast is an optimized version for the hot path (Get/Put operations).
// It only handles simple state transitions without waiter notification.
// This is safe because:
// 1. Get/Put don't need to wait for state changes
// 2. Background operations (handoff/reauth) use UNUSABLE state, which this won't match
// 3. If a background operation is in progress (state is UNUSABLE), this fails fast
//
// Returns true if transition succeeded, false otherwise.
// Use this for performance-critical paths where you don't need error details.
//
// Performance: Single CAS operation - as fast as the old atomic bool!
// For multiple from states, use: sm.TryTransitionFast(State1, Target) || sm.TryTransitionFast(State2, Target)
// The || operator short-circuits, so only 1 CAS is executed in the common case.
func (sm *ConnStateMachine) TryTransitionFast(fromState, targetState ConnState) bool {
return sm.state.CompareAndSwap(uint32(fromState), uint32(targetState))
}
// TryTransition attempts an immediate state transition without waiting.
// Returns the current state after the transition attempt and an error if the transition failed.
// The returned state is the CURRENT state (after the attempt), not the previous state.
// This is faster than AwaitAndTransition when you don't need to wait.
// Uses compare-and-swap to atomically transition, preventing concurrent transitions.
// This method does NOT wait - it fails immediately if the transition cannot be performed.
//
// Performance: Zero allocations on success path (hot path).
func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) (ConnState, error) {
// Try each valid from state with CAS
// This ensures only ONE goroutine can successfully transition at a time
for _, fromState := range validFromStates {
// Try to atomically swap from fromState to targetState
// If successful, we won the race and can proceed
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
// Success! We transitioned atomically
// Hot path optimization: only check for waiters if transition succeeded
// This avoids atomic load on every Get/Put when no waiters exist
if sm.waiterCount.Load() > 0 {
sm.notifyWaiters()
}
return targetState, nil
}
}
// All CAS attempts failed - state is not valid for this transition
// Return the current state so caller can decide what to do
// Note: This error path allocates, but it's the exceptional case
currentState := sm.GetState()
return currentState, fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)",
ErrInvalidStateTransition, currentState, targetState, validFromStates)
}
// Transition unconditionally transitions to the target state.
// Use with caution - prefer AwaitAndTransition or TryTransition for safety.
// This is useful for error paths or when you know the transition is valid.
func (sm *ConnStateMachine) Transition(targetState ConnState) {
sm.state.Store(uint32(targetState))
sm.notifyWaiters()
}
// AwaitAndTransition waits for the connection to reach one of the valid states,
// then atomically transitions to the target state.
// Returns the current state after the transition attempt and an error if the operation failed.
// The returned state is the CURRENT state (after the attempt), not the previous state.
// Returns error if timeout expires or context is cancelled.
//
// This method implements FIFO fairness - the first caller to wait gets priority
// when the state becomes available.
//
// Performance notes:
// - If already in a valid state, this is very fast (no allocation, no waiting)
// - If waiting is required, allocates one waiter struct and one channel
func (sm *ConnStateMachine) AwaitAndTransition(
ctx context.Context,
validFromStates []ConnState,
targetState ConnState,
) (ConnState, error) {
// Fast path: try immediate transition with CAS to prevent race conditions
// BUT: only if there are no waiters in the queue (to maintain FIFO ordering)
if sm.waiterCount.Load() == 0 {
for _, fromState := range validFromStates {
// Check if we're already in target state
if fromState == targetState && sm.GetState() == targetState {
return targetState, nil
}
// Try to atomically swap from fromState to targetState
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
// Success! We transitioned atomically
sm.notifyWaiters()
return targetState, nil
}
}
}
// Fast path failed - check if we should wait or fail
currentState := sm.GetState()
// Check if closed
if currentState == StateClosed {
return currentState, ErrStateMachineClosed
}
// Slow path: need to wait for state change
// Create waiter with valid states map for fast lookup
validStatesMap := make(map[ConnState]struct{}, len(validFromStates))
for _, s := range validFromStates {
validStatesMap[s] = struct{}{}
}
w := &waiter{
validStates: validStatesMap,
targetState: targetState,
done: make(chan error, 1), // Buffered to avoid goroutine leak
}
// Add to FIFO queue
sm.mu.Lock()
elem := sm.waiters.PushBack(w)
sm.waiterCount.Add(1)
sm.mu.Unlock()
// Wait for state change or timeout
select {
case <-ctx.Done():
// Timeout or cancellation - remove from queue
sm.mu.Lock()
sm.waiters.Remove(elem)
sm.waiterCount.Add(-1)
sm.mu.Unlock()
return sm.GetState(), ctx.Err()
case err := <-w.done:
// Transition completed (or failed)
// Note: waiterCount is decremented either in notifyWaiters (when the waiter is notified and removed)
// or here (on timeout/cancellation).
return sm.GetState(), err
}
}
// notifyWaiters checks if any waiters can proceed and notifies them in FIFO order.
// This is called after every state transition.
func (sm *ConnStateMachine) notifyWaiters() {
// Fast path: check atomic counter without acquiring lock
// This eliminates mutex overhead in the common case (no waiters)
if sm.waiterCount.Load() == 0 {
return
}
sm.mu.Lock()
defer sm.mu.Unlock()
// Double-check after acquiring lock (waiters might have been processed)
if sm.waiters.Len() == 0 {
return
}
// Process waiters in FIFO order until no more can be processed
// We loop instead of recursing to avoid stack overflow and mutex issues
for {
processed := false
// Find the first waiter that can proceed
for elem := sm.waiters.Front(); elem != nil; elem = elem.Next() {
w := elem.Value.(*waiter)
// Read current state inside the loop to get the latest value
currentState := sm.GetState()
// Check if current state is valid for this waiter
if _, valid := w.validStates[currentState]; valid {
// Remove from queue first
sm.waiters.Remove(elem)
sm.waiterCount.Add(-1)
// Use CAS to ensure state hasn't changed since we checked
// This prevents race condition where another thread changes state
// between our check and our transition
if sm.state.CompareAndSwap(uint32(currentState), uint32(w.targetState)) {
// Successfully transitioned - notify waiter
w.done <- nil
processed = true
break
} else {
// State changed - re-add waiter to front of queue to maintain FIFO ordering
// This waiter was first in line and should retain priority
sm.waiters.PushFront(w)
sm.waiterCount.Add(1)
// Continue to next iteration to re-read state
processed = true
break
}
}
}
// If we didn't process any waiter, we're done
if !processed {
break
}
}
}
+165
View File
@@ -0,0 +1,165 @@
package pool
import (
"context"
"sync"
)
// PoolHook defines the interface for connection lifecycle hooks.
type PoolHook interface {
// OnGet is called when a connection is retrieved from the pool.
// It can modify the connection or return an error to prevent its use.
// The accept flag can be used to prevent the connection from being used.
// On Accept = false the connection is rejected and returned to the pool.
// The error can be used to prevent the connection from being used and returned to the pool.
// On Errors, the connection is removed from the pool.
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
// The flag can be used for gathering metrics on pool hit/miss ratio.
OnGet(ctx context.Context, conn *Conn, isNewConn bool) (accept bool, err error)
// OnPut is called when a connection is returned to the pool.
// It returns whether the connection should be pooled and whether it should be removed.
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
// OnRemove is called when a connection is removed from the pool.
// This happens when:
// - Connection fails health check
// - Connection exceeds max lifetime
// - Pool is being closed
// - Connection encounters an error
// Implementations should clean up any per-connection state.
// The reason parameter indicates why the connection was removed.
OnRemove(ctx context.Context, conn *Conn, reason error)
}
// PoolHookManager manages multiple pool hooks.
type PoolHookManager struct {
hooks []PoolHook
hooksMu sync.RWMutex
}
// NewPoolHookManager creates a new pool hook manager.
func NewPoolHookManager() *PoolHookManager {
return &PoolHookManager{
hooks: make([]PoolHook, 0),
}
}
// AddHook adds a pool hook to the manager.
// Hooks are called in the order they were added.
func (phm *PoolHookManager) AddHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
phm.hooks = append(phm.hooks, hook)
}
// RemoveHook removes a pool hook from the manager.
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
phm.hooksMu.Lock()
defer phm.hooksMu.Unlock()
for i, h := range phm.hooks {
if h == hook {
// Remove hook by swapping with last element and truncating
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
phm.hooks = phm.hooks[:len(phm.hooks)-1]
break
}
}
}
// ProcessOnGet calls all OnGet hooks in order.
// If any hook returns an error, processing stops and the error is returned.
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) (acceptConn bool, err error) {
// Copy slice reference while holding lock (fast)
phm.hooksMu.RLock()
hooks := phm.hooks
phm.hooksMu.RUnlock()
// Call hooks without holding lock (slow operations)
for _, hook := range hooks {
acceptConn, err := hook.OnGet(ctx, conn, isNewConn)
if err != nil {
return false, err
}
if !acceptConn {
return false, nil
}
}
return true, nil
}
// ProcessOnPut calls all OnPut hooks in order.
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
// Copy slice reference while holding lock (fast)
phm.hooksMu.RLock()
hooks := phm.hooks
phm.hooksMu.RUnlock()
shouldPool = true // Default to pooling the connection
// Call hooks without holding lock (slow operations)
for _, hook := range hooks {
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
if hookErr != nil {
return false, true, hookErr
}
// If any hook says to remove or not pool, respect that decision
if hookShouldRemove {
return false, true, nil
}
if !hookShouldPool {
shouldPool = false
}
}
return shouldPool, false, nil
}
// ProcessOnRemove calls all OnRemove hooks in order.
func (phm *PoolHookManager) ProcessOnRemove(ctx context.Context, conn *Conn, reason error) {
// Copy slice reference while holding lock (fast)
phm.hooksMu.RLock()
hooks := phm.hooks
phm.hooksMu.RUnlock()
// Call hooks without holding lock (slow operations)
for _, hook := range hooks {
hook.OnRemove(ctx, conn, reason)
}
}
// GetHookCount returns the number of registered hooks (for testing).
func (phm *PoolHookManager) GetHookCount() int {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
return len(phm.hooks)
}
// GetHooks returns a copy of all registered hooks.
func (phm *PoolHookManager) GetHooks() []PoolHook {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
hooks := make([]PoolHook, len(phm.hooks))
copy(hooks, phm.hooks)
return hooks
}
// Clone creates a copy of the hook manager with the same hooks.
// This is used for lock-free atomic updates of the hook manager.
func (phm *PoolHookManager) Clone() *PoolHookManager {
phm.hooksMu.RLock()
defer phm.hooksMu.RUnlock()
newManager := &PoolHookManager{
hooks: make([]PoolHook, len(phm.hooks)),
}
copy(newManager.hooks, phm.hooks)
return newManager
}
File diff suppressed because it is too large Load Diff
+50 -4
View File
@@ -1,7 +1,13 @@
package pool
import "context"
import (
"context"
"time"
)
// SingleConnPool is a pool that always returns the same connection.
// Note: This pool is not thread-safe.
// It is intended to be used by clients that need a single connection.
type SingleConnPool struct {
pool Pooler
cn *Conn
@@ -10,6 +16,12 @@ type SingleConnPool struct {
var _ Pooler = (*SingleConnPool)(nil)
// NewSingleConnPool creates a new single connection pool.
// The pool will always return the same connection.
// The pool will not:
// - Close the connection
// - Reconnect the connection
// - Track the connection in any way
func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool {
return &SingleConnPool{
pool: pool,
@@ -25,20 +37,47 @@ func (p *SingleConnPool) CloseConn(cn *Conn) error {
return p.pool.CloseConn(cn)
}
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) {
if p.stickyErr != nil {
return nil, p.stickyErr
}
if p.cn == nil {
return nil, ErrClosed
}
// NOTE: SingleConnPool is NOT thread-safe by design and is used in special scenarios:
// - During initialization (connection is in INITIALIZING state)
// - During re-authentication (connection is in UNUSABLE state)
// - For transactions (connection might be in various states)
// We use SetUsed() which forces the transition, rather than TryTransition() which
// would fail if the connection is not in IDLE/CREATED state.
p.cn.SetUsed(true)
p.cn.SetUsedAt(time.Now())
return p.cn, nil
}
func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {}
func (p *SingleConnPool) Put(_ context.Context, cn *Conn) {
if p.cn == nil {
return
}
if p.cn != cn {
return
}
p.cn.SetUsed(false)
}
func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
func (p *SingleConnPool) Remove(_ context.Context, cn *Conn, reason error) {
cn.SetUsed(false)
p.cn = nil
p.stickyErr = reason
}
// RemoveWithoutTurn has the same behavior as Remove for SingleConnPool
// since SingleConnPool doesn't use a turn-based queue system.
func (p *SingleConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) {
p.Remove(ctx, cn, reason)
}
func (p *SingleConnPool) Close() error {
p.cn = nil
p.stickyErr = ErrClosed
@@ -53,6 +92,13 @@ func (p *SingleConnPool) IdleLen() int {
return 0
}
// Size returns the maximum pool size, which is always 1 for SingleConnPool.
func (p *SingleConnPool) Size() int { return 1 }
func (p *SingleConnPool) Stats() *Stats {
return &Stats{}
}
func (p *SingleConnPool) AddPoolHook(_ PoolHook) {}
func (p *SingleConnPool) RemovePoolHook(_ PoolHook) {}
+13
View File
@@ -123,6 +123,12 @@ func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
p.ch <- cn
}
// RemoveWithoutTurn has the same behavior as Remove for StickyConnPool
// since StickyConnPool doesn't use a turn-based queue system.
func (p *StickyConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) {
p.Remove(ctx, cn, reason)
}
func (p *StickyConnPool) Close() error {
if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
return nil
@@ -196,6 +202,13 @@ func (p *StickyConnPool) IdleLen() int {
return len(p.ch)
}
// Size returns the maximum pool size, which is always 1 for StickyConnPool.
func (p *StickyConnPool) Size() int { return 1 }
func (p *StickyConnPool) Stats() *Stats {
return &Stats{}
}
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}
+80
View File
@@ -0,0 +1,80 @@
package pool
import (
"context"
"net"
"sync"
"sync/atomic"
)
type PubSubStats struct {
Created uint32
Untracked uint32
Active uint32
}
// PubSubPool manages a pool of PubSub connections.
type PubSubPool struct {
opt *Options
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
// Map to track active PubSub connections
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
closed atomic.Bool
stats PubSubStats
}
// NewPubSubPool implements a pool for PubSub connections.
// It intentionally does not implement the Pooler interface
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
return &PubSubPool{
opt: opt,
netDialer: netDialer,
}
}
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
if p.closed.Load() {
return nil, ErrClosed
}
netConn, err := p.netDialer(ctx, network, addr)
if err != nil {
return nil, err
}
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
cn.pubsub = true
atomic.AddUint32(&p.stats.Created, 1)
return cn, nil
}
func (p *PubSubPool) TrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, 1)
p.activeConns.Store(cn.GetID(), cn)
}
func (p *PubSubPool) UntrackConn(cn *Conn) {
atomic.AddUint32(&p.stats.Active, ^uint32(0))
atomic.AddUint32(&p.stats.Untracked, 1)
p.activeConns.Delete(cn.GetID())
}
func (p *PubSubPool) Close() error {
p.closed.Store(true)
p.activeConns.Range(func(key, value interface{}) bool {
cn := value.(*Conn)
_ = cn.Close()
return true
})
return nil
}
func (p *PubSubPool) Stats() *PubSubStats {
// load stats atomically
return &PubSubStats{
Created: atomic.LoadUint32(&p.stats.Created),
Untracked: atomic.LoadUint32(&p.stats.Untracked),
Active: atomic.LoadUint32(&p.stats.Active),
}
}
+93
View File
@@ -0,0 +1,93 @@
package pool
import (
"context"
"sync"
)
type wantConn struct {
mu sync.Mutex // protects ctx, done and sending of the result
ctx context.Context // context for dial, cleared after delivered or canceled
cancelCtx context.CancelFunc
done bool // true after delivered or canceled
result chan wantConnResult // channel to deliver connection or error
}
// getCtxForDial returns context for dial or nil if connection was delivered or canceled.
func (w *wantConn) getCtxForDial() context.Context {
w.mu.Lock()
defer w.mu.Unlock()
return w.ctx
}
func (w *wantConn) tryDeliver(cn *Conn, err error) bool {
w.mu.Lock()
defer w.mu.Unlock()
if w.done {
return false
}
w.done = true
w.ctx = nil
w.result <- wantConnResult{cn: cn, err: err}
close(w.result)
return true
}
func (w *wantConn) cancel() *Conn {
w.mu.Lock()
var cn *Conn
if w.done {
select {
case result := <-w.result:
cn = result.cn
default:
}
} else {
close(w.result)
}
w.done = true
w.ctx = nil
w.mu.Unlock()
return cn
}
type wantConnResult struct {
cn *Conn
err error
}
type wantConnQueue struct {
mu sync.RWMutex
items []*wantConn
}
func newWantConnQueue() *wantConnQueue {
return &wantConnQueue{
items: make([]*wantConn, 0),
}
}
func (q *wantConnQueue) enqueue(w *wantConn) {
q.mu.Lock()
defer q.mu.Unlock()
q.items = append(q.items, w)
}
func (q *wantConnQueue) dequeue() (*wantConn, bool) {
q.mu.Lock()
defer q.mu.Unlock()
if len(q.items) == 0 {
return nil, false
}
item := q.items[0]
q.items = q.items[1:]
return item, true
}
+89 -2
View File
@@ -50,7 +50,8 @@ func (e RedisError) Error() string { return string(e) }
func (RedisError) RedisError() {}
func ParseErrorReply(line []byte) error {
return RedisError(line[1:])
msg := string(line[1:])
return parseTypedRedisError(msg)
}
//------------------------------------------------------------------------------
@@ -99,6 +100,92 @@ func (r *Reader) PeekReplyType() (byte, error) {
return b[0], nil
}
func (r *Reader) PeekPushNotificationName() (string, error) {
// "prime" the buffer by peeking at the next byte
c, err := r.Peek(1)
if err != nil {
return "", err
}
if c[0] != RespPush {
return "", fmt.Errorf("redis: can't peek push notification name, next reply is not a push notification")
}
// peek 36 bytes at most, should be enough to read the push notification name
toPeek := 36
buffered := r.Buffered()
if buffered == 0 {
return "", fmt.Errorf("redis: can't peek push notification name, no data available")
}
if buffered < toPeek {
toPeek = buffered
}
buf, err := r.rd.Peek(toPeek)
if err != nil {
return "", err
}
if buf[0] != RespPush {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
if len(buf) < 3 {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
// remove push notification type
buf = buf[1:]
// remove first line - e.g. >2\r\n
for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' {
buf = buf[i+2:]
break
} else {
if buf[i] < '0' || buf[i] > '9' {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
}
}
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
}
// next line should be $<length><string>\r\n or +<length><string>\r\n
// should have the type of the push notification name and it's length
if buf[0] != RespString && buf[0] != RespStatus {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
typeOfName := buf[0]
// remove the type of the push notification name
buf = buf[1:]
if typeOfName == RespString {
// remove the length of the string
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' {
buf = buf[i+2:]
break
} else {
if buf[i] < '0' || buf[i] > '9' {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
}
}
}
if len(buf) < 2 {
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
}
// keep only the notification name
for i := 0; i < len(buf)-1; i++ {
if buf[i] == '\r' && buf[i+1] == '\n' {
buf = buf[:i]
break
}
}
return util.BytesToString(buf), nil
}
// ReadLine Return a valid reply, it will check the protocol or redis error,
// and discard the attribute type.
func (r *Reader) ReadLine() ([]byte, error) {
@@ -115,7 +202,7 @@ func (r *Reader) ReadLine() ([]byte, error) {
var blobErr string
blobErr, err = r.readStringReply(line)
if err == nil {
err = RedisError(blobErr)
err = parseTypedRedisError(blobErr)
}
return nil, err
case RespAttr:
+488
View File
@@ -0,0 +1,488 @@
package proto
import (
"errors"
"strings"
)
// Typed Redis errors for better error handling with wrapping support.
// These errors maintain backward compatibility by keeping the same error messages.
// LoadingError is returned when Redis is loading the dataset in memory.
type LoadingError struct {
msg string
}
func (e *LoadingError) Error() string {
return e.msg
}
func (e *LoadingError) RedisError() {}
// NewLoadingError creates a new LoadingError with the given message.
func NewLoadingError(msg string) *LoadingError {
return &LoadingError{msg: msg}
}
// ReadOnlyError is returned when trying to write to a read-only replica.
type ReadOnlyError struct {
msg string
}
func (e *ReadOnlyError) Error() string {
return e.msg
}
func (e *ReadOnlyError) RedisError() {}
// NewReadOnlyError creates a new ReadOnlyError with the given message.
func NewReadOnlyError(msg string) *ReadOnlyError {
return &ReadOnlyError{msg: msg}
}
// MovedError is returned when a key has been moved to a different node in a cluster.
type MovedError struct {
msg string
addr string
}
func (e *MovedError) Error() string {
return e.msg
}
func (e *MovedError) RedisError() {}
// Addr returns the address of the node where the key has been moved.
func (e *MovedError) Addr() string {
return e.addr
}
// NewMovedError creates a new MovedError with the given message and address.
func NewMovedError(msg string, addr string) *MovedError {
return &MovedError{msg: msg, addr: addr}
}
// AskError is returned when a key is being migrated and the client should ask another node.
type AskError struct {
msg string
addr string
}
func (e *AskError) Error() string {
return e.msg
}
func (e *AskError) RedisError() {}
// Addr returns the address of the node to ask.
func (e *AskError) Addr() string {
return e.addr
}
// NewAskError creates a new AskError with the given message and address.
func NewAskError(msg string, addr string) *AskError {
return &AskError{msg: msg, addr: addr}
}
// ClusterDownError is returned when the cluster is down.
type ClusterDownError struct {
msg string
}
func (e *ClusterDownError) Error() string {
return e.msg
}
func (e *ClusterDownError) RedisError() {}
// NewClusterDownError creates a new ClusterDownError with the given message.
func NewClusterDownError(msg string) *ClusterDownError {
return &ClusterDownError{msg: msg}
}
// TryAgainError is returned when a command cannot be processed and should be retried.
type TryAgainError struct {
msg string
}
func (e *TryAgainError) Error() string {
return e.msg
}
func (e *TryAgainError) RedisError() {}
// NewTryAgainError creates a new TryAgainError with the given message.
func NewTryAgainError(msg string) *TryAgainError {
return &TryAgainError{msg: msg}
}
// MasterDownError is returned when the master is down.
type MasterDownError struct {
msg string
}
func (e *MasterDownError) Error() string {
return e.msg
}
func (e *MasterDownError) RedisError() {}
// NewMasterDownError creates a new MasterDownError with the given message.
func NewMasterDownError(msg string) *MasterDownError {
return &MasterDownError{msg: msg}
}
// MaxClientsError is returned when the maximum number of clients has been reached.
type MaxClientsError struct {
msg string
}
func (e *MaxClientsError) Error() string {
return e.msg
}
func (e *MaxClientsError) RedisError() {}
// NewMaxClientsError creates a new MaxClientsError with the given message.
func NewMaxClientsError(msg string) *MaxClientsError {
return &MaxClientsError{msg: msg}
}
// AuthError is returned when authentication fails.
type AuthError struct {
msg string
}
func (e *AuthError) Error() string {
return e.msg
}
func (e *AuthError) RedisError() {}
// NewAuthError creates a new AuthError with the given message.
func NewAuthError(msg string) *AuthError {
return &AuthError{msg: msg}
}
// PermissionError is returned when a user lacks required permissions.
type PermissionError struct {
msg string
}
func (e *PermissionError) Error() string {
return e.msg
}
func (e *PermissionError) RedisError() {}
// NewPermissionError creates a new PermissionError with the given message.
func NewPermissionError(msg string) *PermissionError {
return &PermissionError{msg: msg}
}
// ExecAbortError is returned when a transaction is aborted.
type ExecAbortError struct {
msg string
}
func (e *ExecAbortError) Error() string {
return e.msg
}
func (e *ExecAbortError) RedisError() {}
// NewExecAbortError creates a new ExecAbortError with the given message.
func NewExecAbortError(msg string) *ExecAbortError {
return &ExecAbortError{msg: msg}
}
// OOMError is returned when Redis is out of memory.
type OOMError struct {
msg string
}
func (e *OOMError) Error() string {
return e.msg
}
func (e *OOMError) RedisError() {}
// NewOOMError creates a new OOMError with the given message.
func NewOOMError(msg string) *OOMError {
return &OOMError{msg: msg}
}
// parseTypedRedisError parses a Redis error message and returns a typed error if applicable.
// This function maintains backward compatibility by keeping the same error messages.
func parseTypedRedisError(msg string) error {
// Check for specific error patterns and return typed errors
switch {
case strings.HasPrefix(msg, "LOADING "):
return NewLoadingError(msg)
case strings.HasPrefix(msg, "READONLY "):
return NewReadOnlyError(msg)
case strings.HasPrefix(msg, "MOVED "):
// Extract address from "MOVED <slot> <addr>"
addr := extractAddr(msg)
return NewMovedError(msg, addr)
case strings.HasPrefix(msg, "ASK "):
// Extract address from "ASK <slot> <addr>"
addr := extractAddr(msg)
return NewAskError(msg, addr)
case strings.HasPrefix(msg, "CLUSTERDOWN "):
return NewClusterDownError(msg)
case strings.HasPrefix(msg, "TRYAGAIN "):
return NewTryAgainError(msg)
case strings.HasPrefix(msg, "MASTERDOWN "):
return NewMasterDownError(msg)
case msg == "ERR max number of clients reached":
return NewMaxClientsError(msg)
case strings.HasPrefix(msg, "NOAUTH "), strings.HasPrefix(msg, "WRONGPASS "), strings.Contains(msg, "unauthenticated"):
return NewAuthError(msg)
case strings.HasPrefix(msg, "NOPERM "):
return NewPermissionError(msg)
case strings.HasPrefix(msg, "EXECABORT "):
return NewExecAbortError(msg)
case strings.HasPrefix(msg, "OOM "):
return NewOOMError(msg)
default:
// Return generic RedisError for unknown error types
return RedisError(msg)
}
}
// extractAddr extracts the address from MOVED/ASK error messages.
// Format: "MOVED <slot> <addr>" or "ASK <slot> <addr>"
func extractAddr(msg string) string {
ind := strings.LastIndex(msg, " ")
if ind == -1 {
return ""
}
return msg[ind+1:]
}
// IsLoadingError checks if an error is a LoadingError, even if wrapped.
func IsLoadingError(err error) bool {
if err == nil {
return false
}
var loadingErr *LoadingError
if errors.As(err, &loadingErr) {
return true
}
// Check if wrapped error is a RedisError with LOADING prefix
var redisErr RedisError
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "LOADING ") {
return true
}
// Fallback to string checking for backward compatibility
return strings.HasPrefix(err.Error(), "LOADING ")
}
// IsReadOnlyError checks if an error is a ReadOnlyError, even if wrapped.
func IsReadOnlyError(err error) bool {
if err == nil {
return false
}
var readOnlyErr *ReadOnlyError
if errors.As(err, &readOnlyErr) {
return true
}
// Check if wrapped error is a RedisError with READONLY prefix
var redisErr RedisError
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "READONLY ") {
return true
}
// Fallback to string checking for backward compatibility
return strings.HasPrefix(err.Error(), "READONLY ")
}
// IsMovedError checks if an error is a MovedError, even if wrapped.
// Returns the error and a boolean indicating if it's a MovedError.
func IsMovedError(err error) (*MovedError, bool) {
if err == nil {
return nil, false
}
var movedErr *MovedError
if errors.As(err, &movedErr) {
return movedErr, true
}
// Fallback to string checking for backward compatibility
s := err.Error()
if strings.HasPrefix(s, "MOVED ") {
// Parse: MOVED 3999 127.0.0.1:6381
parts := strings.Split(s, " ")
if len(parts) == 3 {
return &MovedError{msg: s, addr: parts[2]}, true
}
}
return nil, false
}
// IsAskError checks if an error is an AskError, even if wrapped.
// Returns the error and a boolean indicating if it's an AskError.
func IsAskError(err error) (*AskError, bool) {
if err == nil {
return nil, false
}
var askErr *AskError
if errors.As(err, &askErr) {
return askErr, true
}
// Fallback to string checking for backward compatibility
s := err.Error()
if strings.HasPrefix(s, "ASK ") {
// Parse: ASK 3999 127.0.0.1:6381
parts := strings.Split(s, " ")
if len(parts) == 3 {
return &AskError{msg: s, addr: parts[2]}, true
}
}
return nil, false
}
// IsClusterDownError checks if an error is a ClusterDownError, even if wrapped.
func IsClusterDownError(err error) bool {
if err == nil {
return false
}
var clusterDownErr *ClusterDownError
if errors.As(err, &clusterDownErr) {
return true
}
// Check if wrapped error is a RedisError with CLUSTERDOWN prefix
var redisErr RedisError
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "CLUSTERDOWN ") {
return true
}
// Fallback to string checking for backward compatibility
return strings.HasPrefix(err.Error(), "CLUSTERDOWN ")
}
// IsTryAgainError checks if an error is a TryAgainError, even if wrapped.
func IsTryAgainError(err error) bool {
if err == nil {
return false
}
var tryAgainErr *TryAgainError
if errors.As(err, &tryAgainErr) {
return true
}
// Check if wrapped error is a RedisError with TRYAGAIN prefix
var redisErr RedisError
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "TRYAGAIN ") {
return true
}
// Fallback to string checking for backward compatibility
return strings.HasPrefix(err.Error(), "TRYAGAIN ")
}
// IsMasterDownError checks if an error is a MasterDownError, even if wrapped.
func IsMasterDownError(err error) bool {
if err == nil {
return false
}
var masterDownErr *MasterDownError
if errors.As(err, &masterDownErr) {
return true
}
// Check if wrapped error is a RedisError with MASTERDOWN prefix
var redisErr RedisError
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "MASTERDOWN ") {
return true
}
// Fallback to string checking for backward compatibility
return strings.HasPrefix(err.Error(), "MASTERDOWN ")
}
// IsMaxClientsError checks if an error is a MaxClientsError, even if wrapped.
func IsMaxClientsError(err error) bool {
if err == nil {
return false
}
var maxClientsErr *MaxClientsError
if errors.As(err, &maxClientsErr) {
return true
}
// Check if wrapped error is a RedisError with max clients prefix
var redisErr RedisError
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "ERR max number of clients reached") {
return true
}
// Fallback to string checking for backward compatibility
return strings.HasPrefix(err.Error(), "ERR max number of clients reached")
}
// IsAuthError checks if an error is an AuthError, even if wrapped.
func IsAuthError(err error) bool {
if err == nil {
return false
}
var authErr *AuthError
if errors.As(err, &authErr) {
return true
}
// Check if wrapped error is a RedisError with auth error prefix
var redisErr RedisError
if errors.As(err, &redisErr) {
s := redisErr.Error()
return strings.HasPrefix(s, "NOAUTH ") || strings.HasPrefix(s, "WRONGPASS ") || strings.Contains(s, "unauthenticated")
}
// Fallback to string checking for backward compatibility
s := err.Error()
return strings.HasPrefix(s, "NOAUTH ") || strings.HasPrefix(s, "WRONGPASS ") || strings.Contains(s, "unauthenticated")
}
// IsPermissionError checks if an error is a PermissionError, even if wrapped.
func IsPermissionError(err error) bool {
if err == nil {
return false
}
var permErr *PermissionError
if errors.As(err, &permErr) {
return true
}
// Check if wrapped error is a RedisError with NOPERM prefix
var redisErr RedisError
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "NOPERM ") {
return true
}
// Fallback to string checking for backward compatibility
return strings.HasPrefix(err.Error(), "NOPERM ")
}
// IsExecAbortError checks if an error is an ExecAbortError, even if wrapped.
func IsExecAbortError(err error) bool {
if err == nil {
return false
}
var execAbortErr *ExecAbortError
if errors.As(err, &execAbortErr) {
return true
}
// Check if wrapped error is a RedisError with EXECABORT prefix
var redisErr RedisError
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "EXECABORT ") {
return true
}
// Fallback to string checking for backward compatibility
return strings.HasPrefix(err.Error(), "EXECABORT ")
}
// IsOOMError checks if an error is an OOMError, even if wrapped.
func IsOOMError(err error) bool {
if err == nil {
return false
}
var oomErr *OOMError
if errors.As(err, &oomErr) {
return true
}
// Check if wrapped error is a RedisError with OOM prefix
var redisErr RedisError
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "OOM ") {
return true
}
// Fallback to string checking for backward compatibility
return strings.HasPrefix(err.Error(), "OOM ")
}
+3
View File
@@ -0,0 +1,3 @@
package internal
const RedisNull = "<nil>"
+193
View File
@@ -0,0 +1,193 @@
package internal
import (
"context"
"sync"
"time"
)
var semTimers = sync.Pool{
New: func() interface{} {
t := time.NewTimer(time.Hour)
t.Stop()
return t
},
}
// FastSemaphore is a channel-based semaphore optimized for performance.
// It uses a fast path that avoids timer allocation when tokens are available.
// The channel is pre-filled with tokens: Acquire = receive, Release = send.
// Closing the semaphore unblocks all waiting goroutines.
//
// Performance: ~30 ns/op with zero allocations on fast path.
// Fairness: Eventual fairness (no starvation) but not strict FIFO.
type FastSemaphore struct {
tokens chan struct{}
max int32
}
// NewFastSemaphore creates a new fast semaphore with the given capacity.
func NewFastSemaphore(capacity int32) *FastSemaphore {
ch := make(chan struct{}, capacity)
// Pre-fill with tokens
for i := int32(0); i < capacity; i++ {
ch <- struct{}{}
}
return &FastSemaphore{
tokens: ch,
max: capacity,
}
}
// TryAcquire attempts to acquire a token without blocking.
// Returns true if successful, false if no tokens available.
func (s *FastSemaphore) TryAcquire() bool {
select {
case <-s.tokens:
return true
default:
return false
}
}
// Acquire acquires a token, blocking if necessary until one is available.
// Returns an error if the context is cancelled or the timeout expires.
// Uses a fast path to avoid timer allocation when tokens are immediately available.
func (s *FastSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error {
// Check context first
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Try fast path first (no timer needed)
select {
case <-s.tokens:
return nil
default:
}
// Slow path: need to wait with timeout
timer := semTimers.Get().(*time.Timer)
defer semTimers.Put(timer)
timer.Reset(timeout)
select {
case <-s.tokens:
if !timer.Stop() {
<-timer.C
}
return nil
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return ctx.Err()
case <-timer.C:
return timeoutErr
}
}
// AcquireBlocking acquires a token, blocking indefinitely until one is available.
func (s *FastSemaphore) AcquireBlocking() {
<-s.tokens
}
// Release releases a token back to the semaphore.
func (s *FastSemaphore) Release() {
s.tokens <- struct{}{}
}
// Close closes the semaphore, unblocking all waiting goroutines.
// After close, all Acquire calls will receive a closed channel signal.
func (s *FastSemaphore) Close() {
close(s.tokens)
}
// Len returns the current number of acquired tokens.
func (s *FastSemaphore) Len() int32 {
return s.max - int32(len(s.tokens))
}
// FIFOSemaphore is a channel-based semaphore with strict FIFO ordering.
// Unlike FastSemaphore, this guarantees that threads are served in the exact order they call Acquire().
// The channel is pre-filled with tokens: Acquire = receive, Release = send.
// Closing the semaphore unblocks all waiting goroutines.
//
// Performance: ~115 ns/op with zero allocations (slower than FastSemaphore due to timer allocation).
// Fairness: Strict FIFO ordering guaranteed by Go runtime.
type FIFOSemaphore struct {
tokens chan struct{}
max int32
}
// NewFIFOSemaphore creates a new FIFO semaphore with the given capacity.
func NewFIFOSemaphore(capacity int32) *FIFOSemaphore {
ch := make(chan struct{}, capacity)
// Pre-fill with tokens
for i := int32(0); i < capacity; i++ {
ch <- struct{}{}
}
return &FIFOSemaphore{
tokens: ch,
max: capacity,
}
}
// TryAcquire attempts to acquire a token without blocking.
// Returns true if successful, false if no tokens available.
func (s *FIFOSemaphore) TryAcquire() bool {
select {
case <-s.tokens:
return true
default:
return false
}
}
// Acquire acquires a token, blocking if necessary until one is available.
// Returns an error if the context is cancelled or the timeout expires.
// Always uses timer to guarantee FIFO ordering (no fast path).
func (s *FIFOSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error {
// No fast path - always use timer to guarantee FIFO
timer := semTimers.Get().(*time.Timer)
defer semTimers.Put(timer)
timer.Reset(timeout)
select {
case <-s.tokens:
if !timer.Stop() {
<-timer.C
}
return nil
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return ctx.Err()
case <-timer.C:
return timeoutErr
}
}
// AcquireBlocking acquires a token, blocking indefinitely until one is available.
func (s *FIFOSemaphore) AcquireBlocking() {
<-s.tokens
}
// Release releases a token back to the semaphore.
func (s *FIFOSemaphore) Release() {
s.tokens <- struct{}{}
}
// Close closes the semaphore, unblocking all waiting goroutines.
// After close, all Acquire calls will receive a closed channel signal.
func (s *FIFOSemaphore) Close() {
close(s.tokens)
}
// Len returns the current number of acquired tokens.
func (s *FIFOSemaphore) Len() int32 {
return s.max - int32(len(s.tokens))
}
+11
View File
@@ -28,3 +28,14 @@ func MustParseFloat(s string) float64 {
}
return f
}
// SafeIntToInt32 safely converts an int to int32, returning an error if overflow would occur.
func SafeIntToInt32(value int, fieldName string) (int32, error) {
if value > math.MaxInt32 {
return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32)
}
if value < math.MinInt32 {
return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32)
}
return int32(value), nil
}
+17
View File
@@ -0,0 +1,17 @@
package util
// Max returns the maximum of two integers
func Max(a, b int) int {
if a > b {
return a
}
return b
}
// Min returns the minimum of two integers
func Min(a, b int) int {
if a < b {
return a
}
return b
}
+218
View File
@@ -0,0 +1,218 @@
# Maintenance Notifications - FEATURES
## Overview
The Maintenance Notifications feature enables seamless Redis connection handoffs during cluster maintenance operations without dropping active connections. This feature leverages Redis RESP3 push notifications to provide zero-downtime maintenance for Redis Enterprise and compatible Redis deployments.
## Important
Using Maintenance Notifications may affect the read and write timeouts by relaxing them during maintenance operations.
This is necessary to prevent false failures due to increased latency during handoffs. The relaxed timeouts are automatically applied and removed as needed.
## Key Features
### Seamless Connection Handoffs
- **Zero-Downtime Maintenance**: Automatically handles connection transitions during cluster operations
- **Active Operation Preservation**: Transfers in-flight operations to new connections without interruption
- **Graceful Degradation**: Falls back to standard reconnection if handoff fails
### Push Notification Support
Supports all Redis Enterprise maintenance notification types:
- **MOVING** - Slot moving to a new node
- **MIGRATING** - Slot in migration state
- **MIGRATED** - Migration completed
- **FAILING_OVER** - Node failing over
- **FAILED_OVER** - Failover completed
### Circuit Breaker Pattern
- **Endpoint-Specific Failure Tracking**: Prevents repeated connection attempts to failing endpoints
- **Automatic Recovery Testing**: Half-open state allows gradual recovery validation
- **Configurable Thresholds**: Customize failure thresholds and reset timeouts
### Flexible Configuration
- **Auto-Detection Mode**: Automatically detects server support for maintenance notifications
- **Multiple Endpoint Types**: Support for internal/external IP/FQDN endpoint resolution
- **Auto-Scaling Workers**: Automatically sizes worker pool based on connection pool size
- **Timeout Management**: Separate timeouts for relaxed (during maintenance) and normal operations
### Extensible Hook System
- **Pre/Post Processing Hooks**: Monitor and customize notification handling
- **Built-in Hooks**: Logging and metrics collection hooks included
- **Custom Hook Support**: Implement custom business logic around maintenance events
### Comprehensive Monitoring
- **Metrics Collection**: Track notification counts, processing times, and error rates
- **Circuit Breaker Stats**: Monitor endpoint health and circuit breaker states
- **Operation Tracking**: Track active handoff operations and their lifecycle
## Architecture Highlights
### Event-Driven Handoff System
- **Asynchronous Processing**: Non-blocking handoff operations using worker pool pattern
- **Queue-Based Architecture**: Configurable queue size with auto-scaling support
- **Retry Mechanism**: Configurable retry attempts with exponential backoff
### Connection Pool Integration
- **Pool Hook Interface**: Seamless integration with go-redis connection pool
- **Connection State Management**: Atomic flags for connection usability tracking
- **Graceful Shutdown**: Ensures all in-flight handoffs complete before shutdown
### Thread-Safe Design
- **Lock-Free Operations**: Atomic operations for high-performance state tracking
- **Concurrent-Safe Maps**: sync.Map for tracking active operations
- **Minimal Lock Contention**: Read-write locks only where necessary
## Configuration Options
### Operation Modes
- **`ModeDisabled`**: Maintenance notifications completely disabled
- **`ModeEnabled`**: Forcefully enabled (fails if server doesn't support)
- **`ModeAuto`**: Auto-detect server support (recommended default)
### Endpoint Types
- **`EndpointTypeAuto`**: Auto-detect based on current connection
- **`EndpointTypeInternalIP`**: Use internal IP addresses
- **`EndpointTypeInternalFQDN`**: Use internal fully qualified domain names
- **`EndpointTypeExternalIP`**: Use external IP addresses
- **`EndpointTypeExternalFQDN`**: Use external fully qualified domain names
- **`EndpointTypeNone`**: No endpoint (reconnect with current configuration)
### Timeout Configuration
- **`RelaxedTimeout`**: Extended timeout during maintenance operations (default: 10s)
- **`HandoffTimeout`**: Maximum time for handoff completion (default: 15s)
- **`PostHandoffRelaxedDuration`**: Relaxed period after handoff (default: 2×RelaxedTimeout)
### Worker Pool Configuration
- **`MaxWorkers`**: Maximum concurrent handoff workers (auto-calculated if 0)
- **`HandoffQueueSize`**: Handoff queue capacity (auto-calculated if 0)
- **`MaxHandoffRetries`**: Maximum retry attempts for failed handoffs (default: 3)
### Circuit Breaker Configuration
- **`CircuitBreakerFailureThreshold`**: Failures before opening circuit (default: 5)
- **`CircuitBreakerResetTimeout`**: Time before testing recovery (default: 60s)
- **`CircuitBreakerMaxRequests`**: Max requests in half-open state (default: 3)
## Auto-Scaling Formulas
### Worker Pool Sizing
When `MaxWorkers = 0` (auto-calculate):
```
MaxWorkers = min(PoolSize/2, max(10, PoolSize/3))
```
### Queue Sizing
When `HandoffQueueSize = 0` (auto-calculate):
```
QueueSize = max(20 × MaxWorkers, PoolSize)
Capped by: min(MaxActiveConns + 1, 5 × PoolSize)
```
### Examples
- **Pool Size 100**: 33 workers, 660 queue (capped at 500)
- **Pool Size 100 + MaxActiveConns 150**: 33 workers, 151 queue
- **Pool Size 50**: 16 workers, 320 queue (capped at 250)
## Performance Characteristics
### Throughput
- **Non-Blocking Handoffs**: Client operations continue during handoffs
- **Concurrent Processing**: Multiple handoffs processed in parallel
- **Minimal Overhead**: Lock-free atomic operations for state tracking
### Latency
- **Relaxed Timeouts**: Extended timeouts during maintenance prevent false failures
- **Fast Path**: Connections not undergoing handoff have zero overhead
- **Graceful Degradation**: Failed handoffs fall back to standard reconnection
### Resource Usage
- **Memory Efficient**: Bounded queue sizes prevent memory exhaustion
- **Worker Pool**: Fixed worker count prevents goroutine explosion
- **Connection Reuse**: Handoff reuses existing connection objects
## Testing
### Unit Tests
- Comprehensive unit test coverage for all components
- Mock-based testing for isolation
- Concurrent operation testing
### Integration Tests
- Pool integration tests with real connection handoffs
- Circuit breaker behavior validation
- Hook system integration testing
### E2E Tests
- Real Redis Enterprise cluster testing
- Multiple scenario coverage (timeouts, endpoint types, stress tests)
- Fault injection testing
- TLS configuration testing
## Compatibility
### Requirements
- **Redis Protocol**: RESP3 required for push notifications
- **Redis Version**: Redis Enterprise or compatible Redis with maintenance notifications
- **Go Version**: Go 1.18+ (uses generics and atomic types)
### Client Support
#### Currently Supported
- **Standalone Client** (`redis.NewClient`)
#### Planned Support
- **Cluster Client** (not yet supported)
#### Will Not Support
- **Failover Client** (no planned support)
- **Ring Client** (no planned support)
## Migration Guide
### Enabling Maintenance Notifications
**Before:**
```go
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Protocol: 2, // RESP2
})
```
**After:**
```go
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Protocol: 3, // RESP3 required
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeAuto,
},
})
```
### Adding Monitoring
```go
// Get the manager from the client
manager := client.GetMaintNotificationsManager()
if manager != nil {
// Add logging hook
loggingHook := maintnotifications.NewLoggingHook(2) // Info level
manager.AddNotificationHook(loggingHook)
// Add metrics hook
metricsHook := maintnotifications.NewMetricsHook()
manager.AddNotificationHook(metricsHook)
}
```
## Known Limitations
1. **Standalone Only**: Currently only supported in standalone Redis clients
2. **RESP3 Required**: Push notifications require RESP3 protocol
3. **Server Support**: Requires Redis Enterprise or compatible Redis with maintenance notifications
4. **Single Connection Commands**: Some commands (MULTI/EXEC, WATCH) may need special handling
5. **No Failover/Ring Client Support**: Failover and Ring clients are not supported and there are no plans to add support
## Future Enhancements
- Cluster client support
- Enhanced metrics and observability
+67
View File
@@ -0,0 +1,67 @@
# Maintenance Notifications
Seamless Redis connection handoffs during cluster maintenance operations without dropping connections.
## ⚠️ **Important Note**
**Maintenance notifications are currently supported only in standalone Redis clients.** Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support this functionality.
## Quick Start
```go
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Protocol: 3, // RESP3 required
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeEnabled,
},
})
```
## Modes
- **`ModeDisabled`** - Maintenance notifications disabled
- **`ModeEnabled`** - Forcefully enabled (fails if server doesn't support)
- **`ModeAuto`** - Auto-detect server support (default)
## Configuration
```go
&maintnotifications.Config{
Mode: maintnotifications.ModeAuto,
EndpointType: maintnotifications.EndpointTypeAuto,
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxHandoffRetries: 3,
MaxWorkers: 0, // Auto-calculated
HandoffQueueSize: 0, // Auto-calculated
PostHandoffRelaxedDuration: 0, // 2 * RelaxedTimeout
}
```
### Endpoint Types
- **`EndpointTypeAuto`** - Auto-detect based on connection (default)
- **`EndpointTypeInternalIP`** - Internal IP address
- **`EndpointTypeInternalFQDN`** - Internal FQDN
- **`EndpointTypeExternalIP`** - External IP address
- **`EndpointTypeExternalFQDN`** - External FQDN
- **`EndpointTypeNone`** - No endpoint (reconnect with current config)
### Auto-Scaling
**Workers**: `min(PoolSize/2, max(10, PoolSize/3))` when auto-calculated
**Queue**: `max(20×Workers, PoolSize)` capped by `MaxActiveConns+1` or `5×PoolSize`
**Examples:**
- Pool 100: 33 workers, 660 queue (capped at 500)
- Pool 100 + MaxActiveConns 150: 33 workers, 151 queue
## How It Works
1. Redis sends push notifications about cluster maintenance operations
2. Client creates new connections to updated endpoints
3. Active operations transfer to new connections
4. Old connections close gracefully
## For more information, see [FEATURES](FEATURES.md)
@@ -0,0 +1,353 @@
package maintnotifications
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
)
// CircuitBreakerState represents the state of a circuit breaker
type CircuitBreakerState int32
const (
// CircuitBreakerClosed - normal operation, requests allowed
CircuitBreakerClosed CircuitBreakerState = iota
// CircuitBreakerOpen - failing fast, requests rejected
CircuitBreakerOpen
// CircuitBreakerHalfOpen - testing if service recovered
CircuitBreakerHalfOpen
)
func (s CircuitBreakerState) String() string {
switch s {
case CircuitBreakerClosed:
return "closed"
case CircuitBreakerOpen:
return "open"
case CircuitBreakerHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling
type CircuitBreaker struct {
// Configuration
failureThreshold int // Number of failures before opening
resetTimeout time.Duration // How long to stay open before testing
maxRequests int // Max requests allowed in half-open state
// State tracking (atomic for lock-free access)
state atomic.Int32 // CircuitBreakerState
failures atomic.Int64 // Current failure count
successes atomic.Int64 // Success count in half-open state
requests atomic.Int64 // Request count in half-open state
lastFailureTime atomic.Int64 // Unix timestamp of last failure
lastSuccessTime atomic.Int64 // Unix timestamp of last success
// Endpoint identification
endpoint string
config *Config
}
// newCircuitBreaker creates a new circuit breaker for an endpoint
func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker {
// Use configuration values with sensible defaults
failureThreshold := 5
resetTimeout := 60 * time.Second
maxRequests := 3
if config != nil {
failureThreshold = config.CircuitBreakerFailureThreshold
resetTimeout = config.CircuitBreakerResetTimeout
maxRequests = config.CircuitBreakerMaxRequests
}
return &CircuitBreaker{
failureThreshold: failureThreshold,
resetTimeout: resetTimeout,
maxRequests: maxRequests,
endpoint: endpoint,
config: config,
state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0)
}
}
// IsOpen returns true if the circuit breaker is open (rejecting requests)
func (cb *CircuitBreaker) IsOpen() bool {
state := CircuitBreakerState(cb.state.Load())
return state == CircuitBreakerOpen
}
// shouldAttemptReset checks if enough time has passed to attempt reset
func (cb *CircuitBreaker) shouldAttemptReset() bool {
lastFailure := time.Unix(cb.lastFailureTime.Load(), 0)
return time.Since(lastFailure) >= cb.resetTimeout
}
// Execute runs the given function with circuit breaker protection
func (cb *CircuitBreaker) Execute(fn func() error) error {
// Single atomic state load for consistency
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerOpen:
if cb.shouldAttemptReset() {
// Attempt transition to half-open
if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
cb.requests.Store(0)
cb.successes.Store(0)
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint))
}
// Fall through to half-open logic
} else {
return ErrCircuitBreakerOpen
}
} else {
return ErrCircuitBreakerOpen
}
fallthrough
case CircuitBreakerHalfOpen:
requests := cb.requests.Add(1)
if requests > int64(cb.maxRequests) {
cb.requests.Add(-1) // Revert the increment
return ErrCircuitBreakerOpen
}
}
// Execute the function with consistent state
err := fn()
if err != nil {
cb.recordFailure()
return err
}
cb.recordSuccess()
return nil
}
// recordFailure records a failure and potentially opens the circuit
func (cb *CircuitBreaker) recordFailure() {
cb.lastFailureTime.Store(time.Now().Unix())
failures := cb.failures.Add(1)
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerClosed:
if failures >= int64(cb.failureThreshold) {
if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) {
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures))
}
}
}
case CircuitBreakerHalfOpen:
// Any failure in half-open state immediately opens the circuit
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) {
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint))
}
}
}
}
// recordSuccess records a success and potentially closes the circuit
func (cb *CircuitBreaker) recordSuccess() {
cb.lastSuccessTime.Store(time.Now().Unix())
state := CircuitBreakerState(cb.state.Load())
switch state {
case CircuitBreakerClosed:
// Reset failure count on success in closed state
cb.failures.Store(0)
case CircuitBreakerHalfOpen:
successes := cb.successes.Add(1)
// If we've had enough successful requests, close the circuit
if successes >= int64(cb.maxRequests) {
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
cb.failures.Store(0)
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes))
}
}
}
}
}
// GetState returns the current state of the circuit breaker
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
return CircuitBreakerState(cb.state.Load())
}
// GetStats returns current statistics for monitoring
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
return CircuitBreakerStats{
Endpoint: cb.endpoint,
State: cb.GetState(),
Failures: cb.failures.Load(),
Successes: cb.successes.Load(),
Requests: cb.requests.Load(),
LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0),
LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0),
}
}
// CircuitBreakerStats provides statistics about a circuit breaker
type CircuitBreakerStats struct {
Endpoint string
State CircuitBreakerState
Failures int64
Successes int64
Requests int64
LastFailureTime time.Time
LastSuccessTime time.Time
}
// CircuitBreakerEntry wraps a circuit breaker with access tracking
type CircuitBreakerEntry struct {
breaker *CircuitBreaker
lastAccess atomic.Int64 // Unix timestamp
created time.Time
}
// CircuitBreakerManager manages circuit breakers for multiple endpoints
type CircuitBreakerManager struct {
breakers sync.Map // map[string]*CircuitBreakerEntry
config *Config
cleanupStop chan struct{}
cleanupMu sync.Mutex
lastCleanup atomic.Int64 // Unix timestamp
}
// newCircuitBreakerManager creates a new circuit breaker manager
func newCircuitBreakerManager(config *Config) *CircuitBreakerManager {
cbm := &CircuitBreakerManager{
config: config,
cleanupStop: make(chan struct{}),
}
cbm.lastCleanup.Store(time.Now().Unix())
// Start background cleanup goroutine
go cbm.cleanupLoop()
return cbm
}
// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary
func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker {
now := time.Now().Unix()
if entry, ok := cbm.breakers.Load(endpoint); ok {
cbEntry := entry.(*CircuitBreakerEntry)
cbEntry.lastAccess.Store(now)
return cbEntry.breaker
}
// Create new circuit breaker with metadata
newBreaker := newCircuitBreaker(endpoint, cbm.config)
newEntry := &CircuitBreakerEntry{
breaker: newBreaker,
created: time.Now(),
}
newEntry.lastAccess.Store(now)
actual, _ := cbm.breakers.LoadOrStore(endpoint, newEntry)
return actual.(*CircuitBreakerEntry).breaker
}
// GetAllStats returns statistics for all circuit breakers
func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats {
var stats []CircuitBreakerStats
cbm.breakers.Range(func(key, value interface{}) bool {
entry := value.(*CircuitBreakerEntry)
stats = append(stats, entry.breaker.GetStats())
return true
})
return stats
}
// cleanupLoop runs background cleanup of unused circuit breakers
func (cbm *CircuitBreakerManager) cleanupLoop() {
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
defer ticker.Stop()
for {
select {
case <-ticker.C:
cbm.cleanup()
case <-cbm.cleanupStop:
return
}
}
}
// cleanup removes circuit breakers that haven't been accessed recently
func (cbm *CircuitBreakerManager) cleanup() {
// Prevent concurrent cleanups
if !cbm.cleanupMu.TryLock() {
return
}
defer cbm.cleanupMu.Unlock()
now := time.Now()
cutoff := now.Add(-30 * time.Minute).Unix() // 30 minute TTL
var toDelete []string
count := 0
cbm.breakers.Range(func(key, value interface{}) bool {
endpoint := key.(string)
entry := value.(*CircuitBreakerEntry)
count++
// Remove if not accessed recently
if entry.lastAccess.Load() < cutoff {
toDelete = append(toDelete, endpoint)
}
return true
})
// Delete expired entries
for _, endpoint := range toDelete {
cbm.breakers.Delete(endpoint)
}
// Log cleanup results
if len(toDelete) > 0 && internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count))
}
cbm.lastCleanup.Store(now.Unix())
}
// Shutdown stops the cleanup goroutine
func (cbm *CircuitBreakerManager) Shutdown() {
close(cbm.cleanupStop)
}
// Reset resets all circuit breakers (useful for testing)
func (cbm *CircuitBreakerManager) Reset() {
cbm.breakers.Range(func(key, value interface{}) bool {
entry := value.(*CircuitBreakerEntry)
breaker := entry.breaker
breaker.state.Store(int32(CircuitBreakerClosed))
breaker.failures.Store(0)
breaker.successes.Store(0)
breaker.requests.Store(0)
breaker.lastFailureTime.Store(0)
breaker.lastSuccessTime.Store(0)
return true
})
}
+458
View File
@@ -0,0 +1,458 @@
package maintnotifications
import (
"context"
"net"
"runtime"
"strings"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/util"
)
// Mode represents the maintenance notifications mode
type Mode string
// Constants for maintenance push notifications modes
const (
ModeDisabled Mode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
ModeEnabled Mode = "enabled" // Client forcefully sends command, interrupts connection on error
ModeAuto Mode = "auto" // Client tries to send command, disables feature on error
)
// IsValid returns true if the maintenance notifications mode is valid
func (m Mode) IsValid() bool {
switch m {
case ModeDisabled, ModeEnabled, ModeAuto:
return true
default:
return false
}
}
// String returns the string representation of the mode
func (m Mode) String() string {
return string(m)
}
// EndpointType represents the type of endpoint to request in MOVING notifications
type EndpointType string
// Constants for endpoint types
const (
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
)
// IsValid returns true if the endpoint type is valid
func (e EndpointType) IsValid() bool {
switch e {
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
return true
default:
return false
}
}
// String returns the string representation of the endpoint type
func (e EndpointType) String() string {
return string(e)
}
// Config provides configuration options for maintenance notifications
type Config struct {
// Mode controls how client maintenance notifications are handled.
// Valid values: ModeDisabled, ModeEnabled, ModeAuto
// Default: ModeAuto
Mode Mode
// EndpointType specifies the type of endpoint to request in MOVING notifications.
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
// Default: EndpointTypeAuto
EndpointType EndpointType
// RelaxedTimeout is the concrete timeout value to use during
// MIGRATING/FAILING_OVER states to accommodate increased latency.
// This applies to both read and write timeouts.
// Default: 10 seconds
RelaxedTimeout time.Duration
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
// If handoff takes longer than this, the old connection will be forcibly closed.
// Default: 15 seconds (matches server-side eviction timeout)
HandoffTimeout time.Duration
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
// Workers are created on-demand and automatically cleaned up when idle.
// If zero, defaults to min(10, PoolSize/2) to handle bursts effectively.
// If explicitly set, enforces minimum of PoolSize/2
//
// Default: min(PoolSize/2, max(10, PoolSize/3)), Minimum when set: PoolSize/2
MaxWorkers int
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
// If the queue is full, new handoff requests will be rejected.
// Scales with both worker count and pool size for better burst handling.
//
// Default: max(20×MaxWorkers, PoolSize), capped by MaxActiveConns+1 (if set) or 5×PoolSize
// When set: minimum 200, capped by MaxActiveConns+1 (if set) or 5×PoolSize
HandoffQueueSize int
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
// after a handoff completes. This provides additional resilience during cluster transitions.
// Default: 2 * RelaxedTimeout
PostHandoffRelaxedDuration time.Duration
// Circuit breaker configuration for endpoint failure handling
// CircuitBreakerFailureThreshold is the number of failures before opening the circuit.
// Default: 5
CircuitBreakerFailureThreshold int
// CircuitBreakerResetTimeout is how long to wait before testing if the endpoint recovered.
// Default: 60 seconds
CircuitBreakerResetTimeout time.Duration
// CircuitBreakerMaxRequests is the maximum number of requests allowed in half-open state.
// Default: 3
CircuitBreakerMaxRequests int
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
// After this many retries, the connection will be removed from the pool.
// Default: 3
MaxHandoffRetries int
}
func (c *Config) IsEnabled() bool {
return c != nil && c.Mode != ModeDisabled
}
// DefaultConfig returns a Config with sensible defaults.
func DefaultConfig() *Config {
return &Config{
Mode: ModeAuto, // Enable by default for Redis Cloud
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
RelaxedTimeout: 10 * time.Second,
HandoffTimeout: 15 * time.Second,
MaxWorkers: 0, // Auto-calculated based on pool size
HandoffQueueSize: 0, // Auto-calculated based on max workers
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
// Circuit breaker configuration
CircuitBreakerFailureThreshold: 5,
CircuitBreakerResetTimeout: 60 * time.Second,
CircuitBreakerMaxRequests: 3,
// Connection Handoff Configuration
MaxHandoffRetries: 3,
}
}
// Validate checks if the configuration is valid.
func (c *Config) Validate() error {
if c.RelaxedTimeout <= 0 {
return ErrInvalidRelaxedTimeout
}
if c.HandoffTimeout <= 0 {
return ErrInvalidHandoffTimeout
}
// Validate worker configuration
// Allow 0 for auto-calculation, but negative values are invalid
if c.MaxWorkers < 0 {
return ErrInvalidHandoffWorkers
}
// HandoffQueueSize validation - allow 0 for auto-calculation
if c.HandoffQueueSize < 0 {
return ErrInvalidHandoffQueueSize
}
if c.PostHandoffRelaxedDuration < 0 {
return ErrInvalidPostHandoffRelaxedDuration
}
// Circuit breaker validation
if c.CircuitBreakerFailureThreshold < 1 {
return ErrInvalidCircuitBreakerFailureThreshold
}
if c.CircuitBreakerResetTimeout < 0 {
return ErrInvalidCircuitBreakerResetTimeout
}
if c.CircuitBreakerMaxRequests < 1 {
return ErrInvalidCircuitBreakerMaxRequests
}
// Validate Mode (maintenance notifications mode)
if !c.Mode.IsValid() {
return ErrInvalidMaintNotifications
}
// Validate EndpointType
if !c.EndpointType.IsValid() {
return ErrInvalidEndpointType
}
// Validate configuration fields
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
return ErrInvalidHandoffRetries
}
return nil
}
// ApplyDefaults applies default values to any zero-value fields in the configuration.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaults() *Config {
return c.ApplyDefaultsWithPoolSize(0)
}
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
// using the provided pool size to calculate worker defaults.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
return c.ApplyDefaultsWithPoolConfig(poolSize, 0)
}
// ApplyDefaultsWithPoolConfig applies default values to any zero-value fields in the configuration,
// using the provided pool size and max active connections to calculate worker and queue defaults.
// This ensures that partially configured structs get sensible defaults for missing fields.
func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *Config {
if c == nil {
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
}
defaults := DefaultConfig()
result := &Config{}
// Apply defaults for enum fields (empty/zero means not set)
result.Mode = defaults.Mode
if c.Mode != "" {
result.Mode = c.Mode
}
result.EndpointType = defaults.EndpointType
if c.EndpointType != "" {
result.EndpointType = c.EndpointType
}
// Apply defaults for duration fields (zero means not set)
result.RelaxedTimeout = defaults.RelaxedTimeout
if c.RelaxedTimeout > 0 {
result.RelaxedTimeout = c.RelaxedTimeout
}
result.HandoffTimeout = defaults.HandoffTimeout
if c.HandoffTimeout > 0 {
result.HandoffTimeout = c.HandoffTimeout
}
// Copy worker configuration
result.MaxWorkers = c.MaxWorkers
// Apply worker defaults based on pool size
result.applyWorkerDefaults(poolSize)
// Apply queue size defaults with new scaling approach
// Default: max(20x workers, PoolSize), capped by maxActiveConns or 5x pool size
workerBasedSize := result.MaxWorkers * 20
poolBasedSize := poolSize
result.HandoffQueueSize = util.Max(workerBasedSize, poolBasedSize)
if c.HandoffQueueSize > 0 {
// When explicitly set: enforce minimum of 200
result.HandoffQueueSize = util.Max(200, c.HandoffQueueSize)
}
// Cap queue size: use maxActiveConns+1 if set, otherwise 5x pool size
var queueCap int
if maxActiveConns > 0 {
queueCap = maxActiveConns + 1
// Ensure queue cap is at least 2 for very small maxActiveConns
if queueCap < 2 {
queueCap = 2
}
} else {
queueCap = poolSize * 5
}
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, queueCap)
// Ensure minimum queue size of 2 (fallback for very small pools)
if result.HandoffQueueSize < 2 {
result.HandoffQueueSize = 2
}
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
if c.PostHandoffRelaxedDuration > 0 {
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
}
// Apply defaults for configuration fields
result.MaxHandoffRetries = defaults.MaxHandoffRetries
if c.MaxHandoffRetries > 0 {
result.MaxHandoffRetries = c.MaxHandoffRetries
}
// Circuit breaker configuration
result.CircuitBreakerFailureThreshold = defaults.CircuitBreakerFailureThreshold
if c.CircuitBreakerFailureThreshold > 0 {
result.CircuitBreakerFailureThreshold = c.CircuitBreakerFailureThreshold
}
result.CircuitBreakerResetTimeout = defaults.CircuitBreakerResetTimeout
if c.CircuitBreakerResetTimeout > 0 {
result.CircuitBreakerResetTimeout = c.CircuitBreakerResetTimeout
}
result.CircuitBreakerMaxRequests = defaults.CircuitBreakerMaxRequests
if c.CircuitBreakerMaxRequests > 0 {
result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests
}
if internal.LogLevel.DebugOrAbove() {
internal.Logger.Printf(context.Background(), logs.DebugLoggingEnabled())
internal.Logger.Printf(context.Background(), logs.ConfigDebug(result))
}
return result
}
// Clone creates a deep copy of the configuration.
func (c *Config) Clone() *Config {
if c == nil {
return DefaultConfig()
}
return &Config{
Mode: c.Mode,
EndpointType: c.EndpointType,
RelaxedTimeout: c.RelaxedTimeout,
HandoffTimeout: c.HandoffTimeout,
MaxWorkers: c.MaxWorkers,
HandoffQueueSize: c.HandoffQueueSize,
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
// Circuit breaker configuration
CircuitBreakerFailureThreshold: c.CircuitBreakerFailureThreshold,
CircuitBreakerResetTimeout: c.CircuitBreakerResetTimeout,
CircuitBreakerMaxRequests: c.CircuitBreakerMaxRequests,
// Configuration fields
MaxHandoffRetries: c.MaxHandoffRetries,
}
}
// applyWorkerDefaults calculates and applies worker defaults based on pool size
func (c *Config) applyWorkerDefaults(poolSize int) {
// Calculate defaults based on pool size
if poolSize <= 0 {
poolSize = 10 * runtime.GOMAXPROCS(0)
}
// When not set: min(poolSize/2, max(10, poolSize/3)) - balanced scaling approach
originalMaxWorkers := c.MaxWorkers
c.MaxWorkers = util.Min(poolSize/2, util.Max(10, poolSize/3))
if originalMaxWorkers != 0 {
// When explicitly set: max(poolSize/2, set_value) - ensure at least poolSize/2 workers
c.MaxWorkers = util.Max(poolSize/2, originalMaxWorkers)
}
// Ensure minimum of 1 worker (fallback for very small pools)
if c.MaxWorkers < 1 {
c.MaxWorkers = 1
}
}
// DetectEndpointType automatically detects the appropriate endpoint type
// based on the connection address and TLS configuration.
//
// For IP addresses:
// - If TLS is enabled: requests FQDN for proper certificate validation
// - If TLS is disabled: requests IP for better performance
//
// For hostnames:
// - If TLS is enabled: always requests FQDN for proper certificate validation
// - If TLS is disabled: requests IP for better performance
//
// Internal vs External detection:
// - For IPs: uses private IP range detection
// - For hostnames: uses heuristics based on common internal naming patterns
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
// Extract host from "host:port" format
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr // Assume no port
}
// Check if the host is an IP address or hostname
ip := net.ParseIP(host)
isIPAddress := ip != nil
var endpointType EndpointType
if isIPAddress {
// Address is an IP - determine if it's private or public
isPrivate := ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
if tlsEnabled {
// TLS with IP addresses - still prefer FQDN for certificate validation
if isPrivate {
endpointType = EndpointTypeInternalFQDN
} else {
endpointType = EndpointTypeExternalFQDN
}
} else {
// No TLS - can use IP addresses directly
if isPrivate {
endpointType = EndpointTypeInternalIP
} else {
endpointType = EndpointTypeExternalIP
}
}
} else {
// Address is a hostname
isInternalHostname := isInternalHostname(host)
if isInternalHostname {
endpointType = EndpointTypeInternalFQDN
} else {
endpointType = EndpointTypeExternalFQDN
}
}
return endpointType
}
// isInternalHostname determines if a hostname appears to be internal/private.
// This is a heuristic based on common naming patterns.
func isInternalHostname(hostname string) bool {
// Convert to lowercase for comparison
hostname = strings.ToLower(hostname)
// Common internal hostname patterns
internalPatterns := []string{
"localhost",
".local",
".internal",
".corp",
".lan",
".intranet",
".private",
}
// Check for exact match or suffix match
for _, pattern := range internalPatterns {
if hostname == pattern || strings.HasSuffix(hostname, pattern) {
return true
}
}
// Check for RFC 1918 style hostnames (e.g., redis-1, db-server, etc.)
// If hostname doesn't contain dots, it's likely internal
if !strings.Contains(hostname, ".") {
return true
}
// Default to external for fully qualified domain names
return false
}
+76
View File
@@ -0,0 +1,76 @@
package maintnotifications
import (
"errors"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
)
// Configuration errors
var (
ErrInvalidRelaxedTimeout = errors.New(logs.InvalidRelaxedTimeoutError())
ErrInvalidHandoffTimeout = errors.New(logs.InvalidHandoffTimeoutError())
ErrInvalidHandoffWorkers = errors.New(logs.InvalidHandoffWorkersError())
ErrInvalidHandoffQueueSize = errors.New(logs.InvalidHandoffQueueSizeError())
ErrInvalidPostHandoffRelaxedDuration = errors.New(logs.InvalidPostHandoffRelaxedDurationError())
ErrInvalidEndpointType = errors.New(logs.InvalidEndpointTypeError())
ErrInvalidMaintNotifications = errors.New(logs.InvalidMaintNotificationsError())
ErrMaxHandoffRetriesReached = errors.New(logs.MaxHandoffRetriesReachedError())
// Configuration validation errors
// ErrInvalidHandoffRetries is returned when the number of handoff retries is invalid
ErrInvalidHandoffRetries = errors.New(logs.InvalidHandoffRetriesError())
)
// Integration errors
var (
// ErrInvalidClient is returned when the client does not support push notifications
ErrInvalidClient = errors.New(logs.InvalidClientError())
)
// Handoff errors
var (
// ErrHandoffQueueFull is returned when the handoff queue is full
ErrHandoffQueueFull = errors.New(logs.HandoffQueueFullError())
)
// Notification errors
var (
// ErrInvalidNotification is returned when a notification is in an invalid format
ErrInvalidNotification = errors.New(logs.InvalidNotificationError())
)
// connection handoff errors
var (
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
// and should not be used until the handoff is complete
ErrConnectionMarkedForHandoff = errors.New(logs.ConnectionMarkedForHandoffErrorMessage)
// ErrConnectionMarkedForHandoffWithState is returned when a connection is marked for handoff
// and should not be used until the handoff is complete
ErrConnectionMarkedForHandoffWithState = errors.New(logs.ConnectionMarkedForHandoffErrorMessage + " with state")
// ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff
ErrConnectionInvalidHandoffState = errors.New(logs.ConnectionInvalidHandoffStateErrorMessage)
)
// shutdown errors
var (
// ErrShutdown is returned when the maintnotifications manager is shutdown
ErrShutdown = errors.New(logs.ShutdownError())
)
// circuit breaker errors
var (
// ErrCircuitBreakerOpen is returned when the circuit breaker is open
ErrCircuitBreakerOpen = errors.New(logs.CircuitBreakerOpenErrorMessage)
)
// circuit breaker configuration errors
var (
// ErrInvalidCircuitBreakerFailureThreshold is returned when the circuit breaker failure threshold is invalid
ErrInvalidCircuitBreakerFailureThreshold = errors.New(logs.InvalidCircuitBreakerFailureThresholdError())
// ErrInvalidCircuitBreakerResetTimeout is returned when the circuit breaker reset timeout is invalid
ErrInvalidCircuitBreakerResetTimeout = errors.New(logs.InvalidCircuitBreakerResetTimeoutError())
// ErrInvalidCircuitBreakerMaxRequests is returned when the circuit breaker max requests is invalid
ErrInvalidCircuitBreakerMaxRequests = errors.New(logs.InvalidCircuitBreakerMaxRequestsError())
)
+101
View File
@@ -0,0 +1,101 @@
package maintnotifications
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// contextKey is a custom type for context keys to avoid collisions
type contextKey string
const (
startTimeKey contextKey = "maint_notif_start_time"
)
// MetricsHook collects metrics about notification processing.
type MetricsHook struct {
NotificationCounts map[string]int64
ProcessingTimes map[string]time.Duration
ErrorCounts map[string]int64
HandoffCounts int64 // Total handoffs initiated
HandoffSuccesses int64 // Successful handoffs
HandoffFailures int64 // Failed handoffs
}
// NewMetricsHook creates a new metrics collection hook.
func NewMetricsHook() *MetricsHook {
return &MetricsHook{
NotificationCounts: make(map[string]int64),
ProcessingTimes: make(map[string]time.Duration),
ErrorCounts: make(map[string]int64),
}
}
// PreHook records the start time for processing metrics.
func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
mh.NotificationCounts[notificationType]++
// Log connection information if available
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
internal.Logger.Printf(ctx, logs.MetricsHookProcessingNotification(notificationType, conn.GetID()))
}
// Store start time in context for duration calculation
startTime := time.Now()
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
return notification, true
}
// PostHook records processing completion and any errors.
func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
// Calculate processing duration
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
duration := time.Since(startTime)
mh.ProcessingTimes[notificationType] = duration
}
// Record errors
if result != nil {
mh.ErrorCounts[notificationType]++
// Log error details with connection information
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
internal.Logger.Printf(ctx, logs.MetricsHookRecordedError(notificationType, conn.GetID(), result))
}
}
}
// GetMetrics returns a summary of collected metrics.
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
return map[string]interface{}{
"notification_counts": mh.NotificationCounts,
"processing_times": mh.ProcessingTimes,
"error_counts": mh.ErrorCounts,
}
}
// ExampleCircuitBreakerMonitor demonstrates how to monitor circuit breaker status
func ExampleCircuitBreakerMonitor(poolHook *PoolHook) {
// Get circuit breaker statistics
stats := poolHook.GetCircuitBreakerStats()
for _, stat := range stats {
fmt.Printf("Circuit Breaker for %s:\n", stat.Endpoint)
fmt.Printf(" State: %s\n", stat.State)
fmt.Printf(" Failures: %d\n", stat.Failures)
fmt.Printf(" Last Failure: %v\n", stat.LastFailureTime)
fmt.Printf(" Last Success: %v\n", stat.LastSuccessTime)
// Alert if circuit breaker is open
if stat.State.String() == "open" {
fmt.Printf(" ⚠️ ALERT: Circuit breaker is OPEN for %s\n", stat.Endpoint)
}
}
}
@@ -0,0 +1,512 @@
package maintnotifications
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
)
// handoffWorkerManager manages background workers and queue for connection handoffs
type handoffWorkerManager struct {
// Event-driven handoff support
handoffQueue chan HandoffRequest // Queue for handoff requests
shutdown chan struct{} // Shutdown signal
shutdownOnce sync.Once // Ensure clean shutdown
workerWg sync.WaitGroup // Track worker goroutines
// On-demand worker management
maxWorkers int
activeWorkers atomic.Int32
workerTimeout time.Duration // How long workers wait for work before exiting
workersScaling atomic.Bool
// Simple state tracking
pending sync.Map // map[uint64]int64 (connID -> seqID)
// Configuration for the maintenance notifications
config *Config
// Pool hook reference for handoff processing
poolHook *PoolHook
// Circuit breaker manager for endpoint failure handling
circuitBreakerManager *CircuitBreakerManager
}
// newHandoffWorkerManager creates a new handoff worker manager
func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager {
return &handoffWorkerManager{
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
shutdown: make(chan struct{}),
maxWorkers: config.MaxWorkers,
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
config: config,
poolHook: poolHook,
circuitBreakerManager: newCircuitBreakerManager(config),
}
}
// getCurrentWorkers returns the current number of active workers (for testing)
func (hwm *handoffWorkerManager) getCurrentWorkers() int {
return int(hwm.activeWorkers.Load())
}
// getPendingMap returns the pending map for testing purposes
func (hwm *handoffWorkerManager) getPendingMap() *sync.Map {
return &hwm.pending
}
// getMaxWorkers returns the max workers for testing purposes
func (hwm *handoffWorkerManager) getMaxWorkers() int {
return hwm.maxWorkers
}
// getHandoffQueue returns the handoff queue for testing purposes
func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest {
return hwm.handoffQueue
}
// getCircuitBreakerStats returns circuit breaker statistics for monitoring
func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats {
return hwm.circuitBreakerManager.GetAllStats()
}
// resetCircuitBreakers resets all circuit breakers (useful for testing)
func (hwm *handoffWorkerManager) resetCircuitBreakers() {
hwm.circuitBreakerManager.Reset()
}
// isHandoffPending returns true if the given connection has a pending handoff
func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool {
_, pending := hwm.pending.Load(conn.GetID())
return pending
}
// ensureWorkerAvailable ensures at least one worker is available to process requests
// Creates a new worker if needed and under the max limit
func (hwm *handoffWorkerManager) ensureWorkerAvailable() {
select {
case <-hwm.shutdown:
return
default:
if hwm.workersScaling.CompareAndSwap(false, true) {
defer hwm.workersScaling.Store(false)
// Check if we need a new worker
currentWorkers := hwm.activeWorkers.Load()
workersWas := currentWorkers
for currentWorkers < int32(hwm.maxWorkers) {
hwm.workerWg.Add(1)
go hwm.onDemandWorker()
currentWorkers++
}
// workersWas is always <= currentWorkers
// currentWorkers will be maxWorkers, but if we have a worker that was closed
// while we were creating new workers, just add the difference between
// the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created)
hwm.activeWorkers.Add(currentWorkers - workersWas)
}
}
}
// onDemandWorker processes handoff requests and exits when idle
func (hwm *handoffWorkerManager) onDemandWorker() {
defer func() {
// Handle panics to ensure proper cleanup
if r := recover(); r != nil {
internal.Logger.Printf(context.Background(), logs.WorkerPanicRecovered(r))
}
// Decrement active worker count when exiting
hwm.activeWorkers.Add(-1)
hwm.workerWg.Done()
}()
// Create reusable timer to prevent timer leaks
timer := time.NewTimer(hwm.workerTimeout)
defer timer.Stop()
for {
// Reset timer for next iteration
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(hwm.workerTimeout)
select {
case <-hwm.shutdown:
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdown())
}
return
case <-timer.C:
// Worker has been idle for too long, exit to save resources
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout))
}
return
case request := <-hwm.handoffQueue:
// Check for shutdown before processing
select {
case <-hwm.shutdown:
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing())
}
// Clean up the request before exiting
hwm.pending.Delete(request.ConnID)
return
default:
// Process the request
hwm.processHandoffRequest(request)
}
}
}
}
// processHandoffRequest processes a single handoff request
func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint))
}
// Create a context with handoff timeout from config
handoffTimeout := 15 * time.Second // Default timeout
if hwm.config != nil && hwm.config.HandoffTimeout > 0 {
handoffTimeout = hwm.config.HandoffTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
defer cancel()
// Create a context that also respects the shutdown signal
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
defer shutdownCancel()
// Monitor shutdown signal in a separate goroutine
go func() {
select {
case <-hwm.shutdown:
shutdownCancel()
case <-shutdownCtx.Done():
}
}()
// Perform the handoff with cancellable context
shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn)
minRetryBackoff := 500 * time.Millisecond
if err != nil {
if shouldRetry {
now := time.Now()
deadline, ok := shutdownCtx.Deadline()
thirdOfTimeout := handoffTimeout / 3
if !ok || deadline.Before(now) {
// wait half the timeout before retrying if no deadline or deadline has passed
deadline = now.Add(thirdOfTimeout)
}
afterTime := deadline.Sub(now)
if afterTime < minRetryBackoff {
afterTime = minRetryBackoff
}
if internal.LogLevel.InfoOrAbove() {
// Get current retry count for better logging
currentRetries := request.Conn.HandoffRetries()
maxRetries := 3 // Default fallback
if hwm.config != nil {
maxRetries = hwm.config.MaxHandoffRetries
}
internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err))
}
// Schedule retry - keep connection in pending map until retry is queued
time.AfterFunc(afterTime, func() {
if err := hwm.queueHandoff(request.Conn); err != nil {
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err))
}
// Failed to queue retry - remove from pending and close connection
hwm.pending.Delete(request.Conn.GetID())
hwm.closeConnFromRequest(context.Background(), request, err)
} else {
// Successfully queued retry - remove from pending (will be re-added by queueHandoff)
hwm.pending.Delete(request.Conn.GetID())
}
})
return
} else {
// Won't retry - remove from pending and close connection
hwm.pending.Delete(request.Conn.GetID())
go hwm.closeConnFromRequest(ctx, request, err)
}
// Clear handoff state if not returned for retry
seqID := request.Conn.GetMovingSeqID()
connID := request.Conn.GetID()
if hwm.poolHook.operationsManager != nil {
hwm.poolHook.operationsManager.UntrackOperationWithConnID(seqID, connID)
}
} else {
// Success - remove from pending map
hwm.pending.Delete(request.Conn.GetID())
}
}
// queueHandoff queues a handoff request for processing
// if err is returned, connection will be removed from pool
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
// Get handoff info atomically to prevent race conditions
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
// on retries the connection will not be marked for handoff, but it will have retries > 0
// if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff
if !shouldHandoff && conn.HandoffRetries() == 0 {
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID()))
}
return errors.New(logs.ConnectionNotMarkedForHandoffError(conn.GetID()))
}
// Create handoff request with atomically retrieved data
request := HandoffRequest{
Conn: conn,
ConnID: conn.GetID(),
Endpoint: endpoint,
SeqID: seqID,
Pool: hwm.poolHook.pool, // Include pool for connection removal on failure
}
select {
// priority to shutdown
case <-hwm.shutdown:
return ErrShutdown
default:
select {
case <-hwm.shutdown:
return ErrShutdown
case hwm.handoffQueue <- request:
// Store in pending map
hwm.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
hwm.ensureWorkerAvailable()
return nil
default:
select {
case <-hwm.shutdown:
return ErrShutdown
case hwm.handoffQueue <- request:
// Store in pending map
hwm.pending.Store(request.ConnID, request.SeqID)
// Ensure we have a worker to process this request
hwm.ensureWorkerAvailable()
return nil
case <-time.After(100 * time.Millisecond): // give workers a chance to process
// Queue is full - log and attempt scaling
queueLen := len(hwm.handoffQueue)
queueCap := cap(hwm.handoffQueue)
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap))
}
}
}
}
// Ensure we have workers available to handle the load
hwm.ensureWorkerAvailable()
return ErrHandoffQueueFull
}
// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete
func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error {
hwm.shutdownOnce.Do(func() {
close(hwm.shutdown)
// workers will exit when they finish their current request
// Shutdown circuit breaker manager cleanup goroutine
if hwm.circuitBreakerManager != nil {
hwm.circuitBreakerManager.Shutdown()
}
})
// Wait for workers to complete
done := make(chan struct{})
go func() {
hwm.workerWg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// performConnectionHandoff performs the actual connection handoff
// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached
func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) {
// Clear handoff state after successful handoff
connID := conn.GetID()
newEndpoint := conn.GetHandoffEndpoint()
if newEndpoint == "" {
return false, ErrConnectionInvalidHandoffState
}
// Use circuit breaker to protect against failing endpoints
circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint)
// Check if circuit breaker is open before attempting handoff
if circuitBreaker.IsOpen() {
internal.Logger.Printf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint))
return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open
}
// Perform the handoff
shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID)
// Update circuit breaker based on result
if err != nil {
// Only track dial/network errors in circuit breaker, not initialization errors
if shouldRetry {
circuitBreaker.recordFailure()
}
return shouldRetry, err
}
// Success - record in circuit breaker
circuitBreaker.recordSuccess()
return false, nil
}
// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
func (hwm *handoffWorkerManager) performHandoffInternal(
ctx context.Context,
conn *pool.Conn,
newEndpoint string,
connID uint64,
) (shouldRetry bool, err error) {
retries := conn.IncrementAndGetHandoffRetries(1)
internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String()))
maxRetries := 3 // Default fallback
if hwm.config != nil {
maxRetries = hwm.config.MaxHandoffRetries
}
if retries > maxRetries {
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries))
}
// won't retry on ErrMaxHandoffRetriesReached
return false, ErrMaxHandoffRetriesReached
}
// Create endpoint-specific dialer
endpointDialer := hwm.createEndpointDialer(newEndpoint)
// Create new connection to the new endpoint
newNetConn, err := endpointDialer(ctx)
if err != nil {
internal.Logger.Printf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err))
// will retry
// Maybe a network error - retry after a delay
return true, err
}
// Get the old connection
oldConn := conn.GetNetConn()
// Apply relaxed timeout to the new connection for the configured post-handoff duration
// This gives the new connection more time to handle operations during cluster transition
// Setting this here (before initing the connection) ensures that the connection is going
// to use the relaxed timeout for the first operation (auth/ACL select)
if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 {
relaxedTimeout := hwm.config.RelaxedTimeout
// Set relaxed timeout with deadline - no background goroutine needed
deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration)
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000")))
}
}
// Replace the connection and execute initialization
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
if err != nil {
// won't retry
// Initialization failed - remove the connection
return false, err
}
defer func() {
if oldConn != nil {
oldConn.Close()
}
}()
// Clear handoff state will:
// - set the connection as usable again
// - clear the handoff state (shouldHandoff, endpoint, seqID)
// - reset the handoff retries to 0
// Note: Theoretically there may be a short window where the connection is in the pool
// and IDLE (initConn completed) but still has handoff state set.
conn.ClearHandoffState()
internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint))
// successfully completed the handoff, no retry needed and no error
return false, nil
}
// createEndpointDialer creates a dialer function that connects to a specific endpoint
func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
return func(ctx context.Context) (net.Conn, error) {
// Parse endpoint to extract host and port
host, port, err := net.SplitHostPort(endpoint)
if err != nil {
// If no port specified, assume default Redis port
host = endpoint
if port == "" {
port = "6379"
}
}
// Use the base dialer to connect to the new endpoint
return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port))
}
}
// closeConnFromRequest closes the connection and logs the reason
func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
pooler := request.Pool
conn := request.Conn
// Clear handoff state before closing
conn.ClearHandoffState()
if pooler != nil {
// Use RemoveWithoutTurn instead of Remove to avoid freeing a turn that we don't have.
// The handoff worker doesn't call Get(), so it doesn't have a turn to free.
// Remove() is meant to be called after Get() and frees a turn.
// RemoveWithoutTurn() removes and closes the connection without affecting the queue.
pooler.RemoveWithoutTurn(ctx, conn, err)
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err))
}
} else {
err := conn.Close() // Close the connection if no pool provided
if err != nil {
internal.Logger.Printf(ctx, "redis: failed to close connection: %v", err)
}
if internal.LogLevel.WarnOrAbove() {
internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err))
}
}
}
+60
View File
@@ -0,0 +1,60 @@
package maintnotifications
import (
"context"
"slices"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// LoggingHook is an example hook implementation that logs all notifications.
type LoggingHook struct {
LogLevel int // 0=Error, 1=Warn, 2=Info, 3=Debug
}
// PreHook logs the notification before processing and allows modification.
func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
if lh.LogLevel >= 2 { // Info level
// Log the notification type and content
connID := uint64(0)
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
connID = conn.GetID()
}
seqID := int64(0)
if slices.Contains(maintenanceNotificationTypes, notificationType) {
// seqID is the second element in the notification array
if len(notification) > 1 {
if parsedSeqID, ok := notification[1].(int64); !ok {
seqID = 0
} else {
seqID = parsedSeqID
}
}
}
internal.Logger.Printf(ctx, logs.ProcessingNotification(connID, seqID, notificationType, notification))
}
return notification, true // Continue processing with unmodified notification
}
// PostHook logs the result after processing.
func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
connID := uint64(0)
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
connID = conn.GetID()
}
if result != nil && lh.LogLevel >= 1 { // Warning level
internal.Logger.Printf(ctx, logs.ProcessingNotificationFailed(connID, notificationType, result, notification))
} else if lh.LogLevel >= 3 { // Debug level
internal.Logger.Printf(ctx, logs.ProcessingNotificationSucceeded(connID, notificationType))
}
}
// NewLoggingHook creates a new logging hook with the specified log level.
// Log levels: 0=Error, 1=Warn, 2=Info, 3=Debug
func NewLoggingHook(logLevel int) *LoggingHook {
return &LoggingHook{LogLevel: logLevel}
}
+320
View File
@@ -0,0 +1,320 @@
package maintnotifications
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/interfaces"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// Push notification type constants for maintenance
const (
NotificationMoving = "MOVING"
NotificationMigrating = "MIGRATING"
NotificationMigrated = "MIGRATED"
NotificationFailingOver = "FAILING_OVER"
NotificationFailedOver = "FAILED_OVER"
)
// maintenanceNotificationTypes contains all notification types that maintenance handles
var maintenanceNotificationTypes = []string{
NotificationMoving,
NotificationMigrating,
NotificationMigrated,
NotificationFailingOver,
NotificationFailedOver,
}
// NotificationHook is called before and after notification processing
// PreHook can modify the notification and return false to skip processing
// PostHook is called after successful processing
type NotificationHook interface {
PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool)
PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error)
}
// MovingOperationKey provides a unique key for tracking MOVING operations
// that combines sequence ID with connection identifier to handle duplicate
// sequence IDs across multiple connections to the same node.
type MovingOperationKey struct {
SeqID int64 // Sequence ID from MOVING notification
ConnID uint64 // Unique connection identifier
}
// String returns a string representation of the key for debugging
func (k MovingOperationKey) String() string {
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
}
// Manager provides a simplified upgrade functionality with hooks and atomic state.
type Manager struct {
client interfaces.ClientInterface
config *Config
options interfaces.OptionsInterface
pool pool.Pooler
// MOVING operation tracking - using sync.Map for better concurrent performance
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
// Atomic state tracking - no locks needed for state queries
activeOperationCount atomic.Int64 // Number of active operations
closed atomic.Bool // Manager closed state
// Notification hooks for extensibility
hooks []NotificationHook
hooksMu sync.RWMutex // Protects hooks slice
poolHooksRef *PoolHook
}
// MovingOperation tracks an active MOVING operation.
type MovingOperation struct {
SeqID int64
NewEndpoint string
StartTime time.Time
Deadline time.Time
}
// NewManager creates a new simplified manager.
func NewManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*Manager, error) {
if client == nil {
return nil, ErrInvalidClient
}
hm := &Manager{
client: client,
pool: pool,
options: client.GetOptions(),
config: config.Clone(),
hooks: make([]NotificationHook, 0),
}
// Set up push notification handling
if err := hm.setupPushNotifications(); err != nil {
return nil, err
}
return hm, nil
}
// GetPoolHook creates a pool hook with a custom dialer.
func (hm *Manager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
poolHook := hm.createPoolHook(baseDialer)
hm.pool.AddPoolHook(poolHook)
}
// setupPushNotifications sets up push notification handling by registering with the client's processor.
func (hm *Manager) setupPushNotifications() error {
processor := hm.client.GetPushProcessor()
if processor == nil {
return ErrInvalidClient // Client doesn't support push notifications
}
// Create our notification handler
handler := &NotificationHandler{manager: hm, operationsManager: hm}
// Register handlers for all upgrade notifications with the client's processor
for _, notificationType := range maintenanceNotificationTypes {
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
return errors.New(logs.FailedToRegisterHandler(notificationType, err))
}
}
return nil
}
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
func (hm *Manager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Create MOVING operation record
movingOp := &MovingOperation{
SeqID: seqID,
NewEndpoint: newEndpoint,
StartTime: time.Now(),
Deadline: deadline,
}
// Use LoadOrStore for atomic check-and-set operation
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
// Duplicate MOVING notification, ignore
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID))
}
return nil
}
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID))
}
// Increment active operation count atomically
hm.activeOperationCount.Add(1)
return nil
}
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
func (hm *Manager) UntrackOperationWithConnID(seqID int64, connID uint64) {
// Create composite key
key := MovingOperationKey{
SeqID: seqID,
ConnID: connID,
}
// Remove from active operations atomically
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.UntrackingMovingOperation(connID, seqID))
}
// Decrement active operation count only if operation existed
hm.activeOperationCount.Add(-1)
} else {
if internal.LogLevel.DebugOrAbove() { // Debug level
internal.Logger.Printf(context.Background(), logs.OperationNotTracked(connID, seqID))
}
}
}
// GetActiveMovingOperations returns active operations with composite keys.
// WARNING: This method creates a new map and copies all operations on every call.
// Use sparingly, especially in hot paths or high-frequency logging.
func (hm *Manager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
result := make(map[MovingOperationKey]*MovingOperation)
// Iterate over sync.Map to build result
hm.activeMovingOps.Range(func(key, value interface{}) bool {
k := key.(MovingOperationKey)
op := value.(*MovingOperation)
// Create a copy to avoid sharing references
result[k] = &MovingOperation{
SeqID: op.SeqID,
NewEndpoint: op.NewEndpoint,
StartTime: op.StartTime,
Deadline: op.Deadline,
}
return true // Continue iteration
})
return result
}
// IsHandoffInProgress returns true if any handoff is in progress.
// Uses atomic counter for lock-free operation.
func (hm *Manager) IsHandoffInProgress() bool {
return hm.activeOperationCount.Load() > 0
}
// GetActiveOperationCount returns the number of active operations.
// Uses atomic counter for lock-free operation.
func (hm *Manager) GetActiveOperationCount() int64 {
return hm.activeOperationCount.Load()
}
// Close closes the manager.
func (hm *Manager) Close() error {
// Use atomic operation for thread-safe close check
if !hm.closed.CompareAndSwap(false, true) {
return nil // Already closed
}
// Shutdown the pool hook if it exists
if hm.poolHooksRef != nil {
// Use a timeout to prevent hanging indefinitely
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := hm.poolHooksRef.Shutdown(shutdownCtx)
if err != nil {
// was not able to close pool hook, keep closed state false
hm.closed.Store(false)
return err
}
// Remove the pool hook from the pool
if hm.pool != nil {
hm.pool.RemovePoolHook(hm.poolHooksRef)
}
}
// Clear all active operations
hm.activeMovingOps.Range(func(key, value interface{}) bool {
hm.activeMovingOps.Delete(key)
return true
})
// Reset counter
hm.activeOperationCount.Store(0)
return nil
}
// GetState returns current state using atomic counter for lock-free operation.
func (hm *Manager) GetState() State {
if hm.activeOperationCount.Load() > 0 {
return StateMoving
}
return StateIdle
}
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
func (hm *Manager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
currentNotification := notification
for _, hook := range hm.hooks {
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification)
if !shouldContinue {
return modifiedNotification, false
}
currentNotification = modifiedNotification
}
return currentNotification, true
}
// processPostHooks calls all post-hooks with the processing result.
func (hm *Manager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
hm.hooksMu.RLock()
defer hm.hooksMu.RUnlock()
for _, hook := range hm.hooks {
hook.PostHook(ctx, notificationCtx, notificationType, notification, result)
}
}
// createPoolHook creates a pool hook with this manager already set.
func (hm *Manager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
if hm.poolHooksRef != nil {
return hm.poolHooksRef
}
// Get pool size from client options for better worker defaults
poolSize := 0
if hm.options != nil {
poolSize = hm.options.GetPoolSize()
}
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
hm.poolHooksRef.SetPool(hm.pool)
return hm.poolHooksRef
}
func (hm *Manager) AddNotificationHook(notificationHook NotificationHook) {
hm.hooksMu.Lock()
defer hm.hooksMu.Unlock()
hm.hooks = append(hm.hooks, notificationHook)
}
+182
View File
@@ -0,0 +1,182 @@
package maintnotifications
import (
"context"
"net"
"sync"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
)
// OperationsManagerInterface defines the interface for completing handoff operations
type OperationsManagerInterface interface {
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
UntrackOperationWithConnID(seqID int64, connID uint64)
}
// HandoffRequest represents a request to handoff a connection to a new endpoint
type HandoffRequest struct {
Conn *pool.Conn
ConnID uint64 // Unique connection identifier
Endpoint string
SeqID int64
Pool pool.Pooler // Pool to remove connection from on failure
}
// PoolHook implements pool.PoolHook for Redis-specific connection handling
// with maintenance notifications support.
type PoolHook struct {
// Base dialer for creating connections to new endpoints during handoffs
// args are network and address
baseDialer func(context.Context, string, string) (net.Conn, error)
// Network type (e.g., "tcp", "unix")
network string
// Worker manager for background handoff processing
workerManager *handoffWorkerManager
// Configuration for the maintenance notifications
config *Config
// Operations manager interface for operation completion tracking
operationsManager OperationsManagerInterface
// Pool interface for removing connections on handoff failure
pool pool.Pooler
}
// NewPoolHook creates a new pool hook
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface) *PoolHook {
return NewPoolHookWithPoolSize(baseDialer, network, config, operationsManager, 0)
}
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface, poolSize int) *PoolHook {
// Apply defaults if config is nil or has zero values
if config == nil {
config = config.ApplyDefaultsWithPoolSize(poolSize)
}
ph := &PoolHook{
// baseDialer is used to create connections to new endpoints during handoffs
baseDialer: baseDialer,
network: network,
config: config,
operationsManager: operationsManager,
}
// Create worker manager
ph.workerManager = newHandoffWorkerManager(config, ph)
return ph
}
// SetPool sets the pool interface for removing connections on handoff failure
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
ph.pool = pooler
}
// GetCurrentWorkers returns the current number of active workers (for testing)
func (ph *PoolHook) GetCurrentWorkers() int {
return ph.workerManager.getCurrentWorkers()
}
// IsHandoffPending returns true if the given connection has a pending handoff
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
return ph.workerManager.isHandoffPending(conn)
}
// GetPendingMap returns the pending map for testing purposes
func (ph *PoolHook) GetPendingMap() *sync.Map {
return ph.workerManager.getPendingMap()
}
// GetMaxWorkers returns the max workers for testing purposes
func (ph *PoolHook) GetMaxWorkers() int {
return ph.workerManager.getMaxWorkers()
}
// GetHandoffQueue returns the handoff queue for testing purposes
func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest {
return ph.workerManager.getHandoffQueue()
}
// GetCircuitBreakerStats returns circuit breaker statistics for monitoring
func (ph *PoolHook) GetCircuitBreakerStats() []CircuitBreakerStats {
return ph.workerManager.getCircuitBreakerStats()
}
// ResetCircuitBreakers resets all circuit breakers (useful for testing)
func (ph *PoolHook) ResetCircuitBreakers() {
ph.workerManager.resetCircuitBreakers()
}
// OnGet is called when a connection is retrieved from the pool
func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
// Check if connection is marked for handoff
// This prevents using connections that have received MOVING notifications
if conn.ShouldHandoff() {
return false, ErrConnectionMarkedForHandoffWithState
}
// Check if connection is usable (not in UNUSABLE or CLOSED state)
// This ensures we don't return connections that are currently being handed off or re-authenticated.
if !conn.IsUsable() {
return false, ErrConnectionMarkedForHandoff
}
return true, nil
}
// OnPut is called when a connection is returned to the pool
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
// first check if we should handoff for faster rejection
if !conn.ShouldHandoff() {
// Default behavior (no handoff): pool the connection
return true, false, nil
}
// check pending handoff to not queue the same connection twice
if ph.workerManager.isHandoffPending(conn) {
// Default behavior (pending handoff): pool the connection
return true, false, nil
}
if err := ph.workerManager.queueHandoff(conn); err != nil {
// Failed to queue handoff, remove the connection
internal.Logger.Printf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err))
// Don't pool, remove connection, no error to caller
return false, true, nil
}
// Check if handoff was already processed by a worker before we can mark it as queued
if !conn.ShouldHandoff() {
// Handoff was already processed - this is normal and the connection should be pooled
return true, false, nil
}
if err := conn.MarkQueuedForHandoff(); err != nil {
// If marking fails, check if handoff was processed in the meantime
if !conn.ShouldHandoff() {
// Handoff was processed - this is normal, pool the connection
return true, false, nil
}
// Other error - remove the connection
return false, true, nil
}
internal.Logger.Printf(ctx, logs.MarkedForHandoff(conn.GetID()))
return true, false, nil
}
func (ph *PoolHook) OnRemove(_ context.Context, _ *pool.Conn, _ error) {
// Not used
}
// Shutdown gracefully shuts down the processor, waiting for workers to complete
func (ph *PoolHook) Shutdown(ctx context.Context) error {
return ph.workerManager.shutdownWorkers(ctx)
}
@@ -0,0 +1,282 @@
package maintnotifications
import (
"context"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/push"
)
// NotificationHandler handles push notifications for the simplified manager.
type NotificationHandler struct {
manager *Manager
operationsManager OperationsManagerInterface
}
// HandlePushNotification processes push notifications with hook support.
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) == 0 {
internal.Logger.Printf(ctx, logs.InvalidNotificationFormat(notification))
return ErrInvalidNotification
}
notificationType, ok := notification[0].(string)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidNotificationTypeFormat(notification[0]))
return ErrInvalidNotification
}
// Process pre-hooks - they can modify the notification or skip processing
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, handlerCtx, notificationType, notification)
if !shouldContinue {
return nil // Hooks decided to skip processing
}
var err error
switch notificationType {
case NotificationMoving:
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
case NotificationMigrating:
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
case NotificationMigrated:
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
case NotificationFailingOver:
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
case NotificationFailedOver:
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
default:
// Ignore other notification types (e.g., pub/sub messages)
err = nil
}
// Process post-hooks with the result
snh.manager.processPostHooks(ctx, handlerCtx, notificationType, modifiedNotification, err)
return err
}
// handleMoving processes MOVING notifications.
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) < 3 {
internal.Logger.Printf(ctx, logs.InvalidNotification("MOVING", notification))
return ErrInvalidNotification
}
seqID, ok := notification[1].(int64)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1]))
return ErrInvalidNotification
}
// Extract timeS
timeS, ok := notification[2].(int64)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidTimeSInMovingNotification(notification[2]))
return ErrInvalidNotification
}
newEndpoint := ""
if len(notification) > 3 {
// Extract new endpoint
newEndpoint, ok = notification[3].(string)
if !ok {
stringified := fmt.Sprintf("%v", notification[3])
// this could be <nil> which is valid
if notification[3] == nil || stringified == internal.RedisNull {
newEndpoint = ""
} else {
internal.Logger.Printf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3]))
return ErrInvalidNotification
}
}
}
// Get the connection that received this notification
conn := handlerCtx.Conn
if conn == nil {
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MOVING"))
return ErrInvalidNotification
}
// Type assert to get the underlying pool connection
var poolConn *pool.Conn
if pc, ok := conn.(*pool.Conn); ok {
poolConn = pc
} else {
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx))
return ErrInvalidNotification
}
// If the connection is closed or not pooled, we can ignore the notification
// this connection won't be remembered by the pool and will be garbage collected
// Keep pubsub connections around since they are not pooled but are long-lived
// and should be allowed to handoff (the pubsub instance will reconnect and change
// the underlying *pool.Conn)
if (poolConn.IsClosed() || !poolConn.IsPooled()) && !poolConn.IsPubSub() {
return nil
}
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
if newEndpoint == "" || newEndpoint == internal.RedisNull {
if internal.LogLevel.DebugOrAbove() {
internal.Logger.Printf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2))
}
// same as current endpoint
newEndpoint = snh.manager.options.GetAddr()
// delay the handoff for timeS/2 seconds to the same endpoint
// do this in a goroutine to avoid blocking the notification handler
// NOTE: This timer is started while parsing the notification, so the connection is not marked for handoff
// and there should be no possibility of a race condition or double handoff.
time.AfterFunc(time.Duration(timeS/2)*time.Second, func() {
if poolConn == nil || poolConn.IsClosed() {
return
}
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
// Log error but don't fail the goroutine - use background context since original may be cancelled
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err))
}
})
return nil
}
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
}
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err))
// Connection is already marked for handoff, which is acceptable
// This can happen if multiple MOVING notifications are received for the same connection
return nil
}
// Optionally track in m
if snh.operationsManager != nil {
connID := conn.GetID()
// Track the operation (ignore errors since this is optional)
_ = snh.operationsManager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
} else {
return errors.New(logs.ManagerNotInitialized())
}
return nil
}
// handleMigrating processes MIGRATING notifications.
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATING notifications indicate that a connection is about to be migrated
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATING", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATING"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
if internal.LogLevel.InfoOrAbove() {
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout))
}
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleMigrated processes MIGRATED notifications.
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// MIGRATED notifications indicate that a connection migration has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATED", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATED"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
if internal.LogLevel.InfoOrAbove() {
connID := conn.GetID()
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
}
conn.ClearRelaxedTimeout()
return nil
}
// handleFailingOver processes FAILING_OVER notifications.
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILING_OVER notifications indicate that a connection is about to failover
// Apply relaxed timeouts to the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILING_OVER", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Apply relaxed timeout to this specific connection
if internal.LogLevel.InfoOrAbove() {
connID := conn.GetID()
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout))
}
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
return nil
}
// handleFailedOver processes FAILED_OVER notifications.
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
// FAILED_OVER notifications indicate that a connection failover has completed
// Restore normal timeouts for the specific connection that received this notification
if len(notification) < 2 {
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILED_OVER", notification))
return ErrInvalidNotification
}
if handlerCtx.Conn == nil {
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER"))
return ErrInvalidNotification
}
conn, ok := handlerCtx.Conn.(*pool.Conn)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx))
return ErrInvalidNotification
}
// Clear relaxed timeout for this specific connection
if internal.LogLevel.InfoOrAbove() {
connID := conn.GetID()
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
}
conn.ClearRelaxedTimeout()
return nil
}
+24
View File
@@ -0,0 +1,24 @@
package maintnotifications
// State represents the current state of a maintenance operation
type State int
const (
// StateIdle indicates no upgrade is in progress
StateIdle State = iota
// StateHandoff indicates a connection handoff is in progress
StateMoving
)
// String returns a string representation of the state.
func (s State) String() string {
switch s {
case StateIdle:
return "idle"
case StateMoving:
return "moving"
default:
return "unknown"
}
}
+149 -14
View File
@@ -16,6 +16,9 @@ import (
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/push"
)
// Limiter is the interface of a rate limiter or a circuit breaker.
@@ -31,7 +34,6 @@ type Limiter interface {
// Options keeps the settings to set up redis connection.
type Options struct {
// Network type, either tcp or unix.
//
// default: is tcp.
@@ -109,6 +111,16 @@ type Options struct {
// default: 5 seconds
DialTimeout time.Duration
// DialerRetries is the maximum number of retry attempts when dialing fails.
//
// default: 5
DialerRetries int
// DialerRetryTimeout is the backoff duration between retry attempts.
//
// default: 100 milliseconds
DialerRetryTimeout time.Duration
// ReadTimeout for socket reads. If reached, commands will fail
// with a timeout instead of blocking. Supported values:
//
@@ -152,6 +164,7 @@ type Options struct {
//
// Note that FIFO has slightly higher overhead compared to LIFO,
// but it helps closing idle connections faster reducing the pool size.
// default: false
PoolFIFO bool
// PoolSize is the base number of socket connections.
@@ -162,6 +175,10 @@ type Options struct {
// default: 10 * runtime.GOMAXPROCS(0)
PoolSize int
// MaxConcurrentDials is the maximum number of concurrent connection creation goroutines.
// If <= 0, defaults to PoolSize. If > PoolSize, it will be capped at PoolSize.
MaxConcurrentDials int
// PoolTimeout is the amount of time client waits for connection if all connections
// are busy before returning an error.
//
@@ -232,10 +249,24 @@ type Options struct {
// When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult
UnstableResp3 bool
// Push notifications are always enabled for RESP3 connections (Protocol: 3)
// and are not available for RESP2 connections. No configuration option is needed.
// PushNotificationProcessor is the processor for handling push notifications.
// If nil, a default processor will be created for RESP3 connections.
PushNotificationProcessor push.NotificationProcessor
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// MaintNotificationsConfig provides custom configuration for maintnotifications.
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, maintnotifications are in "auto" mode and will be enabled if the server supports it.
MaintNotificationsConfig *maintnotifications.Config
}
func (opt *Options) init() {
@@ -255,12 +286,23 @@ func (opt *Options) init() {
if opt.DialTimeout == 0 {
opt.DialTimeout = 5 * time.Second
}
if opt.DialerRetries == 0 {
opt.DialerRetries = 5
}
if opt.DialerRetryTimeout == 0 {
opt.DialerRetryTimeout = 100 * time.Millisecond
}
if opt.Dialer == nil {
opt.Dialer = NewDialer(opt)
}
if opt.PoolSize == 0 {
opt.PoolSize = 10 * runtime.GOMAXPROCS(0)
}
if opt.MaxConcurrentDials <= 0 {
opt.MaxConcurrentDials = opt.PoolSize
} else if opt.MaxConcurrentDials > opt.PoolSize {
opt.MaxConcurrentDials = opt.PoolSize
}
if opt.ReadBufferSize == 0 {
opt.ReadBufferSize = proto.DefaultBufferSize
}
@@ -312,13 +354,40 @@ func (opt *Options) init() {
case 0:
opt.MaxRetryBackoff = 512 * time.Millisecond
}
if opt.FailingTimeoutSeconds == 0 {
opt.FailingTimeoutSeconds = 15
}
opt.MaintNotificationsConfig = opt.MaintNotificationsConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns)
// auto-detect endpoint type if not specified
endpointType := opt.MaintNotificationsConfig.EndpointType
if endpointType == "" || endpointType == maintnotifications.EndpointTypeAuto {
// Auto-detect endpoint type if not specified
endpointType = maintnotifications.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
}
opt.MaintNotificationsConfig.EndpointType = endpointType
}
func (opt *Options) clone() *Options {
clone := *opt
// Deep clone MaintNotificationsConfig to avoid sharing between clients
if opt.MaintNotificationsConfig != nil {
configClone := *opt.MaintNotificationsConfig
clone.MaintNotificationsConfig = &configClone
}
return &clone
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
return NewDialer(opt)
}
// NewDialer returns a function that will be used as the default dialer
// when none is specified in Options.Dialer.
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
@@ -565,6 +634,7 @@ func setupConnParams(u *url.URL, o *Options) (*Options, error) {
o.MinIdleConns = q.int("min_idle_conns")
o.MaxIdleConns = q.int("max_idle_conns")
o.MaxActiveConns = q.int("max_active_conns")
o.MaxConcurrentDials = q.int("max_concurrent_dials")
if q.has("conn_max_idle_time") {
o.ConnMaxIdleTime = q.duration("conn_max_idle_time")
} else {
@@ -604,21 +674,86 @@ func getUserPassword(u *url.URL) (string, string) {
func newConnPool(
opt *Options,
dialer func(ctx context.Context, network, addr string) (net.Conn, error),
) *pool.ConnPool {
) (*pool.ConnPool, error) {
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
if err != nil {
return nil, err
}
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
if err != nil {
return nil, err
}
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
if err != nil {
return nil, err
}
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
if err != nil {
return nil, err
}
return pool.NewConnPool(&pool.Options{
Dialer: func(ctx context.Context) (net.Conn, error) {
return dialer(ctx, opt.Network, opt.Addr)
},
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
MinIdleConns: opt.MinIdleConns,
MaxIdleConns: opt.MaxIdleConns,
MaxActiveConns: opt.MaxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
})
PoolFIFO: opt.PoolFIFO,
PoolSize: poolSize,
MaxConcurrentDials: opt.MaxConcurrentDials,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
DialerRetries: opt.DialerRetries,
DialerRetryTimeout: opt.DialerRetryTimeout,
MinIdleConns: minIdleConns,
MaxIdleConns: maxIdleConns,
MaxActiveConns: maxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
PushNotificationsEnabled: opt.Protocol == 3,
}), nil
}
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
) (*pool.PubSubPool, error) {
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
if err != nil {
return nil, err
}
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
if err != nil {
return nil, err
}
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
if err != nil {
return nil, err
}
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
if err != nil {
return nil, err
}
return pool.NewPubSubPool(&pool.Options{
PoolFIFO: opt.PoolFIFO,
PoolSize: poolSize,
MaxConcurrentDials: opt.MaxConcurrentDials,
PoolTimeout: opt.PoolTimeout,
DialTimeout: opt.DialTimeout,
DialerRetries: opt.DialerRetries,
DialerRetryTimeout: opt.DialerRetryTimeout,
MinIdleConns: minIdleConns,
MaxIdleConns: maxIdleConns,
MaxActiveConns: maxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: 32 * 1024,
WriteBufferSize: 32 * 1024,
PushNotificationsEnabled: opt.Protocol == 3,
}, dialer), nil
}
+67 -12
View File
@@ -20,6 +20,8 @@ import (
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/internal/rand"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/push"
)
const (
@@ -38,6 +40,7 @@ type ClusterOptions struct {
ClientName string
// NewClient creates a cluster node client with provided name and options.
// If NewClient is set by the user, the user is responsible for handling maintnotifications upgrades and push notifications.
NewClient func(opt *Options) *Client
// The maximum number of retries before giving up. Command is retried
@@ -74,6 +77,10 @@ type ClusterOptions struct {
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
StreamingCredentialsProvider auth.StreamingCredentialsProvider
// MaxRetries is the maximum number of retries before giving up.
// For ClusterClient, retries are disabled by default (set to -1),
// because the cluster client handles all kinds of retries internally.
// This is intentional and differs from the standalone Options default.
MaxRetries int
MinRetryBackoff time.Duration
MaxRetryBackoff time.Duration
@@ -125,10 +132,22 @@ type ClusterOptions struct {
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
UnstableResp3 bool
// PushNotificationProcessor is the processor for handling push notifications.
// If nil, a default processor will be created for RESP3 connections.
PushNotificationProcessor push.NotificationProcessor
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
// When a node is marked as failing, it will be avoided for this duration.
// Default is 15 seconds.
FailingTimeoutSeconds int
// MaintNotificationsConfig provides custom configuration for maintnotifications upgrades.
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
// cluster upgrade notifications gracefully and manage connection/pool state
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, maintnotifications upgrades are in "auto" mode and will be enabled if the server supports it.
// The ClusterClient does not directly work with maintnotifications, it is up to the clients in the Nodes map to work with maintnotifications.
MaintNotificationsConfig *maintnotifications.Config
}
func (opt *ClusterOptions) init() {
@@ -319,6 +338,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
}
func (opt *ClusterOptions) clientOptions() *Options {
// Clone MaintNotificationsConfig to avoid sharing between cluster node clients
var maintNotificationsConfig *maintnotifications.Config
if opt.MaintNotificationsConfig != nil {
configClone := *opt.MaintNotificationsConfig
maintNotificationsConfig = &configClone
}
return &Options{
ClientName: opt.ClientName,
Dialer: opt.Dialer,
@@ -360,8 +386,10 @@ func (opt *ClusterOptions) clientOptions() *Options {
// much use for ClusterSlots config). This means we cannot execute the
// READONLY command against that node -- setting readOnly to false in such
// situations in the options below will prevent that from happening.
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3,
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
UnstableResp3: opt.UnstableResp3,
MaintNotificationsConfig: maintNotificationsConfig,
PushNotificationProcessor: opt.PushNotificationProcessor,
}
}
@@ -1664,7 +1692,7 @@ func (c *ClusterClient) processTxPipelineNode(
}
func (c *ClusterClient) processTxPipelineNodeConn(
ctx context.Context, _ *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error {
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
@@ -1682,7 +1710,7 @@ func (c *ClusterClient) processTxPipelineNodeConn(
trimmedCmds := cmds[1 : len(cmds)-1]
if err := c.txPipelineReadQueued(
ctx, rd, statusCmd, trimmedCmds, failedCmds,
ctx, node, cn, rd, statusCmd, trimmedCmds, failedCmds,
); err != nil {
setCmdsErr(cmds, err)
@@ -1694,23 +1722,37 @@ func (c *ClusterClient) processTxPipelineNodeConn(
return err
}
return pipelineReadCmds(rd, trimmedCmds)
return node.Client.pipelineReadCmds(ctx, cn, rd, trimmedCmds)
})
}
func (c *ClusterClient) txPipelineReadQueued(
ctx context.Context,
node *clusterNode,
cn *pool.Conn,
rd *proto.Reader,
statusCmd *StatusCmd,
cmds []Cmder,
failedCmds *cmdsMap,
) error {
// Parse queued replies.
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
if err := statusCmd.readReply(rd); err != nil {
return err
}
for _, cmd := range cmds {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
err := statusCmd.readReply(rd)
if err != nil {
if c.checkMovedErr(ctx, cmd, err, failedCmds) {
@@ -1724,6 +1766,12 @@ func (c *ClusterClient) txPipelineReadQueued(
}
}
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
// Parse number of replies.
line, err := rd.ReadLine()
if err != nil {
@@ -1829,12 +1877,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
return err
}
// maintenance notifications won't work here for now
func (c *ClusterClient) pubSub() *PubSub {
var node *clusterNode
pubsub := &PubSub{
opt: c.opt.clientOptions(),
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
if node != nil {
panic("node != nil")
}
@@ -1868,18 +1916,25 @@ func (c *ClusterClient) pubSub() *PubSub {
return nil, err
}
}
cn, err := node.Client.newConn(context.TODO())
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
if err != nil {
node = nil
return nil, err
}
// will return nil if already initialized
err = node.Client.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
node = nil
return nil, err
}
node.Client.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: func(cn *pool.Conn) error {
err := node.Client.connPool.CloseConn(cn)
// Untrack connection from PubSubPool
node.Client.pubSubPool.UntrackConn(cn)
err := cn.Close()
node = nil
return err
},
+74 -7
View File
@@ -10,6 +10,7 @@ import (
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/push"
)
// PubSub implements Pub/Sub commands as described in
@@ -21,7 +22,7 @@ import (
type PubSub struct {
opt *Options
newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error)
closeConn func(*pool.Conn) error
mu sync.Mutex
@@ -38,6 +39,12 @@ type PubSub struct {
chOnce sync.Once
msgCh *channel
allCh *channel
// Push notification processor for handling generic push notifications
pushProcessor push.NotificationProcessor
// Cleanup callback for maintenanceNotifications upgrade tracking
onClose func()
}
func (c *PubSub) init() {
@@ -69,10 +76,18 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er
return c.cn, nil
}
if c.opt.Addr == "" {
// TODO(maintenanceNotifications):
// this is probably cluster client
// c.newConn will ignore the addr argument
// will be changed when we have maintenanceNotifications upgrades for cluster clients
c.opt.Addr = internal.RedisNull
}
channels := mapKeys(c.channels)
channels = append(channels, newChannels...)
cn, err := c.newConn(ctx, channels)
cn, err := c.newConn(ctx, c.opt.Addr, channels)
if err != nil {
return nil, err
}
@@ -153,12 +168,31 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo
if c.cn != cn {
return
}
if !cn.IsUsable() || cn.ShouldHandoff() {
c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable"))
}
if isBadConn(err, allowTimeout, c.opt.Addr) {
c.reconnect(ctx, err)
}
}
func (c *PubSub) reconnect(ctx context.Context, reason error) {
if c.cn != nil && c.cn.ShouldHandoff() {
newEndpoint := c.cn.GetHandoffEndpoint()
// If new endpoint is NULL, use the original address
if newEndpoint == internal.RedisNull {
newEndpoint = c.opt.Addr
}
if newEndpoint != "" {
// Update the address in the options
oldAddr := c.cn.RemoteAddr().String()
c.opt.Addr = newEndpoint
internal.Logger.Printf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr)
}
}
_ = c.closeTheCn(reason)
_, _ = c.conn(ctx, nil)
}
@@ -167,9 +201,6 @@ func (c *PubSub) closeTheCn(reason error) error {
if c.cn == nil {
return nil
}
if !c.closed {
internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
}
err := c.closeConn(c.cn)
c.cn = nil
return err
@@ -185,6 +216,11 @@ func (c *PubSub) Close() error {
c.closed = true
close(c.exit)
// Call cleanup callback if set
if c.onClose != nil {
c.onClose()
}
return c.closeTheCn(pool.ErrClosed)
}
@@ -429,16 +465,20 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
}
// Don't hold the lock to allow subscriptions and pings.
cn, err := c.connWithLock(ctx)
if err != nil {
return nil, err
}
err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
// Log the error but don't fail the command execution
// Push notification processing errors shouldn't break normal Redis operations
internal.Logger.Printf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err)
}
return c.cmd.readReply(rd)
})
c.releaseConnWithLock(ctx, cn, err, timeout > 0)
if err != nil {
@@ -451,6 +491,12 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
// Receive returns a message as a Subscription, Message, Pong or error.
// See PubSub example for details. This is low-level API and in most cases
// Channel should be used instead.
// Receive returns a message as a Subscription, Message, Pong, or an error.
// See PubSub example for details. This is a low-level API and in most cases
// Channel should be used instead.
// This method blocks until a message is received or an error occurs.
// It may return early with an error if the context is canceled, the connection fails,
// or other internal errors occur.
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
return c.ReceiveTimeout(ctx, 0)
}
@@ -532,6 +578,27 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac
return c.allCh.allCh
}
func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
// Only process push notifications for RESP3 connections with a processor
if c.opt.Protocol != 3 || c.pushProcessor == nil {
return nil
}
// Create handler context with client, connection pool, and connection information
handlerCtx := c.pushNotificationHandlerContext(cn)
return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd)
}
func (c *PubSub) pushNotificationHandlerContext(cn *pool.Conn) push.NotificationHandlerContext {
// PubSub doesn't have a client or connection pool, so we pass nil for those
// PubSub connections are blocking
return push.NotificationHandlerContext{
PubSub: c,
Conn: cn,
IsBlocking: true,
}
}
type ChannelOption func(c *channel)
// WithChannelSize specifies the Go chan size that is used to buffer incoming messages.
+176
View File
@@ -0,0 +1,176 @@
package push
import (
"errors"
"fmt"
)
// Push notification error definitions
// This file contains all error types and messages used by the push notification system
// Error reason constants
const (
// HandlerReasons
ReasonHandlerNil = "handler cannot be nil"
ReasonHandlerExists = "cannot overwrite existing handler"
ReasonHandlerProtected = "handler is protected"
// ProcessorReasons
ReasonPushNotificationsDisabled = "push notifications are disabled"
)
// ProcessorType represents the type of processor involved in the error
// defined as a custom type for better readability and easier maintenance
type ProcessorType string
const (
// ProcessorTypes
ProcessorTypeProcessor = ProcessorType("processor")
ProcessorTypeVoidProcessor = ProcessorType("void_processor")
ProcessorTypeCustom = ProcessorType("custom")
)
// ProcessorOperation represents the operation being performed by the processor
// defined as a custom type for better readability and easier maintenance
type ProcessorOperation string
const (
// ProcessorOperations
ProcessorOperationProcess = ProcessorOperation("process")
ProcessorOperationRegister = ProcessorOperation("register")
ProcessorOperationUnregister = ProcessorOperation("unregister")
ProcessorOperationUnknown = ProcessorOperation("unknown")
)
// Common error variables for reuse
var (
// ErrHandlerNil is returned when attempting to register a nil handler
ErrHandlerNil = errors.New(ReasonHandlerNil)
)
// Registry errors
// ErrHandlerExists creates an error for when attempting to overwrite an existing handler
func ErrHandlerExists(pushNotificationName string) error {
return NewHandlerError(ProcessorOperationRegister, pushNotificationName, ReasonHandlerExists, nil)
}
// ErrProtectedHandler creates an error for when attempting to unregister a protected handler
func ErrProtectedHandler(pushNotificationName string) error {
return NewHandlerError(ProcessorOperationUnregister, pushNotificationName, ReasonHandlerProtected, nil)
}
// VoidProcessor errors
// ErrVoidProcessorRegister creates an error for when attempting to register a handler on void processor
func ErrVoidProcessorRegister(pushNotificationName string) error {
return NewProcessorError(ProcessorTypeVoidProcessor, ProcessorOperationRegister, pushNotificationName, ReasonPushNotificationsDisabled, nil)
}
// ErrVoidProcessorUnregister creates an error for when attempting to unregister a handler on void processor
func ErrVoidProcessorUnregister(pushNotificationName string) error {
return NewProcessorError(ProcessorTypeVoidProcessor, ProcessorOperationUnregister, pushNotificationName, ReasonPushNotificationsDisabled, nil)
}
// Error type definitions for advanced error handling
// HandlerError represents errors related to handler operations
type HandlerError struct {
Operation ProcessorOperation
PushNotificationName string
Reason string
Err error
}
func (e *HandlerError) Error() string {
if e.Err != nil {
return fmt.Sprintf("handler %s failed for '%s': %s (%v)", e.Operation, e.PushNotificationName, e.Reason, e.Err)
}
return fmt.Sprintf("handler %s failed for '%s': %s", e.Operation, e.PushNotificationName, e.Reason)
}
func (e *HandlerError) Unwrap() error {
return e.Err
}
// NewHandlerError creates a new HandlerError
func NewHandlerError(operation ProcessorOperation, pushNotificationName, reason string, err error) *HandlerError {
return &HandlerError{
Operation: operation,
PushNotificationName: pushNotificationName,
Reason: reason,
Err: err,
}
}
// ProcessorError represents errors related to processor operations
type ProcessorError struct {
ProcessorType ProcessorType // "processor", "void_processor"
Operation ProcessorOperation // "process", "register", "unregister"
PushNotificationName string // Name of the push notification involved
Reason string
Err error
}
func (e *ProcessorError) Error() string {
notifInfo := ""
if e.PushNotificationName != "" {
notifInfo = fmt.Sprintf(" for '%s'", e.PushNotificationName)
}
if e.Err != nil {
return fmt.Sprintf("%s %s failed%s: %s (%v)", e.ProcessorType, e.Operation, notifInfo, e.Reason, e.Err)
}
return fmt.Sprintf("%s %s failed%s: %s", e.ProcessorType, e.Operation, notifInfo, e.Reason)
}
func (e *ProcessorError) Unwrap() error {
return e.Err
}
// NewProcessorError creates a new ProcessorError
func NewProcessorError(processorType ProcessorType, operation ProcessorOperation, pushNotificationName, reason string, err error) *ProcessorError {
return &ProcessorError{
ProcessorType: processorType,
Operation: operation,
PushNotificationName: pushNotificationName,
Reason: reason,
Err: err,
}
}
// Helper functions for common error scenarios
// IsHandlerNilError checks if an error is due to a nil handler
func IsHandlerNilError(err error) bool {
return errors.Is(err, ErrHandlerNil)
}
// IsHandlerExistsError checks if an error is due to attempting to overwrite an existing handler.
// This function works correctly even when the error is wrapped.
func IsHandlerExistsError(err error) bool {
var handlerErr *HandlerError
if errors.As(err, &handlerErr) {
return handlerErr.Operation == ProcessorOperationRegister && handlerErr.Reason == ReasonHandlerExists
}
return false
}
// IsProtectedHandlerError checks if an error is due to attempting to unregister a protected handler.
// This function works correctly even when the error is wrapped.
func IsProtectedHandlerError(err error) bool {
var handlerErr *HandlerError
if errors.As(err, &handlerErr) {
return handlerErr.Operation == ProcessorOperationUnregister && handlerErr.Reason == ReasonHandlerProtected
}
return false
}
// IsVoidProcessorError checks if an error is due to void processor operations.
// This function works correctly even when the error is wrapped.
func IsVoidProcessorError(err error) bool {
var procErr *ProcessorError
if errors.As(err, &procErr) {
return procErr.ProcessorType == ProcessorTypeVoidProcessor && procErr.Reason == ReasonPushNotificationsDisabled
}
return false
}
+14
View File
@@ -0,0 +1,14 @@
package push
import (
"context"
)
// NotificationHandler defines the interface for push notification handlers.
type NotificationHandler interface {
// HandlePushNotification processes a push notification with context information.
// The handlerCtx provides information about the client, connection pool, and connection
// on which the notification was received, allowing handlers to make informed decisions.
// Returns an error if the notification could not be handled.
HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error
}
+44
View File
@@ -0,0 +1,44 @@
package push
// No imports needed for this file
// NotificationHandlerContext provides context information about where a push notification was received.
// This struct allows handlers to make informed decisions based on the source of the notification
// with strongly typed access to different client types using concrete types.
type NotificationHandlerContext struct {
// Client is the Redis client instance that received the notification.
// It is interface to both allow for future expansion and to avoid
// circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types:
// - *redis.baseClient
// - *redis.Client
// - *redis.ClusterClient
// - *redis.Conn
Client interface{}
// ConnPool is the connection pool from which the connection was obtained.
// It is interface to both allow for future expansion and to avoid
// circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types:
// - *pool.ConnPool
// - *pool.SingleConnPool
// - *pool.StickyConnPool
ConnPool interface{}
// PubSub is the PubSub instance that received the notification.
// It is interface to both allow for future expansion and to avoid
// circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types:
// - *redis.PubSub
PubSub interface{}
// Conn is the specific connection on which the notification was received.
// It is interface to both allow for future expansion and to avoid
// circular dependencies. The developer is responsible for type assertion.
// It can be one of the following types:
// - *pool.Conn
Conn interface{}
// IsBlocking indicates if the notification was received on a blocking connection.
IsBlocking bool
}
+203
View File
@@ -0,0 +1,203 @@
package push
import (
"context"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/proto"
)
// NotificationProcessor defines the interface for push notification processors.
type NotificationProcessor interface {
// GetHandler returns the handler for a specific push notification name.
GetHandler(pushNotificationName string) NotificationHandler
// ProcessPendingNotifications checks for and processes any pending push notifications.
// To be used when it is known that there are notifications on the socket.
// It will try to read from the socket and if it is empty - it may block.
ProcessPendingNotifications(ctx context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error
// RegisterHandler registers a handler for a specific push notification name.
RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error
// UnregisterHandler removes a handler for a specific push notification name.
UnregisterHandler(pushNotificationName string) error
}
// Processor handles push notifications with a registry of handlers
type Processor struct {
registry *Registry
}
// NewProcessor creates a new push notification processor
func NewProcessor() *Processor {
return &Processor{
registry: NewRegistry(),
}
}
// GetHandler returns the handler for a specific push notification name
func (p *Processor) GetHandler(pushNotificationName string) NotificationHandler {
return p.registry.GetHandler(pushNotificationName)
}
// RegisterHandler registers a handler for a specific push notification name
func (p *Processor) RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error {
return p.registry.RegisterHandler(pushNotificationName, handler, protected)
}
// UnregisterHandler removes a handler for a specific push notification name
func (p *Processor) UnregisterHandler(pushNotificationName string) error {
return p.registry.UnregisterHandler(pushNotificationName)
}
// ProcessPendingNotifications checks for and processes any pending push notifications
// This method should be called by the client in WithReader before reading the reply
// It will try to read from the socket and if it is empty - it may block.
func (p *Processor) ProcessPendingNotifications(ctx context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error {
if rd == nil {
return nil
}
for {
// Check if there's data available to read
replyType, err := rd.PeekReplyType()
if err != nil {
// No more data available or error reading
// if timeout, it will be handled by the caller
break
}
// Only process push notifications (arrays starting with >)
if replyType != proto.RespPush {
break
}
// see if we should skip this notification
notificationName, err := rd.PeekPushNotificationName()
if err != nil {
break
}
if willHandleNotificationInClient(notificationName) {
break
}
// Read the push notification
reply, err := rd.ReadReply()
if err != nil {
internal.Logger.Printf(ctx, "push: error reading push notification: %v", err)
break
}
// Convert to slice of interfaces
notification, ok := reply.([]interface{})
if !ok {
break
}
// Handle the notification directly
if len(notification) > 0 {
// Extract the notification type (first element)
if notificationType, ok := notification[0].(string); ok {
// Get the handler for this notification type
if handler := p.registry.GetHandler(notificationType); handler != nil {
// Handle the notification
err := handler.HandlePushNotification(ctx, handlerCtx, notification)
if err != nil {
internal.Logger.Printf(ctx, "push: error handling push notification: %v", err)
}
}
}
}
}
return nil
}
// VoidProcessor discards all push notifications without processing them
type VoidProcessor struct{}
// NewVoidProcessor creates a new void push notification processor
func NewVoidProcessor() *VoidProcessor {
return &VoidProcessor{}
}
// GetHandler returns nil for void processor since it doesn't maintain handlers
func (v *VoidProcessor) GetHandler(_ string) NotificationHandler {
return nil
}
// RegisterHandler returns an error for void processor since it doesn't maintain handlers
func (v *VoidProcessor) RegisterHandler(pushNotificationName string, _ NotificationHandler, _ bool) error {
return ErrVoidProcessorRegister(pushNotificationName)
}
// UnregisterHandler returns an error for void processor since it doesn't maintain handlers
func (v *VoidProcessor) UnregisterHandler(pushNotificationName string) error {
return ErrVoidProcessorUnregister(pushNotificationName)
}
// ProcessPendingNotifications for VoidProcessor does nothing since push notifications
// are only available in RESP3 and this processor is used for RESP2 connections.
// This avoids unnecessary buffer scanning overhead.
// It does however read and discard all push notifications from the buffer to avoid
// them being interpreted as a reply.
// This method should be called by the client in WithReader before reading the reply
// to be sure there are no buffered push notifications.
// It will try to read from the socket and if it is empty - it may block.
func (v *VoidProcessor) ProcessPendingNotifications(_ context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error {
// read and discard all push notifications
if rd == nil {
return nil
}
for {
// Check if there's data available to read
replyType, err := rd.PeekReplyType()
if err != nil {
// No more data available or error reading
// if timeout, it will be handled by the caller
break
}
// Only process push notifications (arrays starting with >)
if replyType != proto.RespPush {
break
}
// see if we should skip this notification
notificationName, err := rd.PeekPushNotificationName()
if err != nil {
break
}
if willHandleNotificationInClient(notificationName) {
break
}
// Read the push notification
_, err = rd.ReadReply()
if err != nil {
internal.Logger.Printf(context.Background(), "push: error reading push notification: %v", err)
return nil
}
}
return nil
}
// willHandleNotificationInClient checks if a notification type should be ignored by the push notification
// processor and handled by other specialized systems instead (pub/sub, streams, keyspace, etc.).
func willHandleNotificationInClient(notificationType string) bool {
switch notificationType {
// Pub/Sub notifications - handled by pub/sub system
case "message", // Regular pub/sub message
"pmessage", // Pattern pub/sub message
"subscribe", // Subscription confirmation
"unsubscribe", // Unsubscription confirmation
"psubscribe", // Pattern subscription confirmation
"punsubscribe", // Pattern unsubscription confirmation
"smessage", // Sharded pub/sub message (Redis 7.0+)
"ssubscribe", // Sharded subscription confirmation
"sunsubscribe": // Sharded unsubscription confirmation
return true
default:
return false
}
}
+7
View File
@@ -0,0 +1,7 @@
// Package push provides push notifications for Redis.
// This is an EXPERIMENTAL API for handling push notifications from Redis.
// It is not yet stable and may change in the future.
// Although this is in a public package, in its current form public use is not advised.
// Pending push notifications should be processed before executing any readReply from the connection
// as per RESP3 specification push notifications can be sent at any time.
package push
+61
View File
@@ -0,0 +1,61 @@
package push
import (
"sync"
)
// Registry manages push notification handlers
type Registry struct {
mu sync.RWMutex
handlers map[string]NotificationHandler
protected map[string]bool
}
// NewRegistry creates a new push notification registry
func NewRegistry() *Registry {
return &Registry{
handlers: make(map[string]NotificationHandler),
protected: make(map[string]bool),
}
}
// RegisterHandler registers a handler for a specific push notification name
func (r *Registry) RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error {
if handler == nil {
return ErrHandlerNil
}
r.mu.Lock()
defer r.mu.Unlock()
// Check if handler already exists
if _, exists := r.protected[pushNotificationName]; exists {
return ErrHandlerExists(pushNotificationName)
}
r.handlers[pushNotificationName] = handler
r.protected[pushNotificationName] = protected
return nil
}
// GetHandler returns the handler for a specific push notification name
func (r *Registry) GetHandler(pushNotificationName string) NotificationHandler {
r.mu.RLock()
defer r.mu.RUnlock()
return r.handlers[pushNotificationName]
}
// UnregisterHandler removes a handler for a specific push notification name
func (r *Registry) UnregisterHandler(pushNotificationName string) error {
r.mu.Lock()
defer r.mu.Unlock()
// Check if handler is protected
if protected, exists := r.protected[pushNotificationName]; exists && protected {
return ErrProtectedHandler(pushNotificationName)
}
delete(r.handlers, pushNotificationName)
delete(r.protected, pushNotificationName)
return nil
}
+21
View File
@@ -0,0 +1,21 @@
package redis
import (
"github.com/redis/go-redis/v9/push"
)
// NewPushNotificationProcessor creates a new push notification processor
// This processor maintains a registry of handlers and processes push notifications
// It is used for RESP3 connections where push notifications are available
func NewPushNotificationProcessor() push.NotificationProcessor {
return push.NewProcessor()
}
// NewVoidPushNotificationProcessor creates a new void push notification processor
// This processor does not maintain any handlers and always returns nil for all operations
// It is used for RESP2 connections where push notifications are not available
// It can also be used to disable push notifications for RESP3 connections, where
// it will discard all push notifications without processing them
func NewVoidPushNotificationProcessor() push.NotificationProcessor {
return push.NewVoidProcessor()
}
+557 -77
View File
@@ -11,9 +11,12 @@ import (
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/auth/streaming"
"github.com/redis/go-redis/v9/internal/hscan"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/push"
)
// Scanner internal/hscan.Scanner exposed interface.
@@ -23,10 +26,16 @@ type Scanner = hscan.Scanner
const Nil = proto.Nil
// SetLogger set custom log
// Use with VoidLogger to disable logging.
func SetLogger(logger internal.Logging) {
internal.Logger = logger
}
// SetLogLevel sets the log level for the library.
func SetLogLevel(logLevel internal.LogLevelT) {
internal.LogLevel = logLevel
}
//------------------------------------------------------------------------------
type Hook interface {
@@ -202,16 +211,39 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
//------------------------------------------------------------------------------
type baseClient struct {
opt *Options
connPool pool.Pooler
opt *Options
optLock sync.RWMutex
connPool pool.Pooler
pubSubPool *pool.PubSubPool
hooksMixin
onClose func() error // hook called when client is closed
// Push notification processing
pushProcessor push.NotificationProcessor
// Maintenance notifications manager
maintNotificationsManager *maintnotifications.Manager
maintNotificationsManagerLock sync.RWMutex
// streamingCredentialsManager is used to manage streaming credentials
streamingCredentialsManager *streaming.Manager
}
func (c *baseClient) clone() *baseClient {
clone := *c
return &clone
c.maintNotificationsManagerLock.RLock()
maintNotificationsManager := c.maintNotificationsManager
c.maintNotificationsManagerLock.RUnlock()
clone := &baseClient{
opt: c.opt,
connPool: c.connPool,
onClose: c.onClose,
pushProcessor: c.pushProcessor,
maintNotificationsManager: maintNotificationsManager,
streamingCredentialsManager: c.streamingCredentialsManager,
}
return clone
}
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
@@ -229,21 +261,6 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
}
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
cn, err := c.connPool.NewConn(ctx)
if err != nil {
return nil, err
}
err = c.initConn(ctx, cn)
if err != nil {
_ = c.connPool.CloseConn(cn)
return nil, err
}
return cn, nil
}
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
if c.opt.Limiter != nil {
err := c.opt.Limiter.Allow()
@@ -269,7 +286,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return nil, err
}
if cn.Inited {
if cn.IsInited() {
return cn, nil
}
@@ -281,35 +298,39 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
return nil, err
}
// initConn will transition to IDLE state, so we need to acquire it
// before returning it to the user.
if !cn.TryAcquire() {
return nil, fmt.Errorf("redis: connection is not usable")
}
return cn, nil
}
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
return auth.NewReAuthCredentialsListener(
c.reAuthConnection(poolCn),
c.onAuthenticationErr(poolCn),
)
}
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
return func(credentials auth.Credentials) error {
func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error {
return func(poolCn *pool.Conn, credentials auth.Credentials) error {
var err error
username, password := credentials.BasicAuth()
// Use background context - timeout is handled by ReadTimeout in WithReader/WithWriter
ctx := context.Background()
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
// hooksMixin are intentionally empty here
cn := newConn(c.opt, connPool, nil)
// Pass hooks so that reauth commands are recorded/traced
cn := newConn(c.opt, connPool, &c.hooksMixin)
if username != "" {
err = cn.AuthACL(ctx, username, password).Err()
} else {
err = cn.Auth(ctx, password).Err()
}
return err
}
}
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
return func(err error) {
func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) {
return func(poolCn *pool.Conn, err error) {
if err != nil {
if isBadConn(err, false, c.opt.Addr) {
// Close the connection to force a reconnection.
@@ -351,29 +372,113 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
}
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
if cn.Inited {
// This function is called in two scenarios:
// 1. First-time init: Connection is in CREATED state (from pool.Get())
// - We need to transition CREATED → INITIALIZING and do the initialization
// - If another goroutine is already initializing, we WAIT for it to finish
// 2. Re-initialization: Connection is in INITIALIZING state (from SetNetConnAndInitConn())
// - We're already in INITIALIZING, so just proceed with initialization
currentState := cn.GetStateMachine().GetState()
// Fast path: Check if already initialized (IDLE or IN_USE)
if currentState == pool.StateIdle || currentState == pool.StateInUse {
return nil
}
var err error
cn.Inited = true
// If in CREATED state, try to transition to INITIALIZING
if currentState == pool.StateCreated {
finalState, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateCreated}, pool.StateInitializing)
if err != nil {
// Another goroutine is initializing or connection is in unexpected state
// Check what state we're in now
if finalState == pool.StateIdle || finalState == pool.StateInUse {
// Already initialized by another goroutine
return nil
}
if finalState == pool.StateInitializing {
// Another goroutine is initializing - WAIT for it to complete
// Use a context with timeout = min(remaining command timeout, DialTimeout)
// This prevents waiting too long while respecting the caller's deadline
var waitCtx context.Context
var cancel context.CancelFunc
dialTimeout := c.opt.DialTimeout
if cmdDeadline, hasCmdDeadline := ctx.Deadline(); hasCmdDeadline {
// Calculate remaining time until command deadline
remainingTime := time.Until(cmdDeadline)
// Use the minimum of remaining time and DialTimeout
if remainingTime < dialTimeout {
// Command deadline is sooner, use it
waitCtx = ctx
} else {
// DialTimeout is shorter, cap the wait at DialTimeout
waitCtx, cancel = context.WithTimeout(ctx, dialTimeout)
}
} else {
// No command deadline, use DialTimeout to prevent waiting indefinitely
waitCtx, cancel = context.WithTimeout(ctx, dialTimeout)
}
if cancel != nil {
defer cancel()
}
finalState, err := cn.GetStateMachine().AwaitAndTransition(
waitCtx,
[]pool.ConnState{pool.StateIdle, pool.StateInUse},
pool.StateIdle, // Target is IDLE (but we're already there, so this is a no-op)
)
if err != nil {
return err
}
// Verify we're now initialized
if finalState == pool.StateIdle || finalState == pool.StateInUse {
return nil
}
// Unexpected state after waiting
return fmt.Errorf("connection in unexpected state after initialization: %s", finalState)
}
// Unexpected state (CLOSED, UNUSABLE, etc.)
return err
}
}
// At this point, we're in INITIALIZING state and we own the initialization
// If we fail, we must transition to CLOSED
var initErr error
connPool := pool.NewSingleConnPool(c.connPool, cn)
conn := newConn(c.opt, connPool, &c.hooksMixin)
username, password := "", ""
if c.opt.StreamingCredentialsProvider != nil {
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
Subscribe(c.newReAuthCredentialsListener(cn))
if err != nil {
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
credListener, initErr := c.streamingCredentialsManager.Listener(
cn,
c.reAuthConnection(),
c.onAuthenticationErr(),
)
if initErr != nil {
cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to create credentials listener: %w", initErr)
}
credentials, unsubscribeFromCredentialsProvider, initErr := c.opt.StreamingCredentialsProvider.
Subscribe(credListener)
if initErr != nil {
cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to subscribe to streaming credentials: %w", initErr)
}
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
cn.SetOnClose(unsubscribeFromCredentialsProvider)
username, password = credentials.BasicAuth()
} else if c.opt.CredentialsProviderContext != nil {
username, password, err = c.opt.CredentialsProviderContext(ctx)
if err != nil {
return fmt.Errorf("failed to get credentials from context provider: %w", err)
username, password, initErr = c.opt.CredentialsProviderContext(ctx)
if initErr != nil {
cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to get credentials from context provider: %w", initErr)
}
} else if c.opt.CredentialsProvider != nil {
username, password = c.opt.CredentialsProvider()
@@ -383,9 +488,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
// for redis-server versions that do not support the HELLO command,
// RESP2 will continue to be used.
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
if initErr = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); initErr == nil {
// Authentication successful with HELLO command
} else if !isRedisError(err) {
} else if !isRedisError(initErr) {
// When the server responds with the RESP protocol and the result is not a normal
// execution result of the HELLO command, we consider it to be an indication that
// the server does not support the HELLO command.
@@ -393,20 +498,22 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
// or it could be DragonflyDB or a third-party redis-proxy. They all respond
// with different error string results for unsupported commands, making it
// difficult to rely on error strings to determine all results.
return err
cn.GetStateMachine().Transition(pool.StateClosed)
return initErr
} else if password != "" {
// Try legacy AUTH command if HELLO failed
if username != "" {
err = conn.AuthACL(ctx, username, password).Err()
initErr = conn.AuthACL(ctx, username, password).Err()
} else {
err = conn.Auth(ctx, password).Err()
initErr = conn.Auth(ctx, password).Err()
}
if err != nil {
return fmt.Errorf("failed to authenticate: %w", err)
if initErr != nil {
cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to authenticate: %w", initErr)
}
}
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
_, initErr = conn.Pipelined(ctx, func(pipe Pipeliner) error {
if c.opt.DB > 0 {
pipe.Select(ctx, c.opt.DB)
}
@@ -421,8 +528,58 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
return nil
})
if err != nil {
return fmt.Errorf("failed to initialize connection options: %w", err)
if initErr != nil {
cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to initialize connection options: %w", initErr)
}
// Enable maintnotifications if maintnotifications are configured
c.optLock.RLock()
maintNotifEnabled := c.opt.MaintNotificationsConfig != nil && c.opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled
protocol := c.opt.Protocol
endpointType := c.opt.MaintNotificationsConfig.EndpointType
c.optLock.RUnlock()
var maintNotifHandshakeErr error
if maintNotifEnabled && protocol == 3 {
maintNotifHandshakeErr = conn.ClientMaintNotifications(
ctx,
true,
endpointType.String(),
).Err()
if maintNotifHandshakeErr != nil {
if !isRedisError(maintNotifHandshakeErr) {
// if not redis error, fail the connection
cn.GetStateMachine().Transition(pool.StateClosed)
return maintNotifHandshakeErr
}
c.optLock.Lock()
// handshake failed - check and modify config atomically
switch c.opt.MaintNotificationsConfig.Mode {
case maintnotifications.ModeEnabled:
// enabled mode, fail the connection
c.optLock.Unlock()
cn.GetStateMachine().Transition(pool.StateClosed)
return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr)
default: // will handle auto and any other
// Disabling logging here as it's too noisy.
// TODO: Enable when we have a better logging solution for log levels
// internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr)
c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled
c.optLock.Unlock()
// auto mode, disable maintnotifications and continue
if initErr := c.disableMaintNotificationsUpgrades(); initErr != nil {
// Log error but continue - auto mode should be resilient
internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", initErr)
}
}
} else {
// handshake was executed successfully
// to make sure that the handshake will be executed on other connections as well if it was successfully
// executed on this connection, we will force the handshake to be executed on all connections
c.optLock.Lock()
c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeEnabled
c.optLock.Unlock()
}
}
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
@@ -436,13 +593,31 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
p.ClientSetInfo(ctx, WithLibraryVersion(libVer))
// Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid
// out of order responses later on.
if _, err = p.Exec(ctx); err != nil && !isRedisError(err) {
return err
if _, initErr = p.Exec(ctx); initErr != nil && !isRedisError(initErr) {
cn.GetStateMachine().Transition(pool.StateClosed)
return initErr
}
}
// Set the connection initialization function for potential reconnections
// This must be set before transitioning to IDLE so that handoff/reauth can use it
cn.SetInitConnFunc(c.createInitConnFunc())
// Initialization succeeded - transition to IDLE state
// This marks the connection as initialized and ready for use
// NOTE: The connection is still owned by the calling goroutine at this point
// and won't be available to other goroutines until it's Put() back into the pool
cn.GetStateMachine().Transition(pool.StateIdle)
// Call OnConnect hook if configured
// The connection is in IDLE state but still owned by this goroutine
// If OnConnect needs to send commands, it can use the connection safely
if c.opt.OnConnect != nil {
return c.opt.OnConnect(ctx, conn)
if initErr = c.opt.OnConnect(ctx, conn); initErr != nil {
// OnConnect failed - transition to closed
cn.GetStateMachine().Transition(pool.StateClosed)
return initErr
}
}
return nil
@@ -456,6 +631,10 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)
if isBadConn(err, false, c.opt.Addr) {
c.connPool.Remove(ctx, cn, err)
} else {
// process any pending push notifications before returning the connection to the pool
if err := c.processPushNotifications(ctx, cn); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err)
}
c.connPool.Put(ctx, cn)
}
}
@@ -497,16 +676,16 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
return lastErr
}
func (c *baseClient) assertUnstableCommand(cmd Cmder) bool {
func (c *baseClient) assertUnstableCommand(cmd Cmder) (bool, error) {
switch cmd.(type) {
case *AggregateCmd, *FTInfoCmd, *FTSpellCheckCmd, *FTSearchCmd, *FTSynDumpCmd:
if c.opt.UnstableResp3 {
return true
return true, nil
} else {
panic("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3 . See the [README](https://github.com/redis/go-redis/blob/master/README.md) and the release notes for guidance.")
return false, fmt.Errorf("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3. See the README and the release notes for guidance")
}
default:
return false
return false, nil
}
}
@@ -519,6 +698,11 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
retryTimeout := uint32(0)
if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
// Process any pending push notifications before executing the command
if err := c.processPushNotifications(ctx, cn); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err)
}
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
}); err != nil {
@@ -527,10 +711,22 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
}
readReplyFunc := cmd.readReply
// Apply unstable RESP3 search module.
if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) {
readReplyFunc = cmd.readRawReply
if c.opt.Protocol != 2 {
useRawReply, err := c.assertUnstableCommand(cmd)
if err != nil {
return err
}
if useRawReply {
readReplyFunc = cmd.readRawReply
}
}
if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), readReplyFunc); err != nil {
if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
return readReplyFunc(rd)
}); err != nil {
if cmd.readTimeout() == nil {
atomic.StoreUint32(&retryTimeout, 1)
} else {
@@ -573,19 +769,76 @@ func (c *baseClient) context(ctx context.Context) context.Context {
return context.Background()
}
// createInitConnFunc creates a connection initialization function that can be used for reconnections.
func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error {
return func(ctx context.Context, cn *pool.Conn) error {
return c.initConn(ctx, cn)
}
}
// enableMaintNotificationsUpgrades initializes the maintnotifications upgrade manager and pool hook.
// This function is called during client initialization.
// will register push notification handlers for all maintenance upgrade events.
// will start background workers for handoff processing in the pool hook.
func (c *baseClient) enableMaintNotificationsUpgrades() error {
// Create client adapter
clientAdapterInstance := newClientAdapter(c)
// Create maintnotifications manager directly
manager, err := maintnotifications.NewManager(clientAdapterInstance, c.connPool, c.opt.MaintNotificationsConfig)
if err != nil {
return err
}
// Set the manager reference and initialize pool hook
c.maintNotificationsManagerLock.Lock()
c.maintNotificationsManager = manager
c.maintNotificationsManagerLock.Unlock()
// Initialize pool hook (safe to call without lock since manager is now set)
manager.InitPoolHook(c.dialHook)
return nil
}
func (c *baseClient) disableMaintNotificationsUpgrades() error {
c.maintNotificationsManagerLock.Lock()
defer c.maintNotificationsManagerLock.Unlock()
// Close the maintnotifications manager
if c.maintNotificationsManager != nil {
// Closing the manager will also shutdown the pool hook
// and remove it from the pool
c.maintNotificationsManager.Close()
c.maintNotificationsManager = nil
}
return nil
}
// Close closes the client, releasing any open resources.
//
// It is rare to Close a Client, as the Client is meant to be
// long-lived and shared between many goroutines.
func (c *baseClient) Close() error {
var firstErr error
// Close maintnotifications manager first
if err := c.disableMaintNotificationsUpgrades(); err != nil {
firstErr = err
}
if c.onClose != nil {
if err := c.onClose(); err != nil {
if err := c.onClose(); err != nil && firstErr == nil {
firstErr = err
}
}
if err := c.connPool.Close(); err != nil && firstErr == nil {
firstErr = err
if c.connPool != nil {
if err := c.connPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
if c.pubSubPool != nil {
if err := c.pubSubPool.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
@@ -625,12 +878,19 @@ func (c *baseClient) generalProcessPipeline(
// Enable retries by default to retry dial errors returned by withConn.
canRetry := true
lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
// Process any pending push notifications before executing the pipeline
if err := c.processPushNotifications(ctx, cn); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before processing pipeline: %v", err)
}
var err error
canRetry, err = p(ctx, cn, cmds)
return err
})
if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) {
setCmdsErr(cmds, lastErr)
// The error should be set here only when failing to obtain the conn.
if !isRedisError(lastErr) {
setCmdsErr(cmds, lastErr)
}
return lastErr
}
}
@@ -640,6 +900,11 @@ func (c *baseClient) generalProcessPipeline(
func (c *baseClient) pipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
// Process any pending push notifications before executing the pipeline
if err := c.processPushNotifications(ctx, cn); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before writing pipeline: %v", err)
}
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
}); err != nil {
@@ -648,7 +913,8 @@ func (c *baseClient) pipelineProcessCmds(
}
if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
return pipelineReadCmds(rd, cmds)
// read all replies
return c.pipelineReadCmds(ctx, cn, rd, cmds)
}); err != nil {
return true, err
}
@@ -656,8 +922,12 @@ func (c *baseClient) pipelineProcessCmds(
return false, nil
}
func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
func (c *baseClient) pipelineReadCmds(ctx context.Context, cn *pool.Conn, rd *proto.Reader, cmds []Cmder) error {
for i, cmd := range cmds {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
err := cmd.readReply(rd)
cmd.SetErr(err)
if err != nil && !isRedisError(err) {
@@ -672,6 +942,11 @@ func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
func (c *baseClient) txPipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
// Process any pending push notifications before executing the transaction pipeline
if err := c.processPushNotifications(ctx, cn); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err)
}
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
}); err != nil {
@@ -684,12 +959,13 @@ func (c *baseClient) txPipelineProcessCmds(
// Trim multi and exec.
trimmedCmds := cmds[1 : len(cmds)-1]
if err := txPipelineReadQueued(rd, statusCmd, trimmedCmds); err != nil {
if err := c.txPipelineReadQueued(ctx, cn, rd, statusCmd, trimmedCmds); err != nil {
setCmdsErr(cmds, err)
return err
}
return pipelineReadCmds(rd, trimmedCmds)
// Read replies.
return c.pipelineReadCmds(ctx, cn, rd, trimmedCmds)
}); err != nil {
return false, err
}
@@ -697,7 +973,13 @@ func (c *baseClient) txPipelineProcessCmds(
return false, nil
}
func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
// txPipelineReadQueued reads queued replies from the Redis server.
// It returns an error if the server returns an error or if the number of replies does not match the number of commands.
func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
// Parse +OK.
if err := statusCmd.readReply(rd); err != nil {
return err
@@ -705,6 +987,10 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
// Parse +QUEUED.
for _, cmd := range cmds {
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
if err := statusCmd.readReply(rd); err != nil {
cmd.SetErr(err)
if !isRedisError(err) {
@@ -713,6 +999,10 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
}
}
// To be sure there are no buffered push notifications, we process them before reading the reply
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
}
// Parse number of replies.
line, err := rd.ReadLine()
if err != nil {
@@ -746,15 +1036,61 @@ func NewClient(opt *Options) *Client {
if opt == nil {
panic("redis: NewClient nil options")
}
// clone to not share options with the caller
opt = opt.clone()
opt.init()
// Push notifications are always enabled for RESP3 (cannot be disabled)
c := Client{
baseClient: &baseClient{
opt: opt,
},
}
c.init()
c.connPool = newConnPool(opt, c.dialHook)
// Initialize push notification processor using shared helper
// Use void processor for RESP2 connections (push notifications not available)
c.pushProcessor = initializePushProcessor(opt)
// set opt push processor for child clients
c.opt.PushNotificationProcessor = c.pushProcessor
// Create connection pools
var err error
c.connPool, err = newConnPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
}
c.pubSubPool, err = newPubSubPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
}
if opt.StreamingCredentialsProvider != nil {
c.streamingCredentialsManager = streaming.NewManager(c.connPool, c.opt.PoolTimeout)
c.connPool.AddPoolHook(c.streamingCredentialsManager.PoolHook())
}
// Initialize maintnotifications first if enabled and protocol is RESP3
if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 {
err := c.enableMaintNotificationsUpgrades()
if err != nil {
internal.Logger.Printf(context.Background(), "failed to initialize maintnotifications: %v", err)
if opt.MaintNotificationsConfig.Mode == maintnotifications.ModeEnabled {
/*
Design decision: panic here to fail fast if maintnotifications cannot be enabled when explicitly requested.
We choose to panic instead of returning an error to avoid breaking the existing client API, which does not expect
an error from NewClient. This ensures that misconfiguration or critical initialization failures are surfaced
immediately, rather than allowing the client to continue in a partially initialized or inconsistent state.
Clients relying on maintnotifications should be aware that initialization errors will cause a panic, and should
handle this accordingly (e.g., via recover or by validating configuration before calling NewClient).
This approach is only used when MaintNotificationsConfig.Mode is MaintNotificationsEnabled, indicating that maintnotifications
upgrades are required for correct operation. In other modes, initialization failures are logged but do not panic.
*/
panic(fmt.Errorf("failed to enable maintnotifications: %w", err))
}
}
}
return &c
}
@@ -791,11 +1127,51 @@ func (c *Client) Options() *Options {
return c.opt
}
// GetMaintNotificationsManager returns the maintnotifications manager instance for monitoring and control.
// Returns nil if maintnotifications are not enabled.
func (c *Client) GetMaintNotificationsManager() *maintnotifications.Manager {
c.maintNotificationsManagerLock.RLock()
defer c.maintNotificationsManagerLock.RUnlock()
return c.maintNotificationsManager
}
// initializePushProcessor initializes the push notification processor for any client type.
// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient.
func initializePushProcessor(opt *Options) push.NotificationProcessor {
// Always use custom processor if provided
if opt.PushNotificationProcessor != nil {
return opt.PushNotificationProcessor
}
// Push notifications are always enabled for RESP3, disabled for RESP2
if opt.Protocol == 3 {
// Create default processor for RESP3 connections
return NewPushNotificationProcessor()
}
// Create void processor for RESP2 connections (push notifications not available)
return NewVoidPushNotificationProcessor()
}
// RegisterPushNotificationHandler registers a handler for a specific push notification name.
// Returns an error if a handler is already registered for this push notification name.
// If protected is true, the handler cannot be unregistered.
func (c *Client) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error {
return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected)
}
// GetPushNotificationHandler returns the handler for a specific push notification name.
// Returns nil if no handler is registered for the given name.
func (c *Client) GetPushNotificationHandler(pushNotificationName string) push.NotificationHandler {
return c.pushProcessor.GetHandler(pushNotificationName)
}
type PoolStats pool.Stats
// PoolStats returns connection pool stats.
func (c *Client) PoolStats() *PoolStats {
stats := c.connPool.Stats()
stats.PubSubStats = *(c.pubSubPool.Stats())
return (*PoolStats)(stats)
}
@@ -830,13 +1206,31 @@ func (c *Client) TxPipeline() Pipeliner {
func (c *Client) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
if err != nil {
return nil, err
}
// will return nil if already initialized
err = c.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
return nil, err
}
// Track connection in PubSubPool
c.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: c.connPool.CloseConn,
closeConn: func(cn *pool.Conn) error {
// Untrack connection from PubSubPool
c.pubSubPool.UntrackConn(cn)
_ = cn.Close()
return nil
},
pushProcessor: c.pushProcessor,
}
pubsub.init()
return pubsub
}
@@ -920,6 +1314,10 @@ func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn
c.hooksMixin = parentHooks.clone()
}
// Initialize push notification processor using shared helper
// Use void processor for RESP2 connections (push notifications not available)
c.pushProcessor = initializePushProcessor(opt)
c.cmdable = c.Process
c.statefulCmdable = c.Process
c.initHooks(hooks{
@@ -938,6 +1336,13 @@ func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
return err
}
// RegisterPushNotificationHandler registers a handler for a specific push notification name.
// Returns an error if a handler is already registered for this push notification name.
// If protected is true, the handler cannot be unregistered.
func (c *Conn) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error {
return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected)
}
func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(ctx, fn)
}
@@ -965,3 +1370,78 @@ func (c *Conn) TxPipeline() Pipeliner {
pipe.init()
return &pipe
}
// processPushNotifications processes all pending push notifications on a connection
// This ensures that cluster topology changes are handled immediately before the connection is used
// This method should be called by the client before using WithReader for command execution
//
// Performance optimization: Skip the expensive MaybeHasData() syscall if a health check
// was performed recently (within 5 seconds). The health check already verified the connection
// is healthy and checked for unexpected data (push notifications).
func (c *baseClient) processPushNotifications(ctx context.Context, cn *pool.Conn) error {
// Only process push notifications for RESP3 connections with a processor
if c.opt.Protocol != 3 || c.pushProcessor == nil {
return nil
}
// Performance optimization: Skip MaybeHasData() syscall if health check was recent
// If the connection was health-checked within the last 5 seconds, we can skip the
// expensive syscall since the health check already verified no unexpected data.
// This is safe because:
// 0. lastHealthCheckNs is set in pool/conn.go:putConn() after a successful health check
// 1. Health check (connCheck) uses the same syscall (Recvfrom with MSG_PEEK)
// 2. If push notifications arrived, they would have been detected by health check
// 3. 5 seconds is short enough that connection state is still fresh
// 4. Push notifications will be processed by the next WithReader call
// used it is set on getConn, so we should use another timer (lastPutAt?)
lastHealthCheckNs := cn.LastPutAtNs()
if lastHealthCheckNs > 0 {
// Use pool's cached time to avoid expensive time.Now() syscall
nowNs := pool.GetCachedTimeNs()
if nowNs-lastHealthCheckNs < int64(5*time.Second) {
// Recent health check confirmed no unexpected data, skip the syscall
return nil
}
}
// Check if there is any data to read before processing
// This is an optimization on UNIX systems where MaybeHasData is a syscall
// On Windows, MaybeHasData always returns true, so this check is a no-op
if !cn.MaybeHasData() {
return nil
}
// Use WithReader to access the reader and process push notifications
// This is critical for maintnotifications to work properly
// NOTE: almost no timeouts are set for this read, so it should not block
// longer than necessary, 10us should be plenty of time to read if there are any push notifications
// on the socket.
return cn.WithReader(ctx, 10*time.Microsecond, func(rd *proto.Reader) error {
// Create handler context with client, connection pool, and connection information
handlerCtx := c.pushNotificationHandlerContext(cn)
return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd)
})
}
// processPendingPushNotificationWithReader processes all pending push notifications on a connection
// This method should be called by the client in WithReader before reading the reply
func (c *baseClient) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
// if we have the reader, we don't need to check for data on the socket, we are waiting
// for either a reply or a push notification, so we can block until we get a reply or reach the timeout
if c.opt.Protocol != 3 || c.pushProcessor == nil {
return nil
}
// Create handler context with client, connection pool, and connection information
handlerCtx := c.pushNotificationHandlerContext(cn)
return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd)
}
// pushNotificationHandlerContext creates a handler context for push notification processing
func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.NotificationHandlerContext {
return push.NotificationHandlerContext{
Client: c,
ConnPool: c.connPool,
Conn: cn, // Wrap in adapter for easier interface access
}
}
+541 -6
View File
@@ -29,6 +29,8 @@ type SearchCmdable interface {
FTDropIndexWithArgs(ctx context.Context, index string, options *FTDropIndexOptions) *StatusCmd
FTExplain(ctx context.Context, index string, query string) *StringCmd
FTExplainWithArgs(ctx context.Context, index string, query string, options *FTExplainOptions) *StringCmd
FTHybrid(ctx context.Context, index string, searchExpr string, vectorField string, vectorData Vector) *FTHybridCmd
FTHybridWithArgs(ctx context.Context, index string, options *FTHybridOptions) *FTHybridCmd
FTInfo(ctx context.Context, index string) *FTInfoCmd
FTSpellCheck(ctx context.Context, index string, query string) *FTSpellCheckCmd
FTSpellCheckWithArgs(ctx context.Context, index string, query string, options *FTSpellCheckOptions) *FTSpellCheckCmd
@@ -344,6 +346,92 @@ type FTSearchOptions struct {
DialectVersion int
}
// FTHybridCombineMethod represents the fusion method for combining search and vector results
type FTHybridCombineMethod string
const (
FTHybridCombineRRF FTHybridCombineMethod = "RRF"
FTHybridCombineLinear FTHybridCombineMethod = "LINEAR"
FTHybridCombineFunction FTHybridCombineMethod = "FUNCTION"
)
// FTHybridSearchExpression represents a search expression in hybrid search
type FTHybridSearchExpression struct {
Query string
Scorer string
ScorerParams []interface{}
YieldScoreAs string
}
type FTHybridVectorMethod = string
const (
KNN FTHybridCombineMethod = "KNN"
RANGE FTHybridCombineMethod = "RANGE"
)
// FTHybridVectorExpression represents a vector expression in hybrid search
type FTHybridVectorExpression struct {
VectorField string
VectorData Vector
Method FTHybridVectorMethod
MethodParams []interface{}
Filter string
YieldScoreAs string
}
// FTHybridCombineOptions represents options for result fusion
type FTHybridCombineOptions struct {
Method FTHybridCombineMethod
Count int
Window int // For RRF
Constant float64 // For RRF
Alpha float64 // For LINEAR
Beta float64 // For LINEAR
YieldScoreAs string
}
// FTHybridGroupBy represents GROUP BY functionality
type FTHybridGroupBy struct {
Count int
Fields []string
ReduceFunc string
ReduceCount int
ReduceParams []interface{}
}
// FTHybridApply represents APPLY functionality
type FTHybridApply struct {
Expression string
AsField string
}
// FTHybridWithCursor represents cursor configuration for hybrid search
type FTHybridWithCursor struct {
Count int // Number of results to return per cursor read
MaxIdle int // Maximum idle time in milliseconds before cursor is automatically deleted
}
// FTHybridOptions hold options that can be passed to the FT.HYBRID command
type FTHybridOptions struct {
CountExpressions int // Number of search/vector expressions
SearchExpressions []FTHybridSearchExpression // Multiple search expressions
VectorExpressions []FTHybridVectorExpression // Multiple vector expressions
Combine *FTHybridCombineOptions // Fusion step options
Load []string // Projected fields
GroupBy *FTHybridGroupBy // Aggregation grouping
Apply []FTHybridApply // Field transformations
SortBy []FTSearchSortBy // Reuse from FTSearch
Filter string // Post-filter expression
LimitOffset int // Result limiting
Limit int
Params map[string]interface{} // Parameter substitution
ExplainScore bool // Include score explanations
Timeout int // Runtime timeout
WithCursor bool // Enable cursor support for large result sets
WithCursorOptions *FTHybridWithCursor // Cursor configuration options
}
type FTSynDumpResult struct {
Term string
Synonyms []string
@@ -423,6 +511,14 @@ type FTAttribute struct {
PhoneticMatcher string
CaseSensitive bool
WithSuffixtrie bool
// Vector specific attributes
Algorithm string
DataType string
Dim int
DistanceMetric string
M int
EFConstruction int
}
type CursorStats struct {
@@ -1296,21 +1392,26 @@ func parseFTInfo(data map[string]interface{}) (FTInfoResult, error) {
for _, attr := range attributes {
if attrMap, ok := attr.([]interface{}); ok {
att := FTAttribute{}
for i := 0; i < len(attrMap); i++ {
if internal.ToLower(internal.ToString(attrMap[i])) == "attribute" {
attrLen := len(attrMap)
for i := 0; i < attrLen; i++ {
if internal.ToLower(internal.ToString(attrMap[i])) == "attribute" && i+1 < attrLen {
att.Attribute = internal.ToString(attrMap[i+1])
i++
continue
}
if internal.ToLower(internal.ToString(attrMap[i])) == "identifier" {
if internal.ToLower(internal.ToString(attrMap[i])) == "identifier" && i+1 < attrLen {
att.Identifier = internal.ToString(attrMap[i+1])
i++
continue
}
if internal.ToLower(internal.ToString(attrMap[i])) == "type" {
if internal.ToLower(internal.ToString(attrMap[i])) == "type" && i+1 < attrLen {
att.Type = internal.ToString(attrMap[i+1])
i++
continue
}
if internal.ToLower(internal.ToString(attrMap[i])) == "weight" {
if internal.ToLower(internal.ToString(attrMap[i])) == "weight" && i+1 < attrLen {
att.Weight = internal.ToFloat(attrMap[i+1])
i++
continue
}
if internal.ToLower(internal.ToString(attrMap[i])) == "nostem" {
@@ -1329,7 +1430,7 @@ func parseFTInfo(data map[string]interface{}) (FTInfoResult, error) {
att.UNF = true
continue
}
if internal.ToLower(internal.ToString(attrMap[i])) == "phonetic" {
if internal.ToLower(internal.ToString(attrMap[i])) == "phonetic" && i+1 < attrLen {
att.PhoneticMatcher = internal.ToString(attrMap[i+1])
continue
}
@@ -1342,6 +1443,38 @@ func parseFTInfo(data map[string]interface{}) (FTInfoResult, error) {
continue
}
// vector specific attributes
if internal.ToLower(internal.ToString(attrMap[i])) == "algorithm" && i+1 < attrLen {
att.Algorithm = internal.ToString(attrMap[i+1])
i++
continue
}
if internal.ToLower(internal.ToString(attrMap[i])) == "data_type" && i+1 < attrLen {
att.DataType = internal.ToString(attrMap[i+1])
i++
continue
}
if internal.ToLower(internal.ToString(attrMap[i])) == "dim" && i+1 < attrLen {
att.Dim = internal.ToInteger(attrMap[i+1])
i++
continue
}
if internal.ToLower(internal.ToString(attrMap[i])) == "distance_metric" && i+1 < attrLen {
att.DistanceMetric = internal.ToString(attrMap[i+1])
i++
continue
}
if internal.ToLower(internal.ToString(attrMap[i])) == "m" && i+1 < attrLen {
att.M = internal.ToInteger(attrMap[i+1])
i++
continue
}
if internal.ToLower(internal.ToString(attrMap[i])) == "ef_construction" && i+1 < attrLen {
att.EFConstruction = internal.ToInteger(attrMap[i+1])
i++
continue
}
}
ftInfo.Attributes = append(ftInfo.Attributes, att)
}
@@ -1819,6 +1952,207 @@ func (cmd *FTSearchCmd) readReply(rd *proto.Reader) (err error) {
return nil
}
// FTHybridResult represents the result of a hybrid search operation
type FTHybridResult struct {
TotalResults int
Results []map[string]interface{}
Warnings []string
ExecutionTime float64
}
// FTHybridCursorResult represents cursor result for hybrid search
type FTHybridCursorResult struct {
SearchCursorID int
VsimCursorID int
}
type FTHybridCmd struct {
baseCmd
val FTHybridResult
cursorVal *FTHybridCursorResult
options *FTHybridOptions
withCursor bool
}
func newFTHybridCmd(ctx context.Context, options *FTHybridOptions, args ...interface{}) *FTHybridCmd {
var withCursor bool
if options != nil && options.WithCursor {
withCursor = true
}
return &FTHybridCmd{
baseCmd: baseCmd{
ctx: ctx,
args: args,
},
options: options,
withCursor: withCursor,
}
}
func (cmd *FTHybridCmd) String() string {
return cmdString(cmd, cmd.val)
}
func (cmd *FTHybridCmd) SetVal(val FTHybridResult) {
cmd.val = val
}
func (cmd *FTHybridCmd) Result() (FTHybridResult, error) {
return cmd.val, cmd.err
}
func (cmd *FTHybridCmd) CursorResult() (*FTHybridCursorResult, error) {
return cmd.cursorVal, cmd.err
}
func (cmd *FTHybridCmd) Val() FTHybridResult {
return cmd.val
}
func (cmd *FTHybridCmd) CursorVal() *FTHybridCursorResult {
return cmd.cursorVal
}
func (cmd *FTHybridCmd) RawVal() interface{} {
return cmd.rawVal
}
func (cmd *FTHybridCmd) RawResult() (interface{}, error) {
return cmd.rawVal, cmd.err
}
func parseFTHybrid(data []interface{}, withCursor bool) (FTHybridResult, *FTHybridCursorResult, error) {
// Convert to map
resultMap := make(map[string]interface{})
for i := 0; i < len(data); i += 2 {
if i+1 < len(data) {
key, ok := data[i].(string)
if !ok {
return FTHybridResult{}, nil, fmt.Errorf("invalid key type at index %d", i)
}
resultMap[key] = data[i+1]
}
}
// Handle cursor result
if withCursor {
searchCursorID, ok1 := resultMap["SEARCH"].(int64)
vsimCursorID, ok2 := resultMap["VSIM"].(int64)
if !ok1 || !ok2 {
return FTHybridResult{}, nil, fmt.Errorf("invalid cursor result format")
}
return FTHybridResult{}, &FTHybridCursorResult{
SearchCursorID: int(searchCursorID),
VsimCursorID: int(vsimCursorID),
}, nil
}
// Parse regular result
totalResults, ok := resultMap["total_results"].(int64)
if !ok {
return FTHybridResult{}, nil, fmt.Errorf("invalid total_results format")
}
resultsData, ok := resultMap["results"].([]interface{})
if !ok {
return FTHybridResult{}, nil, fmt.Errorf("invalid results format")
}
// Parse each result item
results := make([]map[string]interface{}, 0, len(resultsData))
for _, item := range resultsData {
// Try parsing as map[string]interface{} first (RESP3 format)
if itemMap, ok := item.(map[string]interface{}); ok {
results = append(results, itemMap)
continue
}
// Try parsing as map[interface{}]interface{} (alternative RESP3 format)
if rawMap, ok := item.(map[interface{}]interface{}); ok {
itemMap := make(map[string]interface{})
for k, v := range rawMap {
if keyStr, ok := k.(string); ok {
itemMap[keyStr] = v
}
}
results = append(results, itemMap)
continue
}
// Fall back to array format (RESP2 format - key-value pairs)
itemData, ok := item.([]interface{})
if !ok {
return FTHybridResult{}, nil, fmt.Errorf("invalid result item format")
}
itemMap := make(map[string]interface{})
for i := 0; i < len(itemData); i += 2 {
if i+1 < len(itemData) {
key, ok := itemData[i].(string)
if !ok {
return FTHybridResult{}, nil, fmt.Errorf("invalid item key format")
}
itemMap[key] = itemData[i+1]
}
}
results = append(results, itemMap)
}
// Parse warnings (optional field)
var warnings []string
if warningsData, ok := resultMap["warnings"].([]interface{}); ok {
warnings = make([]string, 0, len(warningsData))
for _, w := range warningsData {
if ws, ok := w.(string); ok {
warnings = append(warnings, ws)
}
}
}
// Parse execution time (optional field)
var executionTime float64
if execTimeVal, exists := resultMap["execution_time"]; exists {
switch v := execTimeVal.(type) {
case string:
var err error
executionTime, err = strconv.ParseFloat(v, 64)
if err != nil {
return FTHybridResult{}, nil, fmt.Errorf("invalid execution_time format: %v", err)
}
case float64:
executionTime = v
case int64:
executionTime = float64(v)
}
}
return FTHybridResult{
TotalResults: int(totalResults),
Results: results,
Warnings: warnings,
ExecutionTime: executionTime,
}, nil, nil
}
func (cmd *FTHybridCmd) readReply(rd *proto.Reader) (err error) {
data, err := rd.ReadSlice()
if err != nil {
return err
}
result, cursorResult, err := parseFTHybrid(data, cmd.withCursor)
if err != nil {
return err
}
if cmd.withCursor {
cmd.cursorVal = cursorResult
} else {
cmd.val = result
}
return nil
}
// FTSearch - Executes a search query on an index.
// The 'index' parameter specifies the index to search, and the 'query' parameter specifies the search query.
// For more information, please refer to the Redis documentation about [FT.SEARCH].
@@ -2191,3 +2525,204 @@ func (c cmdable) FTTagVals(ctx context.Context, index string, field string) *Str
_ = c(ctx, cmd)
return cmd
}
// FTHybrid - Executes a hybrid search combining full-text search and vector similarity
// The 'index' parameter specifies the index to search, 'searchExpr' is the search query,
// 'vectorField' is the name of the vector field, and 'vectorData' is the vector to search with.
// FTHybrid is still experimental, the command behaviour and signature may change
func (c cmdable) FTHybrid(ctx context.Context, index string, searchExpr string, vectorField string, vectorData Vector) *FTHybridCmd {
options := &FTHybridOptions{
CountExpressions: 2,
SearchExpressions: []FTHybridSearchExpression{
{Query: searchExpr},
},
VectorExpressions: []FTHybridVectorExpression{
{VectorField: vectorField, VectorData: vectorData},
},
}
return c.FTHybridWithArgs(ctx, index, options)
}
// FTHybridWithArgs - Executes a hybrid search with advanced options
// FTHybridWithArgs is still experimental, the command behaviour and signature may change
func (c cmdable) FTHybridWithArgs(ctx context.Context, index string, options *FTHybridOptions) *FTHybridCmd {
args := []interface{}{"FT.HYBRID", index}
if options != nil {
// Add search expressions
for _, searchExpr := range options.SearchExpressions {
args = append(args, "SEARCH", searchExpr.Query)
if searchExpr.Scorer != "" {
args = append(args, "SCORER", searchExpr.Scorer)
if len(searchExpr.ScorerParams) > 0 {
args = append(args, searchExpr.ScorerParams...)
}
}
if searchExpr.YieldScoreAs != "" {
args = append(args, "YIELD_SCORE_AS", searchExpr.YieldScoreAs)
}
}
// Add vector expressions
for _, vectorExpr := range options.VectorExpressions {
args = append(args, "VSIM", "@"+vectorExpr.VectorField)
// For FT.HYBRID, we need to send just the raw vector bytes, not the Value() format
// Value() returns [format, data] but FT.HYBRID expects just the blob
vectorValue := vectorExpr.VectorData.Value()
if len(vectorValue) >= 2 {
// vectorValue is [format, data, ...] - we only want the data part
args = append(args, vectorValue[1])
} else {
// Fallback for unexpected format
args = append(args, vectorValue...)
}
if vectorExpr.Method != "" {
args = append(args, vectorExpr.Method)
if len(vectorExpr.MethodParams) > 0 {
// MethodParams should be key-value pairs, count them
args = append(args, len(vectorExpr.MethodParams))
args = append(args, vectorExpr.MethodParams...)
}
}
if vectorExpr.Filter != "" {
args = append(args, "FILTER", vectorExpr.Filter)
}
if vectorExpr.YieldScoreAs != "" {
args = append(args, "YIELD_SCORE_AS", vectorExpr.YieldScoreAs)
}
}
// Add combine/fusion options
if options.Combine != nil {
// Build combine parameters
combineParams := []interface{}{}
switch options.Combine.Method {
case FTHybridCombineRRF:
if options.Combine.Window > 0 {
combineParams = append(combineParams, "WINDOW", options.Combine.Window)
}
if options.Combine.Constant > 0 {
combineParams = append(combineParams, "CONSTANT", options.Combine.Constant)
}
case FTHybridCombineLinear:
if options.Combine.Alpha > 0 {
combineParams = append(combineParams, "ALPHA", options.Combine.Alpha)
}
if options.Combine.Beta > 0 {
combineParams = append(combineParams, "BETA", options.Combine.Beta)
}
}
if options.Combine.YieldScoreAs != "" {
combineParams = append(combineParams, "YIELD_SCORE_AS", options.Combine.YieldScoreAs)
}
// Add COMBINE with method and parameter count
args = append(args, "COMBINE", string(options.Combine.Method))
if len(combineParams) > 0 {
args = append(args, len(combineParams))
args = append(args, combineParams...)
}
}
// Add LOAD (projected fields)
if len(options.Load) > 0 {
args = append(args, "LOAD", len(options.Load))
for _, field := range options.Load {
args = append(args, field)
}
}
// Add GROUPBY
if options.GroupBy != nil {
args = append(args, "GROUPBY", options.GroupBy.Count)
for _, field := range options.GroupBy.Fields {
args = append(args, field)
}
if options.GroupBy.ReduceFunc != "" {
args = append(args, "REDUCE", options.GroupBy.ReduceFunc, options.GroupBy.ReduceCount)
args = append(args, options.GroupBy.ReduceParams...)
}
}
// Add APPLY transformations
for _, apply := range options.Apply {
args = append(args, "APPLY", apply.Expression, "AS", apply.AsField)
}
// Add SORTBY
if len(options.SortBy) > 0 {
sortByOptions := []interface{}{}
for _, sortBy := range options.SortBy {
sortByOptions = append(sortByOptions, sortBy.FieldName)
if sortBy.Asc && sortBy.Desc {
cmd := newFTHybridCmd(ctx, options, args...)
cmd.SetErr(fmt.Errorf("FT.HYBRID: ASC and DESC are mutually exclusive"))
return cmd
}
if sortBy.Asc {
sortByOptions = append(sortByOptions, "ASC")
}
if sortBy.Desc {
sortByOptions = append(sortByOptions, "DESC")
}
}
args = append(args, "SORTBY", len(sortByOptions))
args = append(args, sortByOptions...)
}
// Add FILTER (post-filter)
if options.Filter != "" {
args = append(args, "FILTER", options.Filter)
}
// Add LIMIT
if options.LimitOffset >= 0 && options.Limit > 0 || options.LimitOffset > 0 && options.Limit == 0 {
args = append(args, "LIMIT", options.LimitOffset, options.Limit)
}
// Add PARAMS
if len(options.Params) > 0 {
args = append(args, "PARAMS", len(options.Params)*2)
for key, value := range options.Params {
// Parameter keys should already have '$' prefix from the user
// Don't add it again if it's already there
args = append(args, key, value)
}
}
// Add EXPLAINSCORE
if options.ExplainScore {
args = append(args, "EXPLAINSCORE")
}
// Add TIMEOUT
if options.Timeout > 0 {
args = append(args, "TIMEOUT", options.Timeout)
}
// Add WITHCURSOR support
if options.WithCursor {
args = append(args, "WITHCURSOR")
if options.WithCursorOptions != nil {
if options.WithCursorOptions.Count > 0 {
args = append(args, "COUNT", options.WithCursorOptions.Count)
}
if options.WithCursorOptions.MaxIdle > 0 {
args = append(args, "MAXIDLE", options.WithCursorOptions.MaxIdle)
}
}
}
}
cmd := newFTHybridCmd(ctx, options, args...)
_ = c(ctx, cmd)
return cmd
}
+100 -13
View File
@@ -17,6 +17,8 @@ import (
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/rand"
"github.com/redis/go-redis/v9/internal/util"
"github.com/redis/go-redis/v9/maintnotifications"
"github.com/redis/go-redis/v9/push"
)
//------------------------------------------------------------------------------
@@ -62,6 +64,8 @@ type FailoverOptions struct {
Protocol int
Username string
Password string
// Push notifications are always enabled for RESP3 connections
// CredentialsProvider allows the username and password to be updated
// before reconnecting. It should return the current username and password.
CredentialsProvider func() (username string, password string)
@@ -136,6 +140,15 @@ type FailoverOptions struct {
FailingTimeoutSeconds int
UnstableResp3 bool
// MaintNotificationsConfig is not supported for FailoverClients at the moment
// MaintNotificationsConfig provides custom configuration for maintnotifications upgrades.
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
// upgrade notifications gracefully and manage connection/pool state transitions
// seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
// If nil, maintnotifications upgrades are disabled.
// (however if Mode is nil, it defaults to "auto" - enable if server supports it)
//MaintNotificationsConfig *maintnotifications.Config
}
func (opt *FailoverOptions) clientOptions() *Options {
@@ -182,6 +195,10 @@ func (opt *FailoverOptions) clientOptions() *Options {
IdentitySuffix: opt.IdentitySuffix,
UnstableResp3: opt.UnstableResp3,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeDisabled,
},
}
}
@@ -226,6 +243,10 @@ func (opt *FailoverOptions) sentinelOptions(addr string) *Options {
IdentitySuffix: opt.IdentitySuffix,
UnstableResp3: opt.UnstableResp3,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeDisabled,
},
}
}
@@ -275,6 +296,10 @@ func (opt *FailoverOptions) clusterOptions() *ClusterOptions {
DisableIndentity: opt.DisableIndentity,
IdentitySuffix: opt.IdentitySuffix,
FailingTimeoutSeconds: opt.FailingTimeoutSeconds,
MaintNotificationsConfig: &maintnotifications.Config{
Mode: maintnotifications.ModeDisabled,
},
}
}
@@ -454,8 +479,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
opt.Dialer = masterReplicaDialer(failover)
opt.init()
var connPool *pool.ConnPool
rdb := &Client{
baseClient: &baseClient{
opt: opt,
@@ -463,15 +486,29 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
}
rdb.init()
connPool = newConnPool(opt, rdb.dialHook)
rdb.connPool = connPool
// Initialize push notification processor using shared helper
// Use void processor by default for RESP2 connections
rdb.pushProcessor = initializePushProcessor(opt)
var err error
rdb.connPool, err = newConnPool(opt, rdb.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
}
rdb.pubSubPool, err = newPubSubPool(opt, rdb.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
}
rdb.onClose = rdb.wrappedOnClose(failover.Close)
failover.mu.Lock()
failover.onFailover = func(ctx context.Context, addr string) {
_ = connPool.Filter(func(cn *pool.Conn) bool {
return cn.RemoteAddr().String() != addr
})
if connPool, ok := rdb.connPool.(*pool.ConnPool); ok {
_ = connPool.Filter(func(cn *pool.Conn) bool {
return cn.RemoteAddr().String() != addr
})
}
}
failover.mu.Unlock()
@@ -529,15 +566,40 @@ func NewSentinelClient(opt *Options) *SentinelClient {
},
}
// Initialize push notification processor using shared helper
// Use void processor for Sentinel clients
c.pushProcessor = NewVoidPushNotificationProcessor()
c.initHooks(hooks{
dial: c.baseClient.dial,
process: c.baseClient.process,
})
c.connPool = newConnPool(opt, c.dialHook)
var err error
c.connPool, err = newConnPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
}
c.pubSubPool, err = newPubSubPool(opt, c.dialHook)
if err != nil {
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
}
return c
}
// GetPushNotificationHandler returns the handler for a specific push notification name.
// Returns nil if no handler is registered for the given name.
func (c *SentinelClient) GetPushNotificationHandler(pushNotificationName string) push.NotificationHandler {
return c.pushProcessor.GetHandler(pushNotificationName)
}
// RegisterPushNotificationHandler registers a handler for a specific push notification name.
// Returns an error if a handler is already registered for this push notification name.
// If protected is true, the handler cannot be unregistered.
func (c *SentinelClient) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error {
return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected)
}
func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
err := c.processHook(ctx, cmd)
cmd.SetErr(err)
@@ -547,13 +609,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
func (c *SentinelClient) pubSub() *PubSub {
pubsub := &PubSub{
opt: c.opt,
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
return c.newConn(ctx)
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
if err != nil {
return nil, err
}
// will return nil if already initialized
err = c.initConn(ctx, cn)
if err != nil {
_ = cn.Close()
return nil, err
}
// Track connection in PubSubPool
c.pubSubPool.TrackConn(cn)
return cn, nil
},
closeConn: c.connPool.CloseConn,
closeConn: func(cn *pool.Conn) error {
// Untrack connection from PubSubPool
c.pubSubPool.UntrackConn(cn)
_ = cn.Close()
return nil
},
pushProcessor: c.pushProcessor,
}
pubsub.init()
return pubsub
}
@@ -776,6 +856,11 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) {
}
}
// short circuit if no sentinels configured
if len(c.sentinelAddrs) == 0 {
return "", errors.New("redis: no sentinels configured")
}
var (
masterAddr string
wg sync.WaitGroup
@@ -823,10 +908,12 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) {
}
func joinErrors(errs []error) string {
if len(errs) == 0 {
return ""
}
if len(errs) == 1 {
return errs[0].Error()
}
b := []byte(errs[0].Error())
for _, err := range errs[1:] {
b = append(b, '\n')
+7 -2
View File
@@ -2,6 +2,7 @@ package redis
import (
"context"
"errors"
"strings"
"time"
@@ -313,7 +314,9 @@ func (c cmdable) ZPopMax(ctx context.Context, key string, count ...int64) *ZSlic
case 1:
args = append(args, count[0])
default:
panic("too many arguments")
cmd := NewZSliceCmd(ctx)
cmd.SetErr(errors.New("too many arguments"))
return cmd
}
cmd := NewZSliceCmd(ctx, args...)
@@ -333,7 +336,9 @@ func (c cmdable) ZPopMin(ctx context.Context, key string, count ...int64) *ZSlic
case 1:
args = append(args, count[0])
default:
panic("too many arguments")
cmd := NewZSliceCmd(ctx)
cmd.SetErr(errors.New("too many arguments"))
return cmd
}
cmd := NewZSliceCmd(ctx, args...)
+5
View File
@@ -263,6 +263,7 @@ type XReadGroupArgs struct {
Count int64
Block time.Duration
NoAck bool
Claim time.Duration // Claim idle pending entries older than this duration
}
func (c cmdable) XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSliceCmd {
@@ -282,6 +283,10 @@ func (c cmdable) XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSlic
args = append(args, "noack")
keyPos++
}
if a.Claim > 0 {
args = append(args, "claim", int64(a.Claim/time.Millisecond))
keyPos += 2
}
args = append(args, "streams")
keyPos++
for _, s := range a.Streams {
+446 -1
View File
@@ -2,6 +2,7 @@ package redis
import (
"context"
"fmt"
"time"
)
@@ -9,6 +10,8 @@ type StringCmdable interface {
Append(ctx context.Context, key, value string) *IntCmd
Decr(ctx context.Context, key string) *IntCmd
DecrBy(ctx context.Context, key string, decrement int64) *IntCmd
DelExArgs(ctx context.Context, key string, a DelExArgs) *IntCmd
Digest(ctx context.Context, key string) *DigestCmd
Get(ctx context.Context, key string) *StringCmd
GetRange(ctx context.Context, key string, start, end int64) *StringCmd
GetSet(ctx context.Context, key string, value interface{}) *StringCmd
@@ -21,9 +24,18 @@ type StringCmdable interface {
MGet(ctx context.Context, keys ...string) *SliceCmd
MSet(ctx context.Context, values ...interface{}) *StatusCmd
MSetNX(ctx context.Context, values ...interface{}) *BoolCmd
MSetEX(ctx context.Context, args MSetEXArgs, values ...interface{}) *IntCmd
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd
SetArgs(ctx context.Context, key string, value interface{}, a SetArgs) *StatusCmd
SetEx(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd
SetIFEQ(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd
SetIFEQGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd
SetIFNE(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd
SetIFNEGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd
SetIFDEQ(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd
SetIFDEQGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd
SetIFDNE(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd
SetIFDNEGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd
SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *BoolCmd
SetXX(ctx context.Context, key string, value interface{}, expiration time.Duration) *BoolCmd
SetRange(ctx context.Context, key string, offset int64, value string) *IntCmd
@@ -48,6 +60,76 @@ func (c cmdable) DecrBy(ctx context.Context, key string, decrement int64) *IntCm
return cmd
}
// DelExArgs provides arguments for the DelExArgs function.
type DelExArgs struct {
// Mode can be `IFEQ`, `IFNE`, `IFDEQ`, or `IFDNE`.
Mode string
// MatchValue is used with IFEQ/IFNE modes for compare-and-delete operations.
// - IFEQ: only delete if current value equals MatchValue
// - IFNE: only delete if current value does not equal MatchValue
MatchValue interface{}
// MatchDigest is used with IFDEQ/IFDNE modes for digest-based compare-and-delete.
// - IFDEQ: only delete if current value's digest equals MatchDigest
// - IFDNE: only delete if current value's digest does not equal MatchDigest
//
// The digest is a uint64 xxh3 hash value.
//
// For examples of client-side digest generation, see:
// example/digest-optimistic-locking/
MatchDigest uint64
}
// DelExArgs Redis `DELEX key [IFEQ|IFNE|IFDEQ|IFDNE] match-value` command.
// Compare-and-delete with flexible conditions.
//
// Returns the number of keys that were removed (0 or 1).
//
// NOTE DelExArgs is still experimental
// it's signature and behaviour may change
func (c cmdable) DelExArgs(ctx context.Context, key string, a DelExArgs) *IntCmd {
args := []interface{}{"delex", key}
if a.Mode != "" {
args = append(args, a.Mode)
// Add match value/digest based on mode
switch a.Mode {
case "ifeq", "IFEQ", "ifne", "IFNE":
if a.MatchValue != nil {
args = append(args, a.MatchValue)
}
case "ifdeq", "IFDEQ", "ifdne", "IFDNE":
if a.MatchDigest != 0 {
args = append(args, fmt.Sprintf("%016x", a.MatchDigest))
}
}
}
cmd := NewIntCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// Digest returns the xxh3 hash (uint64) of the specified key's value.
//
// The digest is a 64-bit xxh3 hash that can be used for optimistic locking
// with SetIFDEQ, SetIFDNE, and DelExArgs commands.
//
// For examples of client-side digest generation and usage patterns, see:
// example/digest-optimistic-locking/
//
// Redis 8.4+. See https://redis.io/commands/digest/
//
// NOTE Digest is still experimental
// it's signature and behaviour may change
func (c cmdable) Digest(ctx context.Context, key string) *DigestCmd {
cmd := NewDigestCmd(ctx, "digest", key)
_ = c(ctx, cmd)
return cmd
}
// Get Redis `GET key` command. It returns redis.Nil error when key does not exist.
func (c cmdable) Get(ctx context.Context, key string) *StringCmd {
cmd := NewStringCmd(ctx, "get", key)
@@ -112,6 +194,35 @@ func (c cmdable) IncrByFloat(ctx context.Context, key string, value float64) *Fl
return cmd
}
type SetCondition string
const (
// NX only set the keys and their expiration if none exist
NX SetCondition = "NX"
// XX only set the keys and their expiration if all already exist
XX SetCondition = "XX"
)
type ExpirationMode string
const (
// EX sets expiration in seconds
EX ExpirationMode = "EX"
// PX sets expiration in milliseconds
PX ExpirationMode = "PX"
// EXAT sets expiration as Unix timestamp in seconds
EXAT ExpirationMode = "EXAT"
// PXAT sets expiration as Unix timestamp in milliseconds
PXAT ExpirationMode = "PXAT"
// KEEPTTL keeps the existing TTL
KEEPTTL ExpirationMode = "KEEPTTL"
)
type ExpirationOption struct {
Mode ExpirationMode
Value int64
}
func (c cmdable) LCS(ctx context.Context, q *LCSQuery) *LCSCmd {
cmd := NewLCSCmd(ctx, q)
_ = c(ctx, cmd)
@@ -157,6 +268,49 @@ func (c cmdable) MSetNX(ctx context.Context, values ...interface{}) *BoolCmd {
return cmd
}
type MSetEXArgs struct {
Condition SetCondition
Expiration *ExpirationOption
}
// MSetEX sets the given keys to their respective values.
// This command is an extension of the MSETNX that adds expiration and XX options.
// Available since Redis 8.4
// Important: When this method is used with Cluster clients, all keys
// must be in the same hash slot, otherwise CROSSSLOT error will be returned.
// For more information, see https://redis.io/commands/msetex
func (c cmdable) MSetEX(ctx context.Context, args MSetEXArgs, values ...interface{}) *IntCmd {
expandedArgs := appendArgs([]interface{}{}, values)
numkeys := len(expandedArgs) / 2
cmdArgs := make([]interface{}, 0, 2+len(expandedArgs)+3)
cmdArgs = append(cmdArgs, "msetex", numkeys)
cmdArgs = append(cmdArgs, expandedArgs...)
if args.Condition != "" {
cmdArgs = append(cmdArgs, string(args.Condition))
}
if args.Expiration != nil {
switch args.Expiration.Mode {
case EX:
cmdArgs = append(cmdArgs, "ex", args.Expiration.Value)
case PX:
cmdArgs = append(cmdArgs, "px", args.Expiration.Value)
case EXAT:
cmdArgs = append(cmdArgs, "exat", args.Expiration.Value)
case PXAT:
cmdArgs = append(cmdArgs, "pxat", args.Expiration.Value)
case KEEPTTL:
cmdArgs = append(cmdArgs, "keepttl")
}
}
cmd := NewIntCmd(ctx, cmdArgs...)
_ = c(ctx, cmd)
return cmd
}
// Set Redis `SET key value [expiration]` command.
// Use expiration for `SETEx`-like behavior.
//
@@ -185,9 +339,24 @@ func (c cmdable) Set(ctx context.Context, key string, value interface{}, expirat
// SetArgs provides arguments for the SetArgs function.
type SetArgs struct {
// Mode can be `NX` or `XX` or empty.
// Mode can be `NX`, `XX`, `IFEQ`, `IFNE`, `IFDEQ`, `IFDNE` or empty.
Mode string
// MatchValue is used with IFEQ/IFNE modes for compare-and-set operations.
// - IFEQ: only set if current value equals MatchValue
// - IFNE: only set if current value does not equal MatchValue
MatchValue interface{}
// MatchDigest is used with IFDEQ/IFDNE modes for digest-based compare-and-set.
// - IFDEQ: only set if current value's digest equals MatchDigest
// - IFDNE: only set if current value's digest does not equal MatchDigest
//
// The digest is a uint64 xxh3 hash value.
//
// For examples of client-side digest generation, see:
// example/digest-optimistic-locking/
MatchDigest uint64
// Zero `TTL` or `Expiration` means that the key has no expiration time.
TTL time.Duration
ExpireAt time.Time
@@ -223,6 +392,18 @@ func (c cmdable) SetArgs(ctx context.Context, key string, value interface{}, a S
if a.Mode != "" {
args = append(args, a.Mode)
// Add match value/digest for CAS modes
switch a.Mode {
case "ifeq", "IFEQ", "ifne", "IFNE":
if a.MatchValue != nil {
args = append(args, a.MatchValue)
}
case "ifdeq", "IFDEQ", "ifdne", "IFDNE":
if a.MatchDigest != 0 {
args = append(args, fmt.Sprintf("%016x", a.MatchDigest))
}
}
}
if a.Get {
@@ -290,6 +471,270 @@ func (c cmdable) SetXX(ctx context.Context, key string, value interface{}, expir
return cmd
}
// SetIFEQ Redis `SET key value [expiration] IFEQ match-value` command.
// Compare-and-set: only sets the value if the current value equals matchValue.
//
// Returns "OK" on success.
// Returns nil if the operation was aborted due to condition not matching.
// Zero expiration means the key has no expiration time.
//
// NOTE SetIFEQ is still experimental
// it's signature and behaviour may change
func (c cmdable) SetIFEQ(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd {
args := []interface{}{"set", key, value}
if expiration > 0 {
if usePrecise(expiration) {
args = append(args, "px", formatMs(ctx, expiration))
} else {
args = append(args, "ex", formatSec(ctx, expiration))
}
} else if expiration == KeepTTL {
args = append(args, "keepttl")
}
args = append(args, "ifeq", matchValue)
cmd := NewStatusCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// SetIFEQGet Redis `SET key value [expiration] IFEQ match-value GET` command.
// Compare-and-set with GET: only sets the value if the current value equals matchValue,
// and returns the previous value.
//
// Returns the previous value on success.
// Returns nil if the operation was aborted due to condition not matching.
// Zero expiration means the key has no expiration time.
//
// NOTE SetIFEQGet is still experimental
// it's signature and behaviour may change
func (c cmdable) SetIFEQGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd {
args := []interface{}{"set", key, value}
if expiration > 0 {
if usePrecise(expiration) {
args = append(args, "px", formatMs(ctx, expiration))
} else {
args = append(args, "ex", formatSec(ctx, expiration))
}
} else if expiration == KeepTTL {
args = append(args, "keepttl")
}
args = append(args, "ifeq", matchValue, "get")
cmd := NewStringCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// SetIFNE Redis `SET key value [expiration] IFNE match-value` command.
// Compare-and-set: only sets the value if the current value does not equal matchValue.
//
// Returns "OK" on success.
// Returns nil if the operation was aborted due to condition not matching.
// Zero expiration means the key has no expiration time.
//
// NOTE SetIFNE is still experimental
// it's signature and behaviour may change
func (c cmdable) SetIFNE(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd {
args := []interface{}{"set", key, value}
if expiration > 0 {
if usePrecise(expiration) {
args = append(args, "px", formatMs(ctx, expiration))
} else {
args = append(args, "ex", formatSec(ctx, expiration))
}
} else if expiration == KeepTTL {
args = append(args, "keepttl")
}
args = append(args, "ifne", matchValue)
cmd := NewStatusCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// SetIFNEGet Redis `SET key value [expiration] IFNE match-value GET` command.
// Compare-and-set with GET: only sets the value if the current value does not equal matchValue,
// and returns the previous value.
//
// Returns the previous value on success.
// Returns nil if the operation was aborted due to condition not matching.
// Zero expiration means the key has no expiration time.
//
// NOTE SetIFNEGet is still experimental
// it's signature and behaviour may change
func (c cmdable) SetIFNEGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd {
args := []interface{}{"set", key, value}
if expiration > 0 {
if usePrecise(expiration) {
args = append(args, "px", formatMs(ctx, expiration))
} else {
args = append(args, "ex", formatSec(ctx, expiration))
}
} else if expiration == KeepTTL {
args = append(args, "keepttl")
}
args = append(args, "ifne", matchValue, "get")
cmd := NewStringCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// SetIFDEQ sets the value only if the current value's digest equals matchDigest.
//
// This is a compare-and-set operation using xxh3 digest for optimistic locking.
// The matchDigest parameter is a uint64 xxh3 hash value.
//
// Returns "OK" on success.
// Returns redis.Nil if the digest doesn't match (value was modified).
// Zero expiration means the key has no expiration time.
//
// For examples of client-side digest generation and usage patterns, see:
// example/digest-optimistic-locking/
//
// Redis 8.4+. See https://redis.io/commands/set/
//
// NOTE SetIFNEQ is still experimental
// it's signature and behaviour may change
func (c cmdable) SetIFDEQ(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd {
args := []interface{}{"set", key, value}
if expiration > 0 {
if usePrecise(expiration) {
args = append(args, "px", formatMs(ctx, expiration))
} else {
args = append(args, "ex", formatSec(ctx, expiration))
}
} else if expiration == KeepTTL {
args = append(args, "keepttl")
}
args = append(args, "ifdeq", fmt.Sprintf("%016x", matchDigest))
cmd := NewStatusCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// SetIFDEQGet sets the value only if the current value's digest equals matchDigest,
// and returns the previous value.
//
// This is a compare-and-set operation using xxh3 digest for optimistic locking.
// The matchDigest parameter is a uint64 xxh3 hash value.
//
// Returns the previous value on success.
// Returns redis.Nil if the digest doesn't match (value was modified).
// Zero expiration means the key has no expiration time.
//
// For examples of client-side digest generation and usage patterns, see:
// example/digest-optimistic-locking/
//
// Redis 8.4+. See https://redis.io/commands/set/
//
// NOTE SetIFNEQGet is still experimental
// it's signature and behaviour may change
func (c cmdable) SetIFDEQGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd {
args := []interface{}{"set", key, value}
if expiration > 0 {
if usePrecise(expiration) {
args = append(args, "px", formatMs(ctx, expiration))
} else {
args = append(args, "ex", formatSec(ctx, expiration))
}
} else if expiration == KeepTTL {
args = append(args, "keepttl")
}
args = append(args, "ifdeq", fmt.Sprintf("%016x", matchDigest), "get")
cmd := NewStringCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// SetIFDNE sets the value only if the current value's digest does NOT equal matchDigest.
//
// This is a compare-and-set operation using xxh3 digest for optimistic locking.
// The matchDigest parameter is a uint64 xxh3 hash value.
//
// Returns "OK" on success (digest didn't match, value was set).
// Returns redis.Nil if the digest matches (value was not modified).
// Zero expiration means the key has no expiration time.
//
// For examples of client-side digest generation and usage patterns, see:
// example/digest-optimistic-locking/
//
// Redis 8.4+. See https://redis.io/commands/set/
//
// NOTE SetIFDNE is still experimental
// it's signature and behaviour may change
func (c cmdable) SetIFDNE(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd {
args := []interface{}{"set", key, value}
if expiration > 0 {
if usePrecise(expiration) {
args = append(args, "px", formatMs(ctx, expiration))
} else {
args = append(args, "ex", formatSec(ctx, expiration))
}
} else if expiration == KeepTTL {
args = append(args, "keepttl")
}
args = append(args, "ifdne", fmt.Sprintf("%016x", matchDigest))
cmd := NewStatusCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
// SetIFDNEGet sets the value only if the current value's digest does NOT equal matchDigest,
// and returns the previous value.
//
// This is a compare-and-set operation using xxh3 digest for optimistic locking.
// The matchDigest parameter is a uint64 xxh3 hash value.
//
// Returns the previous value on success (digest didn't match, value was set).
// Returns redis.Nil if the digest matches (value was not modified).
// Zero expiration means the key has no expiration time.
//
// For examples of client-side digest generation and usage patterns, see:
// example/digest-optimistic-locking/
//
// Redis 8.4+. See https://redis.io/commands/set/
//
// NOTE SetIFDNEGet is still experimental
// it's signature and behaviour may change
func (c cmdable) SetIFDNEGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd {
args := []interface{}{"set", key, value}
if expiration > 0 {
if usePrecise(expiration) {
args = append(args, "px", formatMs(ctx, expiration))
} else {
args = append(args, "ex", formatSec(ctx, expiration))
}
} else if expiration == KeepTTL {
args = append(args, "keepttl")
}
args = append(args, "ifdne", fmt.Sprintf("%016x", matchDigest), "get")
cmd := NewStringCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
func (c cmdable) SetRange(ctx context.Context, key string, offset int64, value string) *IntCmd {
cmd := NewIntCmd(ctx, "setrange", key, offset, value)
_ = c(ctx, cmd)
+4 -3
View File
@@ -24,9 +24,10 @@ type Tx struct {
func (c *Client) newTx() *Tx {
tx := Tx{
baseClient: baseClient{
opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool),
hooksMixin: c.hooksMixin.clone(),
opt: c.opt.clone(), // Clone options to avoid sharing mutable state between transaction and parent client
connPool: pool.NewStickyConnPool(c.connPool),
hooksMixin: c.hooksMixin.clone(),
pushProcessor: c.pushProcessor, // Copy push processor from parent client
},
}
tx.init()
+16 -9
View File
@@ -7,6 +7,7 @@ import (
"time"
"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/maintnotifications"
)
// UniversalOptions information is required by UniversalClient to establish
@@ -122,6 +123,9 @@ type UniversalOptions struct {
// IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint).
IsClusterMode bool
// MaintNotificationsConfig provides configuration for maintnotifications upgrades.
MaintNotificationsConfig *maintnotifications.Config
}
// Cluster returns cluster options created from the universal options.
@@ -172,11 +176,12 @@ func (o *UniversalOptions) Cluster() *ClusterOptions {
TLSConfig: o.TLSConfig,
DisableIdentity: o.DisableIdentity,
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
UnstableResp3: o.UnstableResp3,
DisableIdentity: o.DisableIdentity,
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
UnstableResp3: o.UnstableResp3,
MaintNotificationsConfig: o.MaintNotificationsConfig,
}
}
@@ -237,6 +242,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions {
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
// Note: MaintNotificationsConfig not supported for FailoverOptions
}
}
@@ -284,10 +290,11 @@ func (o *UniversalOptions) Simple() *Options {
TLSConfig: o.TLSConfig,
DisableIdentity: o.DisableIdentity,
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
DisableIdentity: o.DisableIdentity,
DisableIndentity: o.DisableIndentity,
IdentitySuffix: o.IdentitySuffix,
UnstableResp3: o.UnstableResp3,
MaintNotificationsConfig: o.MaintNotificationsConfig,
}
}
+11
View File
@@ -26,6 +26,7 @@ type VectorSetCmdable interface {
VSimWithScores(ctx context.Context, key string, val Vector) *VectorScoreSliceCmd
VSimWithArgs(ctx context.Context, key string, val Vector, args *VSimArgs) *StringSliceCmd
VSimWithArgsWithScores(ctx context.Context, key string, val Vector, args *VSimArgs) *VectorScoreSliceCmd
VRange(ctx context.Context, key, start, end string, count int64) *StringSliceCmd
}
type Vector interface {
@@ -345,3 +346,13 @@ func (c cmdable) VSimWithArgsWithScores(ctx context.Context, key string, val Vec
_ = c(ctx, cmd)
return cmd
}
// `VRANGE key start end count`
// a negative count means to return all the elements in the vector set.
// note: the API is experimental and may be subject to change.
func (c cmdable) VRange(ctx context.Context, key, start, end string, count int64) *StringSliceCmd {
args := []any{"vrange", key, start, end, count}
cmd := NewStringSliceCmd(ctx, args...)
_ = c(ctx, cmd)
return cmd
}
+1 -1
View File
@@ -2,5 +2,5 @@ package redis
// Version is the current release version.
func Version() string {
return "9.14.0"
return "9.17.2"
}
+6 -1
View File
@@ -30,17 +30,22 @@ github.com/gorilla/sessions
# github.com/pmezard/go-difflib v1.0.0
## explicit
github.com/pmezard/go-difflib/difflib
# github.com/redis/go-redis/v9 v9.14.0
# github.com/redis/go-redis/v9 v9.17.2
## explicit; go 1.18
github.com/redis/go-redis/v9
github.com/redis/go-redis/v9/auth
github.com/redis/go-redis/v9/internal
github.com/redis/go-redis/v9/internal/auth/streaming
github.com/redis/go-redis/v9/internal/hashtag
github.com/redis/go-redis/v9/internal/hscan
github.com/redis/go-redis/v9/internal/interfaces
github.com/redis/go-redis/v9/internal/maintnotifications/logs
github.com/redis/go-redis/v9/internal/pool
github.com/redis/go-redis/v9/internal/proto
github.com/redis/go-redis/v9/internal/rand
github.com/redis/go-redis/v9/internal/util
github.com/redis/go-redis/v9/maintnotifications
github.com/redis/go-redis/v9/push
# github.com/stretchr/testify v1.10.0
## explicit; go 1.17
github.com/stretchr/testify/assert