mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Size computation for allocation may overflow (#99)
* Size computation for allocation may overflow Performing calculations involving the size of potentially large strings or slices can result in an overflow (for signed integer types) or a wraparound (for unsigned types). An overflow causes the result of the calculation to become negative, while a wraparound results in a small (positive) number.
This commit is contained in:
@@ -18,6 +18,6 @@ jobs:
|
||||
pr-checks:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||
with:
|
||||
go-version: "1.24"
|
||||
go-version: "1.24.11"
|
||||
coverage-threshold: 70
|
||||
secrets: inherit
|
||||
|
||||
@@ -17,5 +17,5 @@ jobs:
|
||||
release:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
|
||||
with:
|
||||
go-version: "1.24"
|
||||
go-version: "1.24.11"
|
||||
secrets: inherit
|
||||
|
||||
@@ -787,6 +787,7 @@ func (mm *TaskMemoryMonitor) ForceGC() (before, after TaskMemoryStats, err error
|
||||
}
|
||||
|
||||
if mm.logger != nil {
|
||||
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
|
||||
freed := int64(before.HeapAlloc) - int64(after.HeapAlloc)
|
||||
mm.logger.Infof("Forced GC: freed %d bytes (%.2f MB)", freed, float64(freed)/(1024*1024))
|
||||
}
|
||||
|
||||
@@ -105,6 +105,7 @@ func (l *ConfigLoader) loadFile(path string) (*UnifiedConfig, error) {
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
// #nosec G304 -- path is validated via filepath.Abs above
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file %s: %w", absPath, err)
|
||||
|
||||
@@ -252,6 +252,7 @@ func (m *ConfigMigrator) MigrateFile(filePath string) (*UnifiedConfig, error) {
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
// #nosec G304 -- path is validated via filepath.Abs above
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
|
||||
@@ -348,6 +348,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
|
||||
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
// #nosec G304 -- path is constructed from trusted config values via credentialsFilePath()
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
|
||||
@@ -538,6 +538,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
|
||||
delay = float64(re.config.MaxDelay)
|
||||
}
|
||||
|
||||
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
|
||||
if re.config.EnableJitter {
|
||||
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0)
|
||||
delay += jitter
|
||||
|
||||
@@ -6,7 +6,7 @@ require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/redis/go-redis/v9 v9.14.0
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
|
||||
@@ -20,8 +20,8 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
|
||||
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
|
||||
Vendored
+2
-2
@@ -76,7 +76,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
|
||||
// Test connectivity
|
||||
if err := backend.Ping(context.Background()); err != nil {
|
||||
pool.Close()
|
||||
_ = pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping Redis: %w", err)
|
||||
}
|
||||
|
||||
@@ -263,7 +263,7 @@ func (r *RedisBackend) Clear(ctx context.Context) error {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.Do("DEL", key) // Best effort, ignore errors
|
||||
_, _ = conn.Do("DEL", key) // Best effort, ignore errors
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
+23
-22
@@ -82,7 +82,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
// Reuse existing connection - validate if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
// Connection is stale, close it and try again
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
@@ -94,6 +94,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
|
||||
default:
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
if err != nil {
|
||||
@@ -115,7 +116,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
case conn = <-p.connections:
|
||||
// Validate connection if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
@@ -144,7 +145,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
p.activeConns.Add(-1)
|
||||
|
||||
if p.closed.Load() || conn.closed.Load() {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
return
|
||||
}
|
||||
@@ -155,7 +156,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
// Successfully returned to pool
|
||||
default:
|
||||
// Pool full, close connection
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
}
|
||||
}
|
||||
@@ -173,7 +174,7 @@ func (p *ConnectionPool) Close() error {
|
||||
|
||||
// Close all pooled connections
|
||||
for conn := range p.connections {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -212,7 +213,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Authenticate if password is provided
|
||||
if p.config.Password != "" {
|
||||
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
|
||||
redisConn.Close()
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -220,7 +221,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Select database
|
||||
if p.config.DB != 0 {
|
||||
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
|
||||
redisConn.Close()
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("failed to select database: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -246,15 +247,15 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Build command arguments
|
||||
// Check for overflow: ensure len(args)+1 doesn't cause allocation overflow
|
||||
// Limit to a safe value that prevents integer overflow in allocation size calculation
|
||||
// (capacity * sizeof(string) must fit in int/size_t)
|
||||
argsLen := len(args)
|
||||
const maxSafeArgs = (1 << 20) - 1 // 1M args is already absurdly large for Redis commands
|
||||
if argsLen < 0 || argsLen > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments")
|
||||
// Validate argument count to prevent integer overflow in slice operations
|
||||
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
|
||||
const maxSafeArgs = (1 << 20) - 1
|
||||
if len(args) > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments: exceeds maximum safe count")
|
||||
}
|
||||
|
||||
// Build command arguments
|
||||
// Validate total argument size to prevent memory exhaustion
|
||||
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
|
||||
totalBytes := len(command)
|
||||
for _, s := range args {
|
||||
@@ -267,13 +268,13 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
return nil, errors.New("total argument size exceeds maximum allowed")
|
||||
}
|
||||
}
|
||||
cmdArgs := make([]string, 0, argsLen+1)
|
||||
cmdArgs = append(cmdArgs, command)
|
||||
cmdArgs = append(cmdArgs, args...)
|
||||
// Build command slice: prepend command to args
|
||||
// Using append avoids arithmetic on potentially large len(args)
|
||||
cmdArgs := append([]string{command}, args...)
|
||||
|
||||
// Set write timeout
|
||||
if c.writeTimeout > 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
|
||||
// Write command (using pooled writer for memory efficiency)
|
||||
@@ -287,7 +288,7 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
|
||||
// Set read timeout
|
||||
if c.readTimeout > 0 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
|
||||
// Read response (using pooled reader for memory efficiency)
|
||||
@@ -328,8 +329,8 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
|
||||
// Set a read deadline for the ping
|
||||
if conn.conn != nil {
|
||||
conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer conn.conn.SetReadDeadline(time.Time{}) // Clear deadline
|
||||
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
|
||||
}
|
||||
|
||||
_, err := conn.Do("PING")
|
||||
|
||||
@@ -158,6 +158,7 @@ func (cb *CircuitBreaker) AllowRequest() bool {
|
||||
case StateHalfOpen:
|
||||
// Allow limited requests in half-open state
|
||||
current := cb.halfOpenRequests.Add(1)
|
||||
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
|
||||
return current <= int32(cb.config.HalfOpenMaxRequests)
|
||||
|
||||
default:
|
||||
@@ -181,6 +182,7 @@ func (cb *CircuitBreaker) RecordSuccess() {
|
||||
case StateHalfOpen:
|
||||
// If we've had enough successful requests, close the circuit
|
||||
successfulRequests := cb.halfOpenRequests.Load()
|
||||
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
|
||||
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
@@ -203,6 +205,7 @@ func (cb *CircuitBreaker) RecordFailure() {
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Check if we should open the circuit
|
||||
// #nosec G115 -- MaxFailures is a small config value that fits in int32
|
||||
if failures >= int32(cb.config.MaxFailures) {
|
||||
cb.openCircuit()
|
||||
} else if cb.config.FailureThreshold > 0 {
|
||||
|
||||
+2
@@ -217,6 +217,7 @@ func (hc *HealthChecker) recordSuccess(latency time.Duration) {
|
||||
newStatus := currentStatus
|
||||
|
||||
// Check if we should become healthy
|
||||
// #nosec G115 -- HealthyThreshold is a small config value that fits in int32
|
||||
if successes >= int32(hc.config.HealthyThreshold) {
|
||||
if latency > hc.config.DegradedThreshold {
|
||||
newStatus = HealthDegraded
|
||||
@@ -241,6 +242,7 @@ func (hc *HealthChecker) recordFailure() {
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Check if we should become unhealthy
|
||||
// #nosec G115 -- UnhealthyThreshold is a small config value that fits in int32
|
||||
if failures >= int32(hc.config.UnhealthyThreshold) {
|
||||
hc.setStatus(HealthUnhealthy)
|
||||
}
|
||||
|
||||
@@ -150,6 +150,7 @@ func (h *HealthCheckBackend) IsHealthy() bool {
|
||||
|
||||
// recordResult records the result of an operation for health tracking
|
||||
func (h *HealthCheckBackend) recordResult(success bool) {
|
||||
// #nosec G115 -- threshold config values are small integers that fit in int32
|
||||
if success {
|
||||
fails := h.consecutiveFails.Swap(0)
|
||||
oks := h.consecutiveOK.Add(1)
|
||||
|
||||
@@ -304,7 +304,8 @@ func (f *Factory) createSecureTLSConfig() *tls.Config {
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
InsecureSkipVerify: false, // SECURITY: Always verify certificates
|
||||
InsecureSkipVerify: false, // SECURITY: Always verify certificates
|
||||
// #nosec G402 -- PreferServerCipherSuites is deprecated in Go 1.17+ but setting it to false is safe
|
||||
PreferServerCipherSuites: false, // Let client choose best cipher
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,12 +144,14 @@ func getOrCreateLogFile(filename string) io.Writer {
|
||||
}
|
||||
|
||||
// Ensure log directory exists
|
||||
// #nosec G301 -- log directory needs to be readable by monitoring tools
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
// Fall back to stderr if we can't create the directory
|
||||
return os.Stderr
|
||||
}
|
||||
|
||||
filepath := logDir + "/" + filename
|
||||
// #nosec G302 G304 -- log files need to be readable; path is from trusted env var
|
||||
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
// Fall back to stderr if we can't open the file
|
||||
|
||||
@@ -107,6 +107,7 @@ const (
|
||||
JWTPattern = `^[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+$`
|
||||
|
||||
// Bearer token pattern (Authorization header)
|
||||
// #nosec G101 -- This is a regex pattern for validation, not a hardcoded credential
|
||||
BearerTokenPattern = `^Bearer\s+([A-Za-z0-9._~+/-]+=*)$`
|
||||
|
||||
// Client ID pattern (alphanumeric with common separators)
|
||||
@@ -119,6 +120,7 @@ const (
|
||||
SessionIDPattern = `^[a-fA-F0-9]{32,128}$`
|
||||
|
||||
// CSRF token pattern (base64url)
|
||||
// #nosec G101 -- This is a regex pattern for validation, not a hardcoded credential
|
||||
CSRFTokenPattern = `^[A-Za-z0-9_-]+$`
|
||||
|
||||
// Nonce pattern (base64url)
|
||||
|
||||
@@ -202,8 +202,10 @@ func (p *TransportPool) createTransport(config TransportConfig) *http.Transport
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
// #nosec G402 -- PreferServerCipherSuites is deprecated in Go 1.17+ but setting it is harmless
|
||||
PreferServerCipherSuites: true,
|
||||
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||
// #nosec G402 -- InsecureSkipVerify is configurable for testing/dev environments
|
||||
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
return &http.Transport{
|
||||
|
||||
@@ -148,6 +148,7 @@ func (cb *CircuitBreaker) allowRequest() bool {
|
||||
// allowHalfOpenRequest checks if a request is allowed in half-open state
|
||||
func (cb *CircuitBreaker) allowHalfOpenRequest() bool {
|
||||
current := atomic.AddInt32(&cb.halfOpenRequests, 1)
|
||||
// #nosec G115 -- MaxRequests is a small config value that fits in int32
|
||||
if current <= int32(cb.config.MaxRequests) {
|
||||
return true
|
||||
}
|
||||
@@ -164,6 +165,7 @@ func (cb *CircuitBreaker) recordFailure() {
|
||||
|
||||
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
|
||||
|
||||
// #nosec G115 -- FailureThreshold is a small config value that fits in int32
|
||||
if state == CircuitBreakerClosed && failures >= int32(cb.config.FailureThreshold) {
|
||||
cb.transitionToOpen()
|
||||
} else if state == CircuitBreakerHalfOpen {
|
||||
@@ -180,6 +182,7 @@ func (cb *CircuitBreaker) recordSuccess() {
|
||||
|
||||
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
|
||||
|
||||
// #nosec G115 -- SuccessThreshold is a small config value that fits in int32
|
||||
if state == CircuitBreakerHalfOpen && successes >= int32(cb.config.SuccessThreshold) {
|
||||
cb.transitionToClosed()
|
||||
}
|
||||
|
||||
@@ -191,6 +191,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
|
||||
}
|
||||
|
||||
// Add jitter
|
||||
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
|
||||
if re.config.RandomizationFactor > 0 {
|
||||
jitter := delay * re.config.RandomizationFactor
|
||||
minDelay := delay - jitter
|
||||
|
||||
@@ -169,7 +169,10 @@ func (jwk *JWK) ToRSAPublicKey() (*rsa.PublicKey, error) {
|
||||
// Pad to 8 bytes for uint64
|
||||
paddedE := make([]byte, 8)
|
||||
copy(paddedE[8-len(eBytes):], eBytes)
|
||||
e = int(binary.BigEndian.Uint64(paddedE))
|
||||
eUint64 := binary.BigEndian.Uint64(paddedE)
|
||||
// RSA exponents are typically small (65537 is common), so overflow is not a concern
|
||||
// #nosec G115 -- RSA public exponents are small values that fit in int
|
||||
e = int(eUint64)
|
||||
} else {
|
||||
return nil, fmt.Errorf("exponent too large")
|
||||
}
|
||||
|
||||
+8
-5
@@ -120,8 +120,9 @@ func NewMemoryMonitor(logger *Logger, thresholds MemoryAlertThresholds) *MemoryM
|
||||
alertThresholds: thresholds,
|
||||
baselineHeap: memStats.HeapAlloc,
|
||||
baselineGoroutines: runtime.NumGoroutine(),
|
||||
lastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
lastGCCount: memStats.NumGC,
|
||||
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
|
||||
lastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
lastGCCount: memStats.NumGC,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,9 +159,10 @@ func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
|
||||
StackSysBytes: memStats.StackSys,
|
||||
GCSysBytes: memStats.GCSys,
|
||||
NumGoroutines: runtime.NumGoroutine(),
|
||||
LastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
GCFrequency: gcFrequency,
|
||||
Timestamp: now,
|
||||
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
|
||||
LastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
GCFrequency: gcFrequency,
|
||||
Timestamp: now,
|
||||
}
|
||||
|
||||
// Get application-specific stats
|
||||
@@ -386,6 +388,7 @@ func (mm *MemoryMonitor) TriggerGC() {
|
||||
|
||||
after := mm.GetCurrentStats()
|
||||
|
||||
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
|
||||
freedBytes := int64(before.HeapAllocBytes) - int64(after.HeapAllocBytes)
|
||||
freedMB := float64(freedBytes) / (1024 * 1024)
|
||||
|
||||
|
||||
@@ -63,6 +63,7 @@ func generateSecureRandomString(length int) (string, error) {
|
||||
}
|
||||
|
||||
// Cookie names and configuration constants used for session management
|
||||
// #nosec G101 -- These are cookie names, not hardcoded credentials
|
||||
const (
|
||||
mainCookieName = "_oidc_raczylo_m"
|
||||
accessTokenCookie = "_oidc_raczylo_a"
|
||||
|
||||
+3
-2
@@ -51,7 +51,8 @@ func NewShardedCache(numShards int, maxSize int) *ShardedCache {
|
||||
}
|
||||
|
||||
return &ShardedCache{
|
||||
shards: shards,
|
||||
shards: shards,
|
||||
// #nosec G115 -- numShards is validated to be positive and small (typically 32-256)
|
||||
numShards: uint32(numShards),
|
||||
maxPerShard: maxPerShard,
|
||||
}
|
||||
@@ -61,7 +62,7 @@ func NewShardedCache(numShards int, maxSize int) *ShardedCache {
|
||||
// FNV-1a is fast and provides good distribution.
|
||||
func (c *ShardedCache) getShard(key string) *cacheShard {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(key))
|
||||
_, _ = h.Write([]byte(key)) // hash.Hash.Write never returns an error
|
||||
return c.shards[h.Sum32()%c.numShards]
|
||||
}
|
||||
|
||||
|
||||
@@ -734,6 +734,7 @@ func (r *TestSuiteRunner) RunMemoryLeakTests(t *testing.T, tests []MemoryLeakTes
|
||||
}
|
||||
|
||||
// Check memory growth
|
||||
// #nosec G115 -- memory stats are within int64 range for practical purposes
|
||||
memoryGrowthBytes := int64(finalMem.Alloc) - int64(initialMem.Alloc)
|
||||
memoryGrowthMB := float64(memoryGrowthBytes) / (1024 * 1024)
|
||||
|
||||
|
||||
+4
@@ -9,3 +9,7 @@ coverage.txt
|
||||
**/coverage.txt
|
||||
.vscode
|
||||
tmp/*
|
||||
*.test
|
||||
|
||||
# maintenanceNotifications upgrade documentation (temporary)
|
||||
maintenanceNotifications/docs/
|
||||
|
||||
+2
-2
@@ -1,8 +1,8 @@
|
||||
GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | sort)
|
||||
REDIS_VERSION ?= 8.2
|
||||
REDIS_VERSION ?= 8.4
|
||||
RE_CLUSTER ?= false
|
||||
RCE_DOCKER ?= true
|
||||
CLIENT_LIBS_TEST_IMAGE ?= redislabs/client-libs-test:8.2.1-pre
|
||||
CLIENT_LIBS_TEST_IMAGE ?= redislabs/client-libs-test:8.4.0
|
||||
|
||||
docker.start:
|
||||
export RE_CLUSTER=$(RE_CLUSTER) && \
|
||||
|
||||
+149
-14
@@ -2,7 +2,7 @@
|
||||
|
||||
[](https://github.com/redis/go-redis/actions)
|
||||
[](https://pkg.go.dev/github.com/redis/go-redis/v9?tab=doc)
|
||||
[](https://redis.uptrace.dev/)
|
||||
[](https://redis.io/docs/latest/develop/clients/go/)
|
||||
[](https://goreportcard.com/report/github.com/redis/go-redis/v9)
|
||||
[](https://codecov.io/github/redis/go-redis)
|
||||
|
||||
@@ -17,15 +17,15 @@
|
||||
## Supported versions
|
||||
|
||||
In `go-redis` we are aiming to support the last three releases of Redis. Currently, this means we do support:
|
||||
- [Redis 7.2](https://raw.githubusercontent.com/redis/redis/7.2/00-RELEASENOTES) - using Redis Stack 7.2 for modules support
|
||||
- [Redis 7.4](https://raw.githubusercontent.com/redis/redis/7.4/00-RELEASENOTES) - using Redis Stack 7.4 for modules support
|
||||
- [Redis 8.0](https://raw.githubusercontent.com/redis/redis/8.0/00-RELEASENOTES) - using Redis CE 8.0 where modules are included
|
||||
- [Redis 8.2](https://raw.githubusercontent.com/redis/redis/8.2/00-RELEASENOTES) - using Redis CE 8.2 where modules are included
|
||||
- [Redis 8.0](https://raw.githubusercontent.com/redis/redis/8.0/00-RELEASENOTES) - using Redis CE 8.0
|
||||
- [Redis 8.2](https://raw.githubusercontent.com/redis/redis/8.2/00-RELEASENOTES) - using Redis CE 8.2
|
||||
- [Redis 8.4](https://raw.githubusercontent.com/redis/redis/8.4/00-RELEASENOTES) - using Redis CE 8.4
|
||||
|
||||
Although the `go.mod` states it requires at minimum `go 1.18`, our CI is configured to run the tests against all three
|
||||
versions of Redis and latest two versions of Go ([1.23](https://go.dev/doc/devel/release#go1.23.0),
|
||||
[1.24](https://go.dev/doc/devel/release#go1.24.0)). We observe that some modules related test may not pass with
|
||||
Redis Stack 7.2 and some commands are changed with Redis CE 8.0.
|
||||
Although it is not officially supported, `go-redis/v9` should be able to work with any Redis 7.0+.
|
||||
Please do refer to the documentation and the tests if you experience any issues. We do plan to update the go version
|
||||
in the `go.mod` to `go 1.24` in one of the next releases.
|
||||
|
||||
@@ -43,10 +43,6 @@ in the `go.mod` to `go 1.24` in one of the next releases.
|
||||
|
||||
[Work at Redis](https://redis.com/company/careers/jobs/)
|
||||
|
||||
## Documentation
|
||||
|
||||
- [English](https://redis.uptrace.dev)
|
||||
- [简体中文](https://redis.uptrace.dev/zh/)
|
||||
|
||||
## Resources
|
||||
|
||||
@@ -55,16 +51,18 @@ in the `go.mod` to `go 1.24` in one of the next releases.
|
||||
- [Reference](https://pkg.go.dev/github.com/redis/go-redis/v9)
|
||||
- [Examples](https://pkg.go.dev/github.com/redis/go-redis/v9#pkg-examples)
|
||||
|
||||
## old documentation
|
||||
|
||||
- [English](https://redis.uptrace.dev)
|
||||
- [简体中文](https://redis.uptrace.dev/zh/)
|
||||
|
||||
## Ecosystem
|
||||
|
||||
- [Redis Mock](https://github.com/go-redis/redismock)
|
||||
- [Entra ID (Azure AD)](https://github.com/redis/go-redis-entraid)
|
||||
- [Distributed Locks](https://github.com/bsm/redislock)
|
||||
- [Redis Cache](https://github.com/go-redis/cache)
|
||||
- [Rate limiting](https://github.com/go-redis/redis_rate)
|
||||
|
||||
This client also works with [Kvrocks](https://github.com/apache/incubator-kvrocks), a distributed
|
||||
key value NoSQL database that uses RocksDB as storage engine and is compatible with Redis protocol.
|
||||
|
||||
## Features
|
||||
|
||||
- Redis commands except QUIT and SYNC.
|
||||
@@ -75,7 +73,6 @@ key value NoSQL database that uses RocksDB as storage engine and is compatible w
|
||||
- [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html).
|
||||
- [Redis Sentinel](https://redis.uptrace.dev/guide/go-redis-sentinel.html).
|
||||
- [Redis Cluster](https://redis.uptrace.dev/guide/go-redis-cluster.html).
|
||||
- [Redis Ring](https://redis.uptrace.dev/guide/ring.html).
|
||||
- [Redis Performance Monitoring](https://redis.uptrace.dev/guide/redis-performance-monitoring.html).
|
||||
- [Redis Probabilistic [RedisStack]](https://redis.io/docs/data-types/probabilistic/)
|
||||
- [Customizable read and write buffers size.](#custom-buffer-sizes)
|
||||
@@ -429,6 +426,144 @@ vals, err := rdb.Eval(ctx, "return {KEYS[1],ARGV[1]}", []string{"key"}, "hello")
|
||||
res, err := rdb.Do(ctx, "set", "key", "value").Result()
|
||||
```
|
||||
|
||||
## Typed Errors
|
||||
|
||||
go-redis provides typed error checking functions for common Redis errors:
|
||||
|
||||
```go
|
||||
// Cluster and replication errors
|
||||
redis.IsLoadingError(err) // Redis is loading the dataset
|
||||
redis.IsReadOnlyError(err) // Write to read-only replica
|
||||
redis.IsClusterDownError(err) // Cluster is down
|
||||
redis.IsTryAgainError(err) // Command should be retried
|
||||
redis.IsMasterDownError(err) // Master is down
|
||||
redis.IsMovedError(err) // Returns (address, true) if key moved
|
||||
redis.IsAskError(err) // Returns (address, true) if key being migrated
|
||||
|
||||
// Connection and resource errors
|
||||
redis.IsMaxClientsError(err) // Maximum clients reached
|
||||
redis.IsAuthError(err) // Authentication failed (NOAUTH, WRONGPASS, unauthenticated)
|
||||
redis.IsPermissionError(err) // Permission denied (NOPERM)
|
||||
redis.IsOOMError(err) // Out of memory (OOM)
|
||||
|
||||
// Transaction errors
|
||||
redis.IsExecAbortError(err) // Transaction aborted (EXECABORT)
|
||||
```
|
||||
|
||||
### Error Wrapping in Hooks
|
||||
|
||||
When wrapping errors in hooks, use custom error types with `Unwrap()` method (preferred) or `fmt.Errorf` with `%w`. Always call `cmd.SetErr()` to preserve error type information:
|
||||
|
||||
```go
|
||||
// Custom error type (preferred)
|
||||
type AppError struct {
|
||||
Code string
|
||||
RequestID string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *AppError) Error() string {
|
||||
return fmt.Sprintf("[%s] request_id=%s: %v", e.Code, e.RequestID, e.Err)
|
||||
}
|
||||
|
||||
func (e *AppError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// Hook implementation
|
||||
func (h MyHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
|
||||
return func(ctx context.Context, cmd redis.Cmder) error {
|
||||
err := next(ctx, cmd)
|
||||
if err != nil {
|
||||
// Wrap with custom error type
|
||||
wrappedErr := &AppError{
|
||||
Code: "REDIS_ERROR",
|
||||
RequestID: getRequestID(ctx),
|
||||
Err: err,
|
||||
}
|
||||
cmd.SetErr(wrappedErr)
|
||||
return wrappedErr // Return wrapped error to preserve it
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Typed error detection works through wrappers
|
||||
if redis.IsLoadingError(err) {
|
||||
// Retry logic
|
||||
}
|
||||
|
||||
// Extract custom error if needed
|
||||
var appErr *AppError
|
||||
if errors.As(err, &appErr) {
|
||||
log.Printf("Request: %s", appErr.RequestID)
|
||||
}
|
||||
```
|
||||
|
||||
Alternatively, use `fmt.Errorf` with `%w`:
|
||||
```go
|
||||
wrappedErr := fmt.Errorf("context: %w", err)
|
||||
cmd.SetErr(wrappedErr)
|
||||
```
|
||||
|
||||
### Pipeline Hook Example
|
||||
|
||||
For pipeline operations, use `ProcessPipelineHook`:
|
||||
|
||||
```go
|
||||
type PipelineLoggingHook struct{}
|
||||
|
||||
func (h PipelineLoggingHook) DialHook(next redis.DialHook) redis.DialHook {
|
||||
return next
|
||||
}
|
||||
|
||||
func (h PipelineLoggingHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
|
||||
return next
|
||||
}
|
||||
|
||||
func (h PipelineLoggingHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||
start := time.Now()
|
||||
|
||||
// Execute the pipeline
|
||||
err := next(ctx, cmds)
|
||||
|
||||
duration := time.Since(start)
|
||||
log.Printf("Pipeline executed %d commands in %v", len(cmds), duration)
|
||||
|
||||
// Process individual command errors
|
||||
// Note: Individual command errors are already set on each cmd by the pipeline execution
|
||||
for _, cmd := range cmds {
|
||||
if cmdErr := cmd.Err(); cmdErr != nil {
|
||||
// Check for specific error types using typed error functions
|
||||
if redis.IsAuthError(cmdErr) {
|
||||
log.Printf("Auth error in pipeline command %s: %v", cmd.Name(), cmdErr)
|
||||
} else if redis.IsPermissionError(cmdErr) {
|
||||
log.Printf("Permission error in pipeline command %s: %v", cmd.Name(), cmdErr)
|
||||
}
|
||||
|
||||
// Optionally wrap individual command errors to add context
|
||||
// The wrapped error preserves type information through errors.As()
|
||||
wrappedErr := fmt.Errorf("pipeline cmd %s failed: %w", cmd.Name(), cmdErr)
|
||||
cmd.SetErr(wrappedErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Return the pipeline-level error (connection errors, etc.)
|
||||
// You can wrap it if needed, or return it as-is
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Register the hook
|
||||
rdb.AddHook(PipelineLoggingHook{})
|
||||
|
||||
// Use pipeline - errors are still properly typed
|
||||
pipe := rdb.Pipeline()
|
||||
pipe.Set(ctx, "key1", "value1", 0)
|
||||
pipe.Get(ctx, "key2")
|
||||
_, err := pipe.Exec(ctx)
|
||||
```
|
||||
|
||||
## Run the test
|
||||
|
||||
|
||||
+224
@@ -1,5 +1,229 @@
|
||||
# Release Notes
|
||||
|
||||
# 9.17.2 (2025-12-01)
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- **Connection Pool**: Fixed critical race condition in turn management that could cause connection leaks when dial goroutines complete after request timeout ([#3626](https://github.com/redis/go-redis/pull/3626)) by [@cyningsun](https://github.com/cyningsun)
|
||||
- **Context Timeout**: Improved context timeout calculation to use minimum of remaining time and DialTimeout, preventing goroutines from waiting longer than necessary ([#3626](https://github.com/redis/go-redis/pull/3626)) by [@cyningsun](https://github.com/cyningsun)
|
||||
|
||||
## 🧰 Maintenance
|
||||
|
||||
- chore(deps): bump rojopolis/spellcheck-github-actions from 0.54.0 to 0.55.0 ([#3627](https://github.com/redis/go-redis/pull/3627))
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@cyningsun](https://github.com/cyningsun) and [@ndyakov](https://github.com/ndyakov)
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.17.1...v9.17.2
|
||||
|
||||
# 9.17.1 (2025-11-25)
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- add wait to keyless commands list ([#3615](https://github.com/redis/go-redis/pull/3615)) by [@marcoferrer](https://github.com/marcoferrer)
|
||||
- fix(time): remove cached time optimization ([#3611](https://github.com/redis/go-redis/pull/3611)) by [@ndyakov](https://github.com/ndyakov)
|
||||
|
||||
## 🧰 Maintenance
|
||||
|
||||
- chore(deps): bump golangci/golangci-lint-action from 9.0.0 to 9.1.0 ([#3609](https://github.com/redis/go-redis/pull/3609))
|
||||
- chore(deps): bump actions/checkout from 5 to 6 ([#3610](https://github.com/redis/go-redis/pull/3610))
|
||||
- chore(script): fix help call in tag.sh ([#3606](https://github.com/redis/go-redis/pull/3606)) by [@ndyakov](https://github.com/ndyakov)
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@marcoferrer](https://github.com/marcoferrer) and [@ndyakov](https://github.com/ndyakov)
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.17.0...v9.17.1
|
||||
|
||||
# 9.17.0 (2025-11-19)
|
||||
|
||||
## 🚀 Highlights
|
||||
|
||||
### Redis 8.4 Support
|
||||
Added support for Redis 8.4, including new commands and features ([#3572](https://github.com/redis/go-redis/pull/3572))
|
||||
|
||||
### Typed Errors
|
||||
Introduced typed errors for better error handling using `errors.As` instead of string checks. Errors can now be wrapped and set to commands in hooks without breaking library functionality ([#3602](https://github.com/redis/go-redis/pull/3602))
|
||||
|
||||
### New Commands
|
||||
- **CAS/CAD Commands**: Added support for Compare-And-Set/Compare-And-Delete operations with conditional matching (`IFEQ`, `IFNE`, `IFDEQ`, `IFDNE`) ([#3583](https://github.com/redis/go-redis/pull/3583), [#3595](https://github.com/redis/go-redis/pull/3595))
|
||||
- **MSETEX**: Atomically set multiple key-value pairs with expiration options and conditional modes ([#3580](https://github.com/redis/go-redis/pull/3580))
|
||||
- **XReadGroup CLAIM**: Consume both incoming and idle pending entries from streams in a single call ([#3578](https://github.com/redis/go-redis/pull/3578))
|
||||
- **ACL Commands**: Added `ACLGenPass`, `ACLUsers`, and `ACLWhoAmI` ([#3576](https://github.com/redis/go-redis/pull/3576))
|
||||
- **SLOWLOG Commands**: Added `SLOWLOG LEN` and `SLOWLOG RESET` ([#3585](https://github.com/redis/go-redis/pull/3585))
|
||||
- **LATENCY Commands**: Added `LATENCY LATEST` and `LATENCY RESET` ([#3584](https://github.com/redis/go-redis/pull/3584))
|
||||
|
||||
### Search & Vector Improvements
|
||||
- **Hybrid Search**: Added **EXPERIMENTAL** support for the new `FT.HYBRID` command ([#3573](https://github.com/redis/go-redis/pull/3573))
|
||||
- **Vector Range**: Added `VRANGE` command for vector sets ([#3543](https://github.com/redis/go-redis/pull/3543))
|
||||
- **FT.INFO Enhancements**: Added vector-specific attributes in FT.INFO response ([#3596](https://github.com/redis/go-redis/pull/3596))
|
||||
|
||||
### Connection Pool Improvements
|
||||
- **Improved Connection Success Rate**: Implemented FIFO queue-based fairness and context pattern for connection creation to prevent premature cancellation under high concurrency ([#3518](https://github.com/redis/go-redis/pull/3518))
|
||||
- **Connection State Machine**: Resolved race conditions and improved pool performance with proper state tracking ([#3559](https://github.com/redis/go-redis/pull/3559))
|
||||
- **Pool Performance**: Significant performance improvements with faster semaphores, lockless hook manager, and reduced allocations (47-67% faster Get/Put operations) ([#3565](https://github.com/redis/go-redis/pull/3565))
|
||||
|
||||
### Metrics & Observability
|
||||
- **Canceled Metric Attribute**: Added 'canceled' metrics attribute to distinguish context cancellation errors from other errors ([#3566](https://github.com/redis/go-redis/pull/3566))
|
||||
|
||||
## ✨ New Features
|
||||
|
||||
- Typed errors with wrapping support ([#3602](https://github.com/redis/go-redis/pull/3602)) by [@ndyakov](https://github.com/ndyakov)
|
||||
- CAS/CAD commands (marked as experimental) ([#3583](https://github.com/redis/go-redis/pull/3583), [#3595](https://github.com/redis/go-redis/pull/3595)) by [@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis)
|
||||
- MSETEX command support ([#3580](https://github.com/redis/go-redis/pull/3580)) by [@ofekshenawa](https://github.com/ofekshenawa)
|
||||
- XReadGroup CLAIM argument ([#3578](https://github.com/redis/go-redis/pull/3578)) by [@ofekshenawa](https://github.com/ofekshenawa)
|
||||
- ACL commands: GenPass, Users, WhoAmI ([#3576](https://github.com/redis/go-redis/pull/3576)) by [@destinyoooo](https://github.com/destinyoooo)
|
||||
- SLOWLOG commands: LEN, RESET ([#3585](https://github.com/redis/go-redis/pull/3585)) by [@destinyoooo](https://github.com/destinyoooo)
|
||||
- LATENCY commands: LATEST, RESET ([#3584](https://github.com/redis/go-redis/pull/3584)) by [@destinyoooo](https://github.com/destinyoooo)
|
||||
- Hybrid search command (FT.HYBRID) ([#3573](https://github.com/redis/go-redis/pull/3573)) by [@htemelski-redis](https://github.com/htemelski-redis)
|
||||
- Vector range command (VRANGE) ([#3543](https://github.com/redis/go-redis/pull/3543)) by [@cxljs](https://github.com/cxljs)
|
||||
- Vector-specific attributes in FT.INFO ([#3596](https://github.com/redis/go-redis/pull/3596)) by [@ndyakov](https://github.com/ndyakov)
|
||||
- Improved connection pool success rate with FIFO queue ([#3518](https://github.com/redis/go-redis/pull/3518)) by [@cyningsun](https://github.com/cyningsun)
|
||||
- Canceled metrics attribute for context errors ([#3566](https://github.com/redis/go-redis/pull/3566)) by [@pvragov](https://github.com/pvragov)
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- Fixed Failover Client MaintNotificationsConfig ([#3600](https://github.com/redis/go-redis/pull/3600)) by [@ajax16384](https://github.com/ajax16384)
|
||||
- Fixed ACLGenPass function to use the bit parameter ([#3597](https://github.com/redis/go-redis/pull/3597)) by [@destinyoooo](https://github.com/destinyoooo)
|
||||
- Return error instead of panic from commands ([#3568](https://github.com/redis/go-redis/pull/3568)) by [@dragneelfps](https://github.com/dragneelfps)
|
||||
- Safety harness in `joinErrors` to prevent panic ([#3577](https://github.com/redis/go-redis/pull/3577)) by [@manisharma](https://github.com/manisharma)
|
||||
|
||||
## ⚡ Performance
|
||||
|
||||
- Connection state machine with race condition fixes ([#3559](https://github.com/redis/go-redis/pull/3559)) by [@ndyakov](https://github.com/ndyakov)
|
||||
- Pool performance improvements: 47-67% faster Get/Put, 33% less memory, 50% fewer allocations ([#3565](https://github.com/redis/go-redis/pull/3565)) by [@ndyakov](https://github.com/ndyakov)
|
||||
|
||||
## 🧪 Testing & Infrastructure
|
||||
|
||||
- Updated to Redis 8.4.0 image ([#3603](https://github.com/redis/go-redis/pull/3603)) by [@ndyakov](https://github.com/ndyakov)
|
||||
- Added Redis 8.4-RC1-pre to CI ([#3572](https://github.com/redis/go-redis/pull/3572)) by [@ndyakov](https://github.com/ndyakov)
|
||||
- Refactored tests for idiomatic Go ([#3561](https://github.com/redis/go-redis/pull/3561), [#3562](https://github.com/redis/go-redis/pull/3562), [#3563](https://github.com/redis/go-redis/pull/3563)) by [@12ya](https://github.com/12ya)
|
||||
|
||||
## 👥 Contributors
|
||||
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@12ya](https://github.com/12ya), [@ajax16384](https://github.com/ajax16384), [@cxljs](https://github.com/cxljs), [@cyningsun](https://github.com/cyningsun), [@destinyoooo](https://github.com/destinyoooo), [@dragneelfps](https://github.com/dragneelfps), [@htemelski-redis](https://github.com/htemelski-redis), [@manisharma](https://github.com/manisharma), [@ndyakov](https://github.com/ndyakov), [@ofekshenawa](https://github.com/ofekshenawa), [@pvragov](https://github.com/pvragov)
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.16.0...v9.17.0
|
||||
|
||||
# 9.16.0 (2025-10-23)
|
||||
|
||||
## 🚀 Highlights
|
||||
|
||||
### Maintenance Notifications Support
|
||||
|
||||
This release introduces comprehensive support for Redis maintenance notifications, enabling applications to handle server maintenance events gracefully. The new `maintnotifications` package provides:
|
||||
|
||||
- **RESP3 Push Notifications**: Full support for Redis RESP3 protocol push notifications
|
||||
- **Connection Handoff**: Automatic connection migration during server maintenance with configurable retry policies and circuit breakers
|
||||
- **Graceful Degradation**: Configurable timeout relaxation during maintenance windows to prevent false failures
|
||||
- **Event-Driven Architecture**: Background workers with on-demand scaling for efficient handoff processing
|
||||
- **Production-Ready**: Comprehensive E2E testing framework and monitoring capabilities
|
||||
|
||||
For detailed usage examples and configuration options, see the [maintenance notifications documentation](maintnotifications/README.md).
|
||||
|
||||
## ✨ New Features
|
||||
|
||||
- **Trace Filtering**: Add support for filtering traces for specific commands, including pipeline operations and dial operations ([#3519](https://github.com/redis/go-redis/pull/3519), [#3550](https://github.com/redis/go-redis/pull/3550))
|
||||
- New `TraceCmdFilter` option to selectively trace commands
|
||||
- Reduces overhead by excluding high-frequency or low-value commands from traces
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- **Pipeline Error Handling**: Fix issue where pipeline repeatedly sets the same error ([#3525](https://github.com/redis/go-redis/pull/3525))
|
||||
- **Connection Pool**: Ensure re-authentication does not interfere with connection handoff operations ([#3547](https://github.com/redis/go-redis/pull/3547))
|
||||
|
||||
## 🔧 Improvements
|
||||
|
||||
- **Hash Commands**: Update hash command implementations ([#3523](https://github.com/redis/go-redis/pull/3523))
|
||||
- **OpenTelemetry**: Use `metric.WithAttributeSet` to avoid unnecessary attribute copying in redisotel ([#3552](https://github.com/redis/go-redis/pull/3552))
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- **Cluster Client**: Add explanation for why `MaxRetries` is disabled for `ClusterClient` ([#3551](https://github.com/redis/go-redis/pull/3551))
|
||||
|
||||
## 🧪 Testing & Infrastructure
|
||||
|
||||
- **E2E Testing**: Upgrade E2E testing framework with improved reliability and coverage ([#3541](https://github.com/redis/go-redis/pull/3541))
|
||||
- **Release Process**: Improved resiliency of the release process ([#3530](https://github.com/redis/go-redis/pull/3530))
|
||||
|
||||
## 📦 Dependencies
|
||||
|
||||
- Bump `rojopolis/spellcheck-github-actions` from 0.51.0 to 0.52.0 ([#3520](https://github.com/redis/go-redis/pull/3520))
|
||||
- Bump `github/codeql-action` from 3 to 4 ([#3544](https://github.com/redis/go-redis/pull/3544))
|
||||
|
||||
## 👥 Contributors
|
||||
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), [@Sovietaced](https://github.com/Sovietaced), [@Udhayarajan](https://github.com/Udhayarajan), [@boekkooi-impossiblecloud](https://github.com/boekkooi-impossiblecloud), [@Pika-Gopher](https://github.com/Pika-Gopher), [@cxljs](https://github.com/cxljs), [@huiyifyj](https://github.com/huiyifyj), [@omid-h70](https://github.com/omid-h70)
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.14.0...v9.16.0
|
||||
|
||||
|
||||
# 9.15.0 was accidentally released. Please use version 9.16.0 instead.
|
||||
|
||||
# 9.15.0-beta.3 (2025-09-26)
|
||||
|
||||
## Highlights
|
||||
This beta release includes a pre-production version of processing push notifications and hitless upgrades.
|
||||
|
||||
# Changes
|
||||
|
||||
- chore: Update hash_commands.go ([#3523](https://github.com/redis/go-redis/pull/3523))
|
||||
|
||||
## 🚀 New Features
|
||||
|
||||
- feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418))
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- fix: pipeline repeatedly sets the error ([#3525](https://github.com/redis/go-redis/pull/3525))
|
||||
|
||||
## 🧰 Maintenance
|
||||
|
||||
- chore(deps): bump rojopolis/spellcheck-github-actions from 0.51.0 to 0.52.0 ([#3520](https://github.com/redis/go-redis/pull/3520))
|
||||
- feat(e2e-testing): maintnotifications e2e and refactor ([#3526](https://github.com/redis/go-redis/pull/3526))
|
||||
- feat(tag.sh): Improved resiliency of the release process ([#3530](https://github.com/redis/go-redis/pull/3530))
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@cxljs](https://github.com/cxljs), [@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), and [@omid-h70](https://github.com/omid-h70)
|
||||
|
||||
|
||||
# 9.15.0-beta.1 (2025-09-10)
|
||||
|
||||
## Highlights
|
||||
This beta release includes a pre-production version of processing push notifications and hitless upgrades.
|
||||
|
||||
### Hitless Upgrades
|
||||
Hitless upgrades is a major new feature that allows for zero-downtime upgrades in Redis clusters.
|
||||
You can find more information in the [Hitless Upgrades documentation](https://github.com/redis/go-redis/tree/master/hitless).
|
||||
|
||||
# Changes
|
||||
|
||||
## 🚀 New Features
|
||||
- [CAE-1088] & [CAE-1072] feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418))
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), [@ofekshenawa](https://github.com/ofekshenawa)
|
||||
|
||||
|
||||
# 9.14.0 (2025-09-10)
|
||||
|
||||
## Highlights
|
||||
|
||||
+27
@@ -8,8 +8,12 @@ type ACLCmdable interface {
|
||||
ACLLog(ctx context.Context, count int64) *ACLLogCmd
|
||||
ACLLogReset(ctx context.Context) *StatusCmd
|
||||
|
||||
ACLGenPass(ctx context.Context, bit int) *StringCmd
|
||||
|
||||
ACLSetUser(ctx context.Context, username string, rules ...string) *StatusCmd
|
||||
ACLDelUser(ctx context.Context, username string) *IntCmd
|
||||
ACLUsers(ctx context.Context) *StringSliceCmd
|
||||
ACLWhoAmI(ctx context.Context) *StringCmd
|
||||
ACLList(ctx context.Context) *StringSliceCmd
|
||||
|
||||
ACLCat(ctx context.Context) *StringSliceCmd
|
||||
@@ -65,6 +69,29 @@ func (c cmdable) ACLSetUser(ctx context.Context, username string, rules ...strin
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) ACLGenPass(ctx context.Context, bit int) *StringCmd {
|
||||
args := make([]interface{}, 0, 3)
|
||||
args = append(args, "acl", "genpass")
|
||||
if bit > 0 {
|
||||
args = append(args, bit)
|
||||
}
|
||||
cmd := NewStringCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) ACLUsers(ctx context.Context) *StringSliceCmd {
|
||||
cmd := NewStringSliceCmd(ctx, "acl", "users")
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) ACLWhoAmI(ctx context.Context) *StringCmd {
|
||||
cmd := NewStringCmd(ctx, "acl", "whoami")
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) ACLList(ctx context.Context) *StringSliceCmd {
|
||||
cmd := NewStringSliceCmd(ctx, "acl", "list")
|
||||
_ = c(ctx, cmd)
|
||||
|
||||
+111
@@ -0,0 +1,111 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
|
||||
var ErrInvalidCommand = errors.New("invalid command type")
|
||||
|
||||
// ErrInvalidPool is returned when the pool type is not supported.
|
||||
var ErrInvalidPool = errors.New("invalid pool type")
|
||||
|
||||
// newClientAdapter creates a new client adapter for regular Redis clients.
|
||||
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
|
||||
return &clientAdapter{client: client}
|
||||
}
|
||||
|
||||
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
|
||||
type clientAdapter struct {
|
||||
client *baseClient
|
||||
}
|
||||
|
||||
// GetOptions returns the client options.
|
||||
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
|
||||
return &optionsAdapter{options: ca.client.opt}
|
||||
}
|
||||
|
||||
// GetPushProcessor returns the client's push notification processor.
|
||||
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
|
||||
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
|
||||
}
|
||||
|
||||
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
|
||||
type optionsAdapter struct {
|
||||
options *Options
|
||||
}
|
||||
|
||||
// GetReadTimeout returns the read timeout.
|
||||
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
|
||||
return oa.options.ReadTimeout
|
||||
}
|
||||
|
||||
// GetWriteTimeout returns the write timeout.
|
||||
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
|
||||
return oa.options.WriteTimeout
|
||||
}
|
||||
|
||||
// GetNetwork returns the network type.
|
||||
func (oa *optionsAdapter) GetNetwork() string {
|
||||
return oa.options.Network
|
||||
}
|
||||
|
||||
// GetAddr returns the connection address.
|
||||
func (oa *optionsAdapter) GetAddr() string {
|
||||
return oa.options.Addr
|
||||
}
|
||||
|
||||
// IsTLSEnabled returns true if TLS is enabled.
|
||||
func (oa *optionsAdapter) IsTLSEnabled() bool {
|
||||
return oa.options.TLSConfig != nil
|
||||
}
|
||||
|
||||
// GetProtocol returns the protocol version.
|
||||
func (oa *optionsAdapter) GetProtocol() int {
|
||||
return oa.options.Protocol
|
||||
}
|
||||
|
||||
// GetPoolSize returns the connection pool size.
|
||||
func (oa *optionsAdapter) GetPoolSize() int {
|
||||
return oa.options.PoolSize
|
||||
}
|
||||
|
||||
// NewDialer returns a new dialer function for the connection.
|
||||
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
|
||||
baseDialer := oa.options.NewDialer()
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
// Extract network and address from the options
|
||||
network := oa.options.Network
|
||||
addr := oa.options.Addr
|
||||
return baseDialer(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
|
||||
type pushProcessorAdapter struct {
|
||||
processor push.NotificationProcessor
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
|
||||
if pushHandler, ok := handler.(push.NotificationHandler); ok {
|
||||
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
|
||||
}
|
||||
return errors.New("handler must implement push.NotificationHandler")
|
||||
}
|
||||
|
||||
// UnregisterHandler removes a handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
|
||||
return ppa.processor.UnregisterHandler(pushNotificationName)
|
||||
}
|
||||
|
||||
// GetHandler returns the handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
|
||||
return ppa.processor.GetHandler(pushNotificationName)
|
||||
}
|
||||
+6
-2
@@ -141,7 +141,9 @@ func (c cmdable) BitPos(ctx context.Context, key string, bit int64, pos ...int64
|
||||
args[3] = pos[0]
|
||||
args[4] = pos[1]
|
||||
default:
|
||||
panic("too many arguments")
|
||||
cmd := NewIntCmd(ctx)
|
||||
cmd.SetErr(errors.New("too many arguments"))
|
||||
return cmd
|
||||
}
|
||||
cmd := NewIntCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
@@ -182,7 +184,9 @@ func (c cmdable) BitFieldRO(ctx context.Context, key string, values ...interface
|
||||
args[0] = "BITFIELD_RO"
|
||||
args[1] = key
|
||||
if len(values)%2 != 0 {
|
||||
panic("BitFieldRO: invalid number of arguments, must be even")
|
||||
c := NewIntSliceCmd(ctx)
|
||||
c.SetErr(errors.New("BitFieldRO: invalid number of arguments, must be even"))
|
||||
return c
|
||||
}
|
||||
for i := 0; i < len(values); i += 2 {
|
||||
args = append(args, "GET", values[i], values[i+1])
|
||||
|
||||
+169
-3
@@ -64,6 +64,7 @@ var keylessCommands = map[string]struct{}{
|
||||
"sync": {},
|
||||
"unsubscribe": {},
|
||||
"unwatch": {},
|
||||
"wait": {},
|
||||
}
|
||||
|
||||
type Cmder interface {
|
||||
@@ -698,6 +699,68 @@ func (cmd *IntCmd) readReply(rd *proto.Reader) (err error) {
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// DigestCmd is a command that returns a uint64 xxh3 hash digest.
|
||||
//
|
||||
// This command is specifically designed for the Redis DIGEST command,
|
||||
// which returns the xxh3 hash of a key's value as a hex string.
|
||||
// The hex string is automatically parsed to a uint64 value.
|
||||
//
|
||||
// The digest can be used for optimistic locking with SetIFDEQ, SetIFDNE,
|
||||
// and DelExArgs commands.
|
||||
//
|
||||
// For examples of client-side digest generation and usage patterns, see:
|
||||
// example/digest-optimistic-locking/
|
||||
//
|
||||
// Redis 8.4+. See https://redis.io/commands/digest/
|
||||
type DigestCmd struct {
|
||||
baseCmd
|
||||
|
||||
val uint64
|
||||
}
|
||||
|
||||
var _ Cmder = (*DigestCmd)(nil)
|
||||
|
||||
func NewDigestCmd(ctx context.Context, args ...interface{}) *DigestCmd {
|
||||
return &DigestCmd{
|
||||
baseCmd: baseCmd{
|
||||
ctx: ctx,
|
||||
args: args,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *DigestCmd) SetVal(val uint64) {
|
||||
cmd.val = val
|
||||
}
|
||||
|
||||
func (cmd *DigestCmd) Val() uint64 {
|
||||
return cmd.val
|
||||
}
|
||||
|
||||
func (cmd *DigestCmd) Result() (uint64, error) {
|
||||
return cmd.val, cmd.err
|
||||
}
|
||||
|
||||
func (cmd *DigestCmd) String() string {
|
||||
return cmdString(cmd, cmd.val)
|
||||
}
|
||||
|
||||
func (cmd *DigestCmd) readReply(rd *proto.Reader) (err error) {
|
||||
// Redis DIGEST command returns a hex string (e.g., "a1b2c3d4e5f67890")
|
||||
// We parse it as a uint64 xxh3 hash value
|
||||
var hexStr string
|
||||
hexStr, err = rd.ReadString()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse hex string to uint64
|
||||
cmd.val, err = strconv.ParseUint(hexStr, 16, 64)
|
||||
return err
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type IntSliceCmd struct {
|
||||
baseCmd
|
||||
|
||||
@@ -1585,6 +1648,12 @@ func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error {
|
||||
type XMessage struct {
|
||||
ID string
|
||||
Values map[string]interface{}
|
||||
// MillisElapsedFromDelivery is the number of milliseconds since the entry was last delivered.
|
||||
// Only populated when using XREADGROUP with CLAIM argument for claimed entries.
|
||||
MillisElapsedFromDelivery int64
|
||||
// DeliveredCount is the number of times the entry was delivered.
|
||||
// Only populated when using XREADGROUP with CLAIM argument for claimed entries.
|
||||
DeliveredCount int64
|
||||
}
|
||||
|
||||
type XMessageSliceCmd struct {
|
||||
@@ -1641,10 +1710,16 @@ func readXMessageSlice(rd *proto.Reader) ([]XMessage, error) {
|
||||
}
|
||||
|
||||
func readXMessage(rd *proto.Reader) (XMessage, error) {
|
||||
if err := rd.ReadFixedArrayLen(2); err != nil {
|
||||
// Read array length can be 2 or 4 (with CLAIM metadata)
|
||||
n, err := rd.ReadArrayLen()
|
||||
if err != nil {
|
||||
return XMessage{}, err
|
||||
}
|
||||
|
||||
if n != 2 && n != 4 {
|
||||
return XMessage{}, fmt.Errorf("redis: got %d elements in the XMessage array, expected 2 or 4", n)
|
||||
}
|
||||
|
||||
id, err := rd.ReadString()
|
||||
if err != nil {
|
||||
return XMessage{}, err
|
||||
@@ -1657,10 +1732,24 @@ func readXMessage(rd *proto.Reader) (XMessage, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return XMessage{
|
||||
msg := XMessage{
|
||||
ID: id,
|
||||
Values: v,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if n == 4 {
|
||||
msg.MillisElapsedFromDelivery, err = rd.ReadInt()
|
||||
if err != nil {
|
||||
return XMessage{}, err
|
||||
}
|
||||
|
||||
msg.DeliveredCount, err = rd.ReadInt()
|
||||
if err != nil {
|
||||
return XMessage{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func stringInterfaceMapParser(rd *proto.Reader) (map[string]interface{}, error) {
|
||||
@@ -3768,6 +3857,83 @@ func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error {
|
||||
|
||||
//-----------------------------------------------------------------------
|
||||
|
||||
type Latency struct {
|
||||
Name string
|
||||
Time time.Time
|
||||
Latest time.Duration
|
||||
Max time.Duration
|
||||
}
|
||||
|
||||
type LatencyCmd struct {
|
||||
baseCmd
|
||||
val []Latency
|
||||
}
|
||||
|
||||
var _ Cmder = (*LatencyCmd)(nil)
|
||||
|
||||
func NewLatencyCmd(ctx context.Context, args ...interface{}) *LatencyCmd {
|
||||
return &LatencyCmd{
|
||||
baseCmd: baseCmd{
|
||||
ctx: ctx,
|
||||
args: args,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *LatencyCmd) SetVal(val []Latency) {
|
||||
cmd.val = val
|
||||
}
|
||||
|
||||
func (cmd *LatencyCmd) Val() []Latency {
|
||||
return cmd.val
|
||||
}
|
||||
|
||||
func (cmd *LatencyCmd) Result() ([]Latency, error) {
|
||||
return cmd.val, cmd.err
|
||||
}
|
||||
|
||||
func (cmd *LatencyCmd) String() string {
|
||||
return cmdString(cmd, cmd.val)
|
||||
}
|
||||
|
||||
func (cmd *LatencyCmd) readReply(rd *proto.Reader) error {
|
||||
n, err := rd.ReadArrayLen()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.val = make([]Latency, n)
|
||||
for i := 0; i < len(cmd.val); i++ {
|
||||
nn, err := rd.ReadArrayLen()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nn < 3 {
|
||||
return fmt.Errorf("redis: got %d elements in latency get, expected at least 3", nn)
|
||||
}
|
||||
if cmd.val[i].Name, err = rd.ReadString(); err != nil {
|
||||
return err
|
||||
}
|
||||
createdAt, err := rd.ReadInt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.val[i].Time = time.Unix(createdAt, 0)
|
||||
latest, err := rd.ReadInt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.val[i].Latest = time.Duration(latest) * time.Millisecond
|
||||
maximum, err := rd.ReadInt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.val[i].Max = time.Duration(maximum) * time.Millisecond
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------
|
||||
|
||||
type MapStringInterfaceCmd struct {
|
||||
baseCmd
|
||||
|
||||
|
||||
+53
-1
@@ -193,6 +193,7 @@ type Cmdable interface {
|
||||
ClientID(ctx context.Context) *IntCmd
|
||||
ClientUnblock(ctx context.Context, id int64) *IntCmd
|
||||
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
|
||||
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
|
||||
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
|
||||
ConfigResetStat(ctx context.Context) *StatusCmd
|
||||
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
|
||||
@@ -210,9 +211,13 @@ type Cmdable interface {
|
||||
ShutdownNoSave(ctx context.Context) *StatusCmd
|
||||
SlaveOf(ctx context.Context, host, port string) *StatusCmd
|
||||
SlowLogGet(ctx context.Context, num int64) *SlowLogCmd
|
||||
SlowLogLen(ctx context.Context) *IntCmd
|
||||
SlowLogReset(ctx context.Context) *StatusCmd
|
||||
Time(ctx context.Context) *TimeCmd
|
||||
DebugObject(ctx context.Context, key string) *StringCmd
|
||||
MemoryUsage(ctx context.Context, key string, samples ...int) *IntCmd
|
||||
Latency(ctx context.Context) *LatencyCmd
|
||||
LatencyReset(ctx context.Context, events ...interface{}) *StatusCmd
|
||||
|
||||
ModuleLoadex(ctx context.Context, conf *ModuleLoadexConfig) *StringCmd
|
||||
|
||||
@@ -519,6 +524,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ClientMaintNotifications enables or disables maintenance notifications for maintenance upgrades.
|
||||
// When enabled, the client will receive push notifications about Redis maintenance events.
|
||||
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
|
||||
args := []interface{}{"client", "maint_notifications"}
|
||||
if enabled {
|
||||
if endpointType == "" {
|
||||
endpointType = "none"
|
||||
}
|
||||
args = append(args, "on", "moving-endpoint-type", endpointType)
|
||||
} else {
|
||||
args = append(args, "off")
|
||||
}
|
||||
cmd := NewStatusCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {
|
||||
@@ -655,6 +677,34 @@ func (c cmdable) SlowLogGet(ctx context.Context, num int64) *SlowLogCmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) SlowLogLen(ctx context.Context) *IntCmd {
|
||||
cmd := NewIntCmd(ctx, "slowlog", "len")
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) SlowLogReset(ctx context.Context) *StatusCmd {
|
||||
cmd := NewStatusCmd(ctx, "slowlog", "reset")
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) Latency(ctx context.Context) *LatencyCmd {
|
||||
cmd := NewLatencyCmd(ctx, "latency", "latest")
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) LatencyReset(ctx context.Context, events ...interface{}) *StatusCmd {
|
||||
args := make([]interface{}, 2+len(events))
|
||||
args[0] = "latency"
|
||||
args[1] = "reset"
|
||||
copy(args[2:], events)
|
||||
cmd := NewStatusCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) Sync(_ context.Context) {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -675,7 +725,9 @@ func (c cmdable) MemoryUsage(ctx context.Context, key string, samples ...int) *I
|
||||
args := []interface{}{"memory", "usage", key}
|
||||
if len(samples) > 0 {
|
||||
if len(samples) != 1 {
|
||||
panic("MemoryUsage expects single sample count")
|
||||
cmd := NewIntCmd(ctx)
|
||||
cmd.SetErr(errors.New("MemoryUsage expects single sample count"))
|
||||
return cmd
|
||||
}
|
||||
args = append(args, "SAMPLES", samples[0])
|
||||
}
|
||||
|
||||
+5
-5
@@ -2,7 +2,7 @@
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
|
||||
platform: linux/amd64
|
||||
container_name: redis-standalone
|
||||
environment:
|
||||
@@ -23,7 +23,7 @@ services:
|
||||
- all
|
||||
|
||||
osscluster:
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
|
||||
platform: linux/amd64
|
||||
container_name: redis-osscluster
|
||||
environment:
|
||||
@@ -40,7 +40,7 @@ services:
|
||||
- all
|
||||
|
||||
sentinel-cluster:
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
|
||||
platform: linux/amd64
|
||||
container_name: redis-sentinel-cluster
|
||||
network_mode: "host"
|
||||
@@ -60,7 +60,7 @@ services:
|
||||
- all
|
||||
|
||||
sentinel:
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
|
||||
platform: linux/amd64
|
||||
container_name: redis-sentinel
|
||||
depends_on:
|
||||
@@ -84,7 +84,7 @@ services:
|
||||
- all
|
||||
|
||||
ring-cluster:
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
|
||||
platform: linux/amd64
|
||||
container_name: redis-ring-cluster
|
||||
environment:
|
||||
|
||||
+207
-45
@@ -52,34 +52,82 @@ type Error interface {
|
||||
var _ Error = proto.RedisError("")
|
||||
|
||||
func isContextError(err error) bool {
|
||||
switch err {
|
||||
case context.Canceled, context.DeadlineExceeded:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
// Check for wrapped context errors using errors.Is
|
||||
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
// isTimeoutError checks if an error is a timeout error, even if wrapped.
|
||||
// Returns (isTimeout, shouldRetryOnTimeout) where:
|
||||
// - isTimeout: true if the error is any kind of timeout error
|
||||
// - shouldRetryOnTimeout: true if Timeout() method returns true
|
||||
func isTimeoutError(err error) (isTimeout bool, hasTimeoutFlag bool) {
|
||||
// Check for timeoutError interface (works with wrapped errors)
|
||||
var te timeoutError
|
||||
if errors.As(err, &te) {
|
||||
return true, te.Timeout()
|
||||
}
|
||||
|
||||
// Check for net.Error specifically (common case for network timeouts)
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
return true, netErr.Timeout()
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
func shouldRetry(err error, retryTimeout bool) bool {
|
||||
switch err {
|
||||
case io.EOF, io.ErrUnexpectedEOF:
|
||||
return true
|
||||
case nil, context.Canceled, context.DeadlineExceeded:
|
||||
if err == nil {
|
||||
return false
|
||||
case pool.ErrPoolTimeout:
|
||||
}
|
||||
|
||||
// Check for EOF errors (works with wrapped errors)
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for context errors (works with wrapped errors)
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for pool timeout (works with wrapped errors)
|
||||
if errors.Is(err, pool.ErrPoolTimeout) {
|
||||
// connection pool timeout, increase retries. #3289
|
||||
return true
|
||||
}
|
||||
|
||||
if v, ok := err.(timeoutError); ok {
|
||||
if v.Timeout() {
|
||||
// Check for timeout errors (works with wrapped errors)
|
||||
if isTimeout, hasTimeoutFlag := isTimeoutError(err); isTimeout {
|
||||
if hasTimeoutFlag {
|
||||
return retryTimeout
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for typed Redis errors using errors.As (works with wrapped errors)
|
||||
if proto.IsMaxClientsError(err) {
|
||||
return true
|
||||
}
|
||||
if proto.IsLoadingError(err) {
|
||||
return true
|
||||
}
|
||||
if proto.IsReadOnlyError(err) {
|
||||
return true
|
||||
}
|
||||
if proto.IsMasterDownError(err) {
|
||||
return true
|
||||
}
|
||||
if proto.IsClusterDownError(err) {
|
||||
return true
|
||||
}
|
||||
if proto.IsTryAgainError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fallback to string checking for backward compatibility with plain errors
|
||||
s := err.Error()
|
||||
if s == "ERR max number of clients reached" {
|
||||
if strings.HasPrefix(s, "ERR max number of clients reached") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(s, "LOADING ") {
|
||||
@@ -88,29 +136,42 @@ func shouldRetry(err error, retryTimeout bool) bool {
|
||||
if strings.HasPrefix(s, "READONLY ") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(s, "MASTERDOWN ") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(s, "CLUSTERDOWN ") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(s, "TRYAGAIN ") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(s, "MASTERDOWN ") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isRedisError(err error) bool {
|
||||
_, ok := err.(proto.RedisError)
|
||||
return ok
|
||||
// Check if error implements the Error interface (works with wrapped errors)
|
||||
var redisErr Error
|
||||
if errors.As(err, &redisErr) {
|
||||
return true
|
||||
}
|
||||
// Also check for proto.RedisError specifically
|
||||
var protoRedisErr proto.RedisError
|
||||
return errors.As(err, &protoRedisErr)
|
||||
}
|
||||
|
||||
func isBadConn(err error, allowTimeout bool, addr string) bool {
|
||||
switch err {
|
||||
case nil:
|
||||
if err == nil {
|
||||
return false
|
||||
case context.Canceled, context.DeadlineExceeded:
|
||||
}
|
||||
|
||||
// Check for context errors (works with wrapped errors)
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for pool timeout errors (works with wrapped errors)
|
||||
if errors.Is(err, pool.ErrConnUnusableTimeout) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -131,7 +192,9 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
|
||||
}
|
||||
|
||||
if allowTimeout {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
// Check for network timeout errors (works with wrapped errors)
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -140,44 +203,143 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
|
||||
}
|
||||
|
||||
func isMovedError(err error) (moved bool, ask bool, addr string) {
|
||||
if !isRedisError(err) {
|
||||
return
|
||||
// Check for typed MovedError
|
||||
if movedErr, ok := proto.IsMovedError(err); ok {
|
||||
addr = movedErr.Addr()
|
||||
addr = internal.GetAddr(addr)
|
||||
return true, false, addr
|
||||
}
|
||||
|
||||
// Check for typed AskError
|
||||
if askErr, ok := proto.IsAskError(err); ok {
|
||||
addr = askErr.Addr()
|
||||
addr = internal.GetAddr(addr)
|
||||
return false, true, addr
|
||||
}
|
||||
|
||||
// Fallback to string checking for backward compatibility
|
||||
s := err.Error()
|
||||
switch {
|
||||
case strings.HasPrefix(s, "MOVED "):
|
||||
moved = true
|
||||
case strings.HasPrefix(s, "ASK "):
|
||||
ask = true
|
||||
default:
|
||||
return
|
||||
if strings.HasPrefix(s, "MOVED ") {
|
||||
// Parse: MOVED 3999 127.0.0.1:6381
|
||||
parts := strings.Split(s, " ")
|
||||
if len(parts) == 3 {
|
||||
addr = internal.GetAddr(parts[2])
|
||||
return true, false, addr
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(s, "ASK ") {
|
||||
// Parse: ASK 3999 127.0.0.1:6381
|
||||
parts := strings.Split(s, " ")
|
||||
if len(parts) == 3 {
|
||||
addr = internal.GetAddr(parts[2])
|
||||
return false, true, addr
|
||||
}
|
||||
}
|
||||
|
||||
ind := strings.LastIndex(s, " ")
|
||||
if ind == -1 {
|
||||
return false, false, ""
|
||||
}
|
||||
|
||||
addr = s[ind+1:]
|
||||
addr = internal.GetAddr(addr)
|
||||
return
|
||||
return false, false, ""
|
||||
}
|
||||
|
||||
func isLoadingError(err error) bool {
|
||||
return strings.HasPrefix(err.Error(), "LOADING ")
|
||||
return proto.IsLoadingError(err)
|
||||
}
|
||||
|
||||
func isReadOnlyError(err error) bool {
|
||||
return strings.HasPrefix(err.Error(), "READONLY ")
|
||||
return proto.IsReadOnlyError(err)
|
||||
}
|
||||
|
||||
func isMovedSameConnAddr(err error, addr string) bool {
|
||||
redisError := err.Error()
|
||||
if !strings.HasPrefix(redisError, "MOVED ") {
|
||||
return false
|
||||
if movedErr, ok := proto.IsMovedError(err); ok {
|
||||
return strings.HasSuffix(movedErr.Addr(), addr)
|
||||
}
|
||||
return strings.HasSuffix(redisError, " "+addr)
|
||||
return false
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Typed error checking functions for public use.
|
||||
// These functions work correctly even when errors are wrapped in hooks.
|
||||
|
||||
// IsLoadingError checks if an error is a Redis LOADING error, even if wrapped.
|
||||
// LOADING errors occur when Redis is loading the dataset in memory.
|
||||
func IsLoadingError(err error) bool {
|
||||
return proto.IsLoadingError(err)
|
||||
}
|
||||
|
||||
// IsReadOnlyError checks if an error is a Redis READONLY error, even if wrapped.
|
||||
// READONLY errors occur when trying to write to a read-only replica.
|
||||
func IsReadOnlyError(err error) bool {
|
||||
return proto.IsReadOnlyError(err)
|
||||
}
|
||||
|
||||
// IsClusterDownError checks if an error is a Redis CLUSTERDOWN error, even if wrapped.
|
||||
// CLUSTERDOWN errors occur when the cluster is down.
|
||||
func IsClusterDownError(err error) bool {
|
||||
return proto.IsClusterDownError(err)
|
||||
}
|
||||
|
||||
// IsTryAgainError checks if an error is a Redis TRYAGAIN error, even if wrapped.
|
||||
// TRYAGAIN errors occur when a command cannot be processed and should be retried.
|
||||
func IsTryAgainError(err error) bool {
|
||||
return proto.IsTryAgainError(err)
|
||||
}
|
||||
|
||||
// IsMasterDownError checks if an error is a Redis MASTERDOWN error, even if wrapped.
|
||||
// MASTERDOWN errors occur when the master is down.
|
||||
func IsMasterDownError(err error) bool {
|
||||
return proto.IsMasterDownError(err)
|
||||
}
|
||||
|
||||
// IsMaxClientsError checks if an error is a Redis max clients error, even if wrapped.
|
||||
// This error occurs when the maximum number of clients has been reached.
|
||||
func IsMaxClientsError(err error) bool {
|
||||
return proto.IsMaxClientsError(err)
|
||||
}
|
||||
|
||||
// IsMovedError checks if an error is a Redis MOVED error, even if wrapped.
|
||||
// MOVED errors occur in cluster mode when a key has been moved to a different node.
|
||||
// Returns the address of the node where the key has been moved and a boolean indicating if it's a MOVED error.
|
||||
func IsMovedError(err error) (addr string, ok bool) {
|
||||
if movedErr, isMovedErr := proto.IsMovedError(err); isMovedErr {
|
||||
return movedErr.Addr(), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// IsAskError checks if an error is a Redis ASK error, even if wrapped.
|
||||
// ASK errors occur in cluster mode when a key is being migrated and the client should ask another node.
|
||||
// Returns the address of the node to ask and a boolean indicating if it's an ASK error.
|
||||
func IsAskError(err error) (addr string, ok bool) {
|
||||
if askErr, isAskErr := proto.IsAskError(err); isAskErr {
|
||||
return askErr.Addr(), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// IsAuthError checks if an error is a Redis authentication error, even if wrapped.
|
||||
// Authentication errors occur when:
|
||||
// - NOAUTH: Redis requires authentication but none was provided
|
||||
// - WRONGPASS: Redis authentication failed due to incorrect password
|
||||
// - unauthenticated: Error returned when password changed
|
||||
func IsAuthError(err error) bool {
|
||||
return proto.IsAuthError(err)
|
||||
}
|
||||
|
||||
// IsPermissionError checks if an error is a Redis permission error, even if wrapped.
|
||||
// Permission errors (NOPERM) occur when a user does not have permission to execute a command.
|
||||
func IsPermissionError(err error) bool {
|
||||
return proto.IsPermissionError(err)
|
||||
}
|
||||
|
||||
// IsExecAbortError checks if an error is a Redis EXECABORT error, even if wrapped.
|
||||
// EXECABORT errors occur when a transaction is aborted.
|
||||
func IsExecAbortError(err error) bool {
|
||||
return proto.IsExecAbortError(err)
|
||||
}
|
||||
|
||||
// IsOOMError checks if an error is a Redis OOM (Out Of Memory) error, even if wrapped.
|
||||
// OOM errors occur when Redis is out of memory.
|
||||
func IsOOMError(err error) bool {
|
||||
return proto.IsOOMError(err)
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
+4
-4
@@ -116,16 +116,16 @@ func (c cmdable) HMGet(ctx context.Context, key string, fields ...string) *Slice
|
||||
|
||||
// HSet accepts values in following formats:
|
||||
//
|
||||
// - HSet("myhash", "key1", "value1", "key2", "value2")
|
||||
// - HSet(ctx, "myhash", "key1", "value1", "key2", "value2")
|
||||
//
|
||||
// - HSet("myhash", []string{"key1", "value1", "key2", "value2"})
|
||||
// - HSet(ctx, "myhash", []string{"key1", "value1", "key2", "value2"})
|
||||
//
|
||||
// - HSet("myhash", map[string]interface{}{"key1": "value1", "key2": "value2"})
|
||||
// - HSet(ctx, "myhash", map[string]interface{}{"key1": "value1", "key2": "value2"})
|
||||
//
|
||||
// Playing struct With "redis" tag.
|
||||
// type MyHash struct { Key1 string `redis:"key1"`; Key2 int `redis:"key2"` }
|
||||
//
|
||||
// - HSet("myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0
|
||||
// - HSet(ctx, "myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0
|
||||
//
|
||||
// For struct, can be a structure pointer type, we only parse the field whose tag is redis.
|
||||
// if you don't want the field to be read, you can use the `redis:"-"` flag to ignore it,
|
||||
|
||||
Generated
Vendored
+100
@@ -0,0 +1,100 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// ConnReAuthCredentialsListener is a credentials listener for a specific connection
|
||||
// that triggers re-authentication when credentials change.
|
||||
//
|
||||
// This listener implements the auth.CredentialsListener interface and is subscribed
|
||||
// to a StreamingCredentialsProvider. When new credentials are received via OnNext,
|
||||
// it marks the connection for re-authentication through the manager.
|
||||
//
|
||||
// The re-authentication is always performed asynchronously to avoid blocking the
|
||||
// credentials provider and to prevent potential deadlocks with the pool semaphore.
|
||||
// The actual re-auth happens when the connection is returned to the pool in an idle state.
|
||||
//
|
||||
// Lifecycle:
|
||||
// - Created during connection initialization via Manager.Listener()
|
||||
// - Subscribed to the StreamingCredentialsProvider
|
||||
// - Receives credential updates via OnNext()
|
||||
// - Cleaned up when connection is removed from pool via Manager.RemoveListener()
|
||||
type ConnReAuthCredentialsListener struct {
|
||||
// reAuth is the function to re-authenticate the connection with new credentials
|
||||
reAuth func(conn *pool.Conn, credentials auth.Credentials) error
|
||||
|
||||
// onErr is the function to call when re-authentication or acquisition fails
|
||||
onErr func(conn *pool.Conn, err error)
|
||||
|
||||
// conn is the connection this listener is associated with
|
||||
conn *pool.Conn
|
||||
|
||||
// manager is the streaming credentials manager for coordinating re-auth
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
// OnNext is called when new credentials are received from the StreamingCredentialsProvider.
|
||||
//
|
||||
// This method marks the connection for asynchronous re-authentication. The actual
|
||||
// re-authentication happens in the background when the connection is returned to the
|
||||
// pool and is in an idle state.
|
||||
//
|
||||
// Asynchronous re-auth is used to:
|
||||
// - Avoid blocking the credentials provider's notification goroutine
|
||||
// - Prevent deadlocks with the pool's semaphore (especially with small pool sizes)
|
||||
// - Ensure re-auth happens when the connection is safe to use (not processing commands)
|
||||
//
|
||||
// The reAuthFn callback receives:
|
||||
// - nil if the connection was successfully acquired for re-auth
|
||||
// - error if acquisition timed out or failed
|
||||
//
|
||||
// Thread-safe: Called by the credentials provider's notification goroutine.
|
||||
func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
|
||||
if c.conn == nil || c.conn.IsClosed() || c.manager == nil || c.reAuth == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Always use async reauth to avoid complex pool semaphore issues
|
||||
// The synchronous path can cause deadlocks in the pool's semaphore mechanism
|
||||
// when called from the Subscribe goroutine, especially with small pool sizes.
|
||||
// The connection pool hook will re-authenticate the connection when it is
|
||||
// returned to the pool in a clean, idle state.
|
||||
c.manager.MarkForReAuth(c.conn, func(err error) {
|
||||
// err is from connection acquisition (timeout, etc.)
|
||||
if err != nil {
|
||||
// Log the error
|
||||
c.OnError(err)
|
||||
return
|
||||
}
|
||||
// err is from reauth command execution
|
||||
err = c.reAuth(c.conn, credentials)
|
||||
if err != nil {
|
||||
// Log the error
|
||||
c.OnError(err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnError is called when an error occurs during credential streaming or re-authentication.
|
||||
//
|
||||
// This method can be called from:
|
||||
// - The StreamingCredentialsProvider when there's an error in the credentials stream
|
||||
// - The re-auth process when connection acquisition times out
|
||||
// - The re-auth process when the AUTH command fails
|
||||
//
|
||||
// The error is delegated to the onErr callback provided during listener creation.
|
||||
//
|
||||
// Thread-safe: Can be called from multiple goroutines (provider, re-auth worker).
|
||||
func (c *ConnReAuthCredentialsListener) OnError(err error) {
|
||||
if c.onErr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.onErr(c.conn, err)
|
||||
}
|
||||
|
||||
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
|
||||
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
|
||||
+77
@@ -0,0 +1,77 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
)
|
||||
|
||||
// CredentialsListeners is a thread-safe collection of credentials listeners
|
||||
// indexed by connection ID.
|
||||
//
|
||||
// This collection is used by the Manager to maintain a registry of listeners
|
||||
// for each connection in the pool. Listeners are reused when connections are
|
||||
// reinitialized (e.g., after a handoff) to avoid creating duplicate subscriptions
|
||||
// to the StreamingCredentialsProvider.
|
||||
//
|
||||
// The collection supports concurrent access from multiple goroutines during
|
||||
// connection initialization, credential updates, and connection removal.
|
||||
type CredentialsListeners struct {
|
||||
// listeners maps connection ID to credentials listener
|
||||
listeners map[uint64]auth.CredentialsListener
|
||||
|
||||
// lock protects concurrent access to the listeners map
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCredentialsListeners creates a new thread-safe credentials listeners collection.
|
||||
func NewCredentialsListeners() *CredentialsListeners {
|
||||
return &CredentialsListeners{
|
||||
listeners: make(map[uint64]auth.CredentialsListener),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds or updates a credentials listener for a connection.
|
||||
//
|
||||
// If a listener already exists for the connection ID, it is replaced.
|
||||
// This is safe because the old listener should have been unsubscribed
|
||||
// before the connection was reinitialized.
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (c *CredentialsListeners) Add(connID uint64, listener auth.CredentialsListener) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
if c.listeners == nil {
|
||||
c.listeners = make(map[uint64]auth.CredentialsListener)
|
||||
}
|
||||
c.listeners[connID] = listener
|
||||
}
|
||||
|
||||
// Get retrieves the credentials listener for a connection.
|
||||
//
|
||||
// Returns:
|
||||
// - listener: The credentials listener for the connection, or nil if not found
|
||||
// - ok: true if a listener exists for the connection ID, false otherwise
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (c *CredentialsListeners) Get(connID uint64) (auth.CredentialsListener, bool) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
if len(c.listeners) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
listener, ok := c.listeners[connID]
|
||||
return listener, ok
|
||||
}
|
||||
|
||||
// Remove removes the credentials listener for a connection.
|
||||
//
|
||||
// This is called when a connection is removed from the pool to prevent
|
||||
// memory leaks. If no listener exists for the connection ID, this is a no-op.
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (c *CredentialsListeners) Remove(connID uint64) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
delete(c.listeners, connID)
|
||||
}
|
||||
+137
@@ -0,0 +1,137 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// Manager coordinates streaming credentials and re-authentication for a connection pool.
|
||||
//
|
||||
// The manager is responsible for:
|
||||
// - Creating and managing per-connection credentials listeners
|
||||
// - Providing the pool hook for re-authentication
|
||||
// - Coordinating between credentials updates and pool operations
|
||||
//
|
||||
// When credentials change via a StreamingCredentialsProvider:
|
||||
// 1. The credentials listener (ConnReAuthCredentialsListener) receives the update
|
||||
// 2. It calls MarkForReAuth on the manager
|
||||
// 3. The manager delegates to the pool hook
|
||||
// 4. The pool hook schedules background re-authentication
|
||||
//
|
||||
// The manager maintains a registry of credentials listeners indexed by connection ID,
|
||||
// allowing listener reuse when connections are reinitialized (e.g., after handoff).
|
||||
type Manager struct {
|
||||
// credentialsListeners maps connection ID to credentials listener
|
||||
credentialsListeners *CredentialsListeners
|
||||
|
||||
// pool is the connection pool being managed
|
||||
pool pool.Pooler
|
||||
|
||||
// poolHookRef is the re-authentication pool hook
|
||||
poolHookRef *ReAuthPoolHook
|
||||
}
|
||||
|
||||
// NewManager creates a new streaming credentials manager.
|
||||
//
|
||||
// Parameters:
|
||||
// - pl: The connection pool to manage
|
||||
// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication
|
||||
//
|
||||
// The manager creates a ReAuthPoolHook sized to match the pool size, ensuring that
|
||||
// re-auth operations don't exhaust the connection pool.
|
||||
func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager {
|
||||
m := &Manager{
|
||||
pool: pl,
|
||||
poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout),
|
||||
credentialsListeners: NewCredentialsListeners(),
|
||||
}
|
||||
m.poolHookRef.manager = m
|
||||
return m
|
||||
}
|
||||
|
||||
// PoolHook returns the pool hook for re-authentication.
|
||||
//
|
||||
// This hook should be registered with the connection pool to enable
|
||||
// automatic re-authentication when credentials change.
|
||||
func (m *Manager) PoolHook() pool.PoolHook {
|
||||
return m.poolHookRef
|
||||
}
|
||||
|
||||
// Listener returns or creates a credentials listener for a connection.
|
||||
//
|
||||
// This method is called during connection initialization to set up the
|
||||
// credentials listener. If a listener already exists for the connection ID
|
||||
// (e.g., after a handoff), it is reused.
|
||||
//
|
||||
// Parameters:
|
||||
// - poolCn: The connection to create/get a listener for
|
||||
// - reAuth: Function to re-authenticate the connection with new credentials
|
||||
// - onErr: Function to call when re-authentication fails
|
||||
//
|
||||
// Returns:
|
||||
// - auth.CredentialsListener: The listener to subscribe to the credentials provider
|
||||
// - error: Non-nil if poolCn is nil
|
||||
//
|
||||
// Note: The reAuth and onErr callbacks are captured once when the listener is
|
||||
// created and reused for the connection's lifetime. They should not change.
|
||||
//
|
||||
// Thread-safe: Can be called concurrently during connection initialization.
|
||||
func (m *Manager) Listener(
|
||||
poolCn *pool.Conn,
|
||||
reAuth func(*pool.Conn, auth.Credentials) error,
|
||||
onErr func(*pool.Conn, error),
|
||||
) (auth.CredentialsListener, error) {
|
||||
if poolCn == nil {
|
||||
return nil, errors.New("poolCn cannot be nil")
|
||||
}
|
||||
connID := poolCn.GetID()
|
||||
// if we reconnect the underlying network connection, the streaming credentials listener will continue to work
|
||||
// so we can get the old listener from the cache and use it.
|
||||
// subscribing the same (an already subscribed) listener for a StreamingCredentialsProvider SHOULD be a no-op
|
||||
listener, ok := m.credentialsListeners.Get(connID)
|
||||
if !ok || listener == nil {
|
||||
// Create new listener for this connection
|
||||
// Note: Callbacks (reAuth, onErr) are captured once and reused for the connection's lifetime
|
||||
newCredListener := &ConnReAuthCredentialsListener{
|
||||
conn: poolCn,
|
||||
reAuth: reAuth,
|
||||
onErr: onErr,
|
||||
manager: m,
|
||||
}
|
||||
|
||||
m.credentialsListeners.Add(connID, newCredListener)
|
||||
listener = newCredListener
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// MarkForReAuth marks a connection for re-authentication.
|
||||
//
|
||||
// This method is called by the credentials listener when new credentials are
|
||||
// received. It delegates to the pool hook to schedule background re-authentication.
|
||||
//
|
||||
// Parameters:
|
||||
// - poolCn: The connection to re-authenticate
|
||||
// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails
|
||||
//
|
||||
// Thread-safe: Called by credentials listeners when credentials change.
|
||||
func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
|
||||
connID := poolCn.GetID()
|
||||
m.poolHookRef.MarkForReAuth(connID, reAuthFn)
|
||||
}
|
||||
|
||||
// RemoveListener removes the credentials listener for a connection.
|
||||
//
|
||||
// This method is called by the pool hook's OnRemove to clean up listeners
|
||||
// when connections are removed from the pool.
|
||||
//
|
||||
// Parameters:
|
||||
// - connID: The connection ID whose listener should be removed
|
||||
//
|
||||
// Thread-safe: Called during connection removal.
|
||||
func (m *Manager) RemoveListener(connID uint64) {
|
||||
m.credentialsListeners.Remove(connID)
|
||||
}
|
||||
+241
@@ -0,0 +1,241 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// ReAuthPoolHook is a pool hook that manages background re-authentication of connections
|
||||
// when credentials change via a streaming credentials provider.
|
||||
//
|
||||
// The hook uses a semaphore-based worker pool to limit concurrent re-authentication
|
||||
// operations and prevent pool exhaustion. When credentials change, connections are
|
||||
// marked for re-authentication and processed asynchronously in the background.
|
||||
//
|
||||
// The re-authentication process:
|
||||
// 1. OnPut: When a connection is returned to the pool, check if it needs re-auth
|
||||
// 2. If yes, schedule it for background processing (move from shouldReAuth to scheduledReAuth)
|
||||
// 3. A worker goroutine acquires the connection (waits until it's not in use)
|
||||
// 4. Executes the re-auth function while holding the connection
|
||||
// 5. Releases the connection back to the pool
|
||||
//
|
||||
// The hook ensures that:
|
||||
// - Only one re-auth operation runs per connection at a time
|
||||
// - Connections are not used for commands during re-authentication
|
||||
// - Re-auth operations timeout if they can't acquire the connection
|
||||
// - Resources are properly cleaned up on connection removal
|
||||
type ReAuthPoolHook struct {
|
||||
// shouldReAuth maps connection ID to re-auth function
|
||||
// Connections in this map need re-authentication but haven't been scheduled yet
|
||||
shouldReAuth map[uint64]func(error)
|
||||
shouldReAuthLock sync.RWMutex
|
||||
|
||||
// workers is a semaphore limiting concurrent re-auth operations
|
||||
// Initialized with poolSize tokens to prevent pool exhaustion
|
||||
// Uses FastSemaphore for better performance with eventual fairness
|
||||
workers *internal.FastSemaphore
|
||||
|
||||
// reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth
|
||||
reAuthTimeout time.Duration
|
||||
|
||||
// scheduledReAuth maps connection ID to scheduled status
|
||||
// Connections in this map have a background worker attempting re-authentication
|
||||
scheduledReAuth map[uint64]bool
|
||||
scheduledLock sync.RWMutex
|
||||
|
||||
// manager is a back-reference for cleanup operations
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
// NewReAuthPoolHook creates a new re-authentication pool hook.
|
||||
//
|
||||
// Parameters:
|
||||
// - poolSize: Maximum number of concurrent re-auth operations (typically matches pool size)
|
||||
// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication
|
||||
//
|
||||
// The poolSize parameter is used to initialize the worker semaphore, ensuring that
|
||||
// re-auth operations don't exhaust the connection pool.
|
||||
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
|
||||
return &ReAuthPoolHook{
|
||||
shouldReAuth: make(map[uint64]func(error)),
|
||||
scheduledReAuth: make(map[uint64]bool),
|
||||
workers: internal.NewFastSemaphore(int32(poolSize)),
|
||||
reAuthTimeout: reAuthTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkForReAuth marks a connection for re-authentication.
|
||||
//
|
||||
// This method is called when credentials change and a connection needs to be
|
||||
// re-authenticated. The actual re-authentication happens asynchronously when
|
||||
// the connection is returned to the pool (in OnPut).
|
||||
//
|
||||
// Parameters:
|
||||
// - connID: The connection ID to mark for re-authentication
|
||||
// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
|
||||
r.shouldReAuthLock.Lock()
|
||||
defer r.shouldReAuthLock.Unlock()
|
||||
r.shouldReAuth[connID] = reAuthFn
|
||||
}
|
||||
|
||||
// OnGet is called when a connection is retrieved from the pool.
|
||||
//
|
||||
// This hook checks if the connection needs re-authentication or has a scheduled
|
||||
// re-auth operation. If so, it rejects the connection (returns accept=false),
|
||||
// causing the pool to try another connection.
|
||||
//
|
||||
// Returns:
|
||||
// - accept: false if connection needs re-auth, true otherwise
|
||||
// - err: always nil (errors are not used in this hook)
|
||||
//
|
||||
// Thread-safe: Called concurrently by multiple goroutines getting connections.
|
||||
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
|
||||
connID := conn.GetID()
|
||||
r.shouldReAuthLock.RLock()
|
||||
_, shouldReAuth := r.shouldReAuth[connID]
|
||||
r.shouldReAuthLock.RUnlock()
|
||||
// This connection was marked for reauth while in the pool,
|
||||
// reject the connection
|
||||
if shouldReAuth {
|
||||
// simply reject the connection, it will be re-authenticated in OnPut
|
||||
return false, nil
|
||||
}
|
||||
r.scheduledLock.RLock()
|
||||
_, hasScheduled := r.scheduledReAuth[connID]
|
||||
r.scheduledLock.RUnlock()
|
||||
// has scheduled reauth, reject the connection
|
||||
if hasScheduled {
|
||||
// simply reject the connection, it currently has a reauth scheduled
|
||||
// and the worker is waiting for slot to execute the reauth
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// OnPut is called when a connection is returned to the pool.
|
||||
//
|
||||
// This hook checks if the connection needs re-authentication. If so, it schedules
|
||||
// a background goroutine to perform the re-auth asynchronously. The goroutine:
|
||||
// 1. Waits for a worker slot (semaphore)
|
||||
// 2. Acquires the connection (waits until not in use)
|
||||
// 3. Executes the re-auth function
|
||||
// 4. Releases the connection and worker slot
|
||||
//
|
||||
// The connection is always pooled (not removed) since re-auth happens in background.
|
||||
//
|
||||
// Returns:
|
||||
// - shouldPool: always true (connection stays in pool during background re-auth)
|
||||
// - shouldRemove: always false
|
||||
// - err: always nil
|
||||
//
|
||||
// Thread-safe: Called concurrently by multiple goroutines returning connections.
|
||||
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
|
||||
if conn == nil {
|
||||
// noop
|
||||
return true, false, nil
|
||||
}
|
||||
connID := conn.GetID()
|
||||
// Check if reauth is needed and get the function with proper locking
|
||||
r.shouldReAuthLock.RLock()
|
||||
reAuthFn, ok := r.shouldReAuth[connID]
|
||||
r.shouldReAuthLock.RUnlock()
|
||||
|
||||
if ok {
|
||||
// Acquire both locks to atomically move from shouldReAuth to scheduledReAuth
|
||||
// This prevents race conditions where OnGet might miss the transition
|
||||
r.shouldReAuthLock.Lock()
|
||||
r.scheduledLock.Lock()
|
||||
r.scheduledReAuth[connID] = true
|
||||
delete(r.shouldReAuth, connID)
|
||||
r.scheduledLock.Unlock()
|
||||
r.shouldReAuthLock.Unlock()
|
||||
go func() {
|
||||
r.workers.AcquireBlocking()
|
||||
// safety first
|
||||
if conn == nil || (conn != nil && conn.IsClosed()) {
|
||||
r.workers.Release()
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
// once again - safety first
|
||||
internal.Logger.Printf(context.Background(), "panic in reauth worker: %v", rec)
|
||||
}
|
||||
r.scheduledLock.Lock()
|
||||
delete(r.scheduledReAuth, connID)
|
||||
r.scheduledLock.Unlock()
|
||||
r.workers.Release()
|
||||
}()
|
||||
|
||||
// Create timeout context for connection acquisition
|
||||
// This prevents indefinite waiting if the connection is stuck
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.reAuthTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Try to acquire the connection for re-authentication
|
||||
// We need to ensure the connection is IDLE (not IN_USE) before transitioning to UNUSABLE
|
||||
// This prevents re-authentication from interfering with active commands
|
||||
// Use AwaitAndTransition to wait for the connection to become IDLE
|
||||
stateMachine := conn.GetStateMachine()
|
||||
if stateMachine == nil {
|
||||
// No state machine - should not happen, but handle gracefully
|
||||
reAuthFn(pool.ErrConnUnusableTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := stateMachine.AwaitAndTransition(ctx, pool.ValidFromIdle(), pool.StateUnusable)
|
||||
if err != nil {
|
||||
// Timeout or other error occurred, cannot acquire connection
|
||||
reAuthFn(err)
|
||||
return
|
||||
}
|
||||
|
||||
// safety first
|
||||
if !conn.IsClosed() {
|
||||
// Successfully acquired the connection, perform reauth
|
||||
reAuthFn(nil)
|
||||
}
|
||||
|
||||
// Release the connection: transition from UNUSABLE back to IDLE
|
||||
stateMachine.Transition(pool.StateIdle)
|
||||
}()
|
||||
}
|
||||
|
||||
// the reauth will happen in background, as far as the pool is concerned:
|
||||
// pool the connection, don't remove it, no error
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// OnRemove is called when a connection is removed from the pool.
|
||||
//
|
||||
// This hook cleans up all state associated with the connection:
|
||||
// - Removes from shouldReAuth map (pending re-auth)
|
||||
// - Removes from scheduledReAuth map (active re-auth)
|
||||
// - Removes credentials listener from manager
|
||||
//
|
||||
// This prevents memory leaks and ensures that removed connections don't have
|
||||
// lingering re-auth operations or listeners.
|
||||
//
|
||||
// Thread-safe: Called when connections are removed due to errors, timeouts, or pool closure.
|
||||
func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) {
|
||||
connID := conn.GetID()
|
||||
r.shouldReAuthLock.Lock()
|
||||
r.scheduledLock.Lock()
|
||||
delete(r.scheduledReAuth, connID)
|
||||
delete(r.shouldReAuth, connID)
|
||||
r.scheduledLock.Unlock()
|
||||
r.shouldReAuthLock.Unlock()
|
||||
if r.manager != nil {
|
||||
r.manager.RemoveListener(connID)
|
||||
}
|
||||
}
|
||||
|
||||
var _ pool.PoolHook = (*ReAuthPoolHook)(nil)
|
||||
+54
@@ -0,0 +1,54 @@
|
||||
// Package interfaces provides shared interfaces used by both the main redis package
|
||||
// and the maintnotifications upgrade package to avoid circular dependencies.
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NotificationProcessor is (most probably) a push.NotificationProcessor
|
||||
// forward declaration to avoid circular imports
|
||||
type NotificationProcessor interface {
|
||||
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
|
||||
UnregisterHandler(pushNotificationName string) error
|
||||
GetHandler(pushNotificationName string) interface{}
|
||||
}
|
||||
|
||||
// ClientInterface defines the interface that clients must implement for maintnotifications upgrades.
|
||||
type ClientInterface interface {
|
||||
// GetOptions returns the client options.
|
||||
GetOptions() OptionsInterface
|
||||
|
||||
// GetPushProcessor returns the client's push notification processor.
|
||||
GetPushProcessor() NotificationProcessor
|
||||
}
|
||||
|
||||
// OptionsInterface defines the interface for client options.
|
||||
// Uses an adapter pattern to avoid circular dependencies.
|
||||
type OptionsInterface interface {
|
||||
// GetReadTimeout returns the read timeout.
|
||||
GetReadTimeout() time.Duration
|
||||
|
||||
// GetWriteTimeout returns the write timeout.
|
||||
GetWriteTimeout() time.Duration
|
||||
|
||||
// GetNetwork returns the network type.
|
||||
GetNetwork() string
|
||||
|
||||
// GetAddr returns the connection address.
|
||||
GetAddr() string
|
||||
|
||||
// IsTLSEnabled returns true if TLS is enabled.
|
||||
IsTLSEnabled() bool
|
||||
|
||||
// GetProtocol returns the protocol version.
|
||||
GetProtocol() int
|
||||
|
||||
// GetPoolSize returns the connection pool size.
|
||||
GetPoolSize() int
|
||||
|
||||
// NewDialer returns a new dialer function for the connection.
|
||||
NewDialer() func(context.Context) (net.Conn, error)
|
||||
}
|
||||
+57
-4
@@ -7,20 +7,73 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// TODO (ned): Revisit logging
|
||||
// Add more standardized approach with log levels and configurability
|
||||
|
||||
type Logging interface {
|
||||
Printf(ctx context.Context, format string, v ...interface{})
|
||||
}
|
||||
|
||||
type logger struct {
|
||||
type DefaultLogger struct {
|
||||
log *log.Logger
|
||||
}
|
||||
|
||||
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
func (l *DefaultLogger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
_ = l.log.Output(2, fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func NewDefaultLogger() Logging {
|
||||
return &DefaultLogger{
|
||||
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
|
||||
}
|
||||
}
|
||||
|
||||
// Logger calls Output to print to the stderr.
|
||||
// Arguments are handled in the manner of fmt.Print.
|
||||
var Logger Logging = &logger{
|
||||
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
|
||||
var Logger Logging = NewDefaultLogger()
|
||||
|
||||
var LogLevel LogLevelT = LogLevelError
|
||||
|
||||
// LogLevelT represents the logging level
|
||||
type LogLevelT int
|
||||
|
||||
// Log level constants for the entire go-redis library
|
||||
const (
|
||||
LogLevelError LogLevelT = iota // 0 - errors only
|
||||
LogLevelWarn // 1 - warnings and errors
|
||||
LogLevelInfo // 2 - info, warnings, and errors
|
||||
LogLevelDebug // 3 - debug, info, warnings, and errors
|
||||
)
|
||||
|
||||
// String returns the string representation of the log level
|
||||
func (l LogLevelT) String() string {
|
||||
switch l {
|
||||
case LogLevelError:
|
||||
return "ERROR"
|
||||
case LogLevelWarn:
|
||||
return "WARN"
|
||||
case LogLevelInfo:
|
||||
return "INFO"
|
||||
case LogLevelDebug:
|
||||
return "DEBUG"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// IsValid returns true if the log level is valid
|
||||
func (l LogLevelT) IsValid() bool {
|
||||
return l >= LogLevelError && l <= LogLevelDebug
|
||||
}
|
||||
|
||||
func (l LogLevelT) WarnOrAbove() bool {
|
||||
return l >= LogLevelWarn
|
||||
}
|
||||
|
||||
func (l LogLevelT) InfoOrAbove() bool {
|
||||
return l >= LogLevelInfo
|
||||
}
|
||||
|
||||
func (l LogLevelT) DebugOrAbove() bool {
|
||||
return l >= LogLevelDebug
|
||||
}
|
||||
|
||||
Generated
Vendored
+625
@@ -0,0 +1,625 @@
|
||||
package logs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
)
|
||||
|
||||
// appendJSONIfDebug appends JSON data to a message only if the global log level is Debug
|
||||
func appendJSONIfDebug(message string, data map[string]interface{}) string {
|
||||
if internal.LogLevel.DebugOrAbove() {
|
||||
jsonData, _ := json.Marshal(data)
|
||||
return fmt.Sprintf("%s %s", message, string(jsonData))
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
const (
|
||||
// ========================================
|
||||
// CIRCUIT_BREAKER.GO - Circuit breaker management
|
||||
// ========================================
|
||||
CircuitBreakerTransitioningToHalfOpenMessage = "circuit breaker transitioning to half-open"
|
||||
CircuitBreakerOpenedMessage = "circuit breaker opened"
|
||||
CircuitBreakerReopenedMessage = "circuit breaker reopened"
|
||||
CircuitBreakerClosedMessage = "circuit breaker closed"
|
||||
CircuitBreakerCleanupMessage = "circuit breaker cleanup"
|
||||
CircuitBreakerOpenMessage = "circuit breaker is open, failing fast"
|
||||
|
||||
// ========================================
|
||||
// CONFIG.GO - Configuration and debug
|
||||
// ========================================
|
||||
DebugLoggingEnabledMessage = "debug logging enabled"
|
||||
ConfigDebugMessage = "config debug"
|
||||
|
||||
// ========================================
|
||||
// ERRORS.GO - Error message constants
|
||||
// ========================================
|
||||
InvalidRelaxedTimeoutErrorMessage = "relaxed timeout must be greater than 0"
|
||||
InvalidHandoffTimeoutErrorMessage = "handoff timeout must be greater than 0"
|
||||
InvalidHandoffWorkersErrorMessage = "MaxWorkers must be greater than or equal to 0"
|
||||
InvalidHandoffQueueSizeErrorMessage = "handoff queue size must be greater than 0"
|
||||
InvalidPostHandoffRelaxedDurationErrorMessage = "post-handoff relaxed duration must be greater than or equal to 0"
|
||||
InvalidEndpointTypeErrorMessage = "invalid endpoint type"
|
||||
InvalidMaintNotificationsErrorMessage = "invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')"
|
||||
InvalidHandoffRetriesErrorMessage = "MaxHandoffRetries must be between 1 and 10"
|
||||
InvalidClientErrorMessage = "invalid client type"
|
||||
InvalidNotificationErrorMessage = "invalid notification format"
|
||||
MaxHandoffRetriesReachedErrorMessage = "max handoff retries reached"
|
||||
HandoffQueueFullErrorMessage = "handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration"
|
||||
InvalidCircuitBreakerFailureThresholdErrorMessage = "circuit breaker failure threshold must be >= 1"
|
||||
InvalidCircuitBreakerResetTimeoutErrorMessage = "circuit breaker reset timeout must be >= 0"
|
||||
InvalidCircuitBreakerMaxRequestsErrorMessage = "circuit breaker max requests must be >= 1"
|
||||
ConnectionMarkedForHandoffErrorMessage = "connection marked for handoff"
|
||||
ConnectionInvalidHandoffStateErrorMessage = "connection is in invalid state for handoff"
|
||||
ShutdownErrorMessage = "shutdown"
|
||||
CircuitBreakerOpenErrorMessage = "circuit breaker is open, failing fast"
|
||||
|
||||
// ========================================
|
||||
// EXAMPLE_HOOKS.GO - Example metrics hooks
|
||||
// ========================================
|
||||
MetricsHookProcessingNotificationMessage = "metrics hook processing"
|
||||
MetricsHookRecordedErrorMessage = "metrics hook recorded error"
|
||||
|
||||
// ========================================
|
||||
// HANDOFF_WORKER.GO - Connection handoff processing
|
||||
// ========================================
|
||||
HandoffStartedMessage = "handoff started"
|
||||
HandoffFailedMessage = "handoff failed"
|
||||
ConnectionNotMarkedForHandoffMessage = "is not marked for handoff and has no retries"
|
||||
ConnectionNotMarkedForHandoffErrorMessage = "is not marked for handoff"
|
||||
HandoffRetryAttemptMessage = "Performing handoff"
|
||||
CannotQueueHandoffForRetryMessage = "can't queue handoff for retry"
|
||||
HandoffQueueFullMessage = "handoff queue is full"
|
||||
FailedToDialNewEndpointMessage = "failed to dial new endpoint"
|
||||
ApplyingRelaxedTimeoutDueToPostHandoffMessage = "applying relaxed timeout due to post-handoff"
|
||||
HandoffSuccessMessage = "handoff succeeded"
|
||||
RemovingConnectionFromPoolMessage = "removing connection from pool"
|
||||
NoPoolProvidedMessageCannotRemoveMessage = "no pool provided, cannot remove connection, closing it"
|
||||
WorkerExitingDueToShutdownMessage = "worker exiting due to shutdown"
|
||||
WorkerExitingDueToShutdownWhileProcessingMessage = "worker exiting due to shutdown while processing request"
|
||||
WorkerPanicRecoveredMessage = "worker panic recovered"
|
||||
WorkerExitingDueToInactivityTimeoutMessage = "worker exiting due to inactivity timeout"
|
||||
ReachedMaxHandoffRetriesMessage = "reached max handoff retries"
|
||||
|
||||
// ========================================
|
||||
// MANAGER.GO - Moving operation tracking and handler registration
|
||||
// ========================================
|
||||
DuplicateMovingOperationMessage = "duplicate MOVING operation ignored"
|
||||
TrackingMovingOperationMessage = "tracking MOVING operation"
|
||||
UntrackingMovingOperationMessage = "untracking MOVING operation"
|
||||
OperationNotTrackedMessage = "operation not tracked"
|
||||
FailedToRegisterHandlerMessage = "failed to register handler"
|
||||
|
||||
// ========================================
|
||||
// HOOKS.GO - Notification processing hooks
|
||||
// ========================================
|
||||
ProcessingNotificationMessage = "processing notification started"
|
||||
ProcessingNotificationFailedMessage = "proccessing notification failed"
|
||||
ProcessingNotificationSucceededMessage = "processing notification succeeded"
|
||||
|
||||
// ========================================
|
||||
// POOL_HOOK.GO - Pool connection management
|
||||
// ========================================
|
||||
FailedToQueueHandoffMessage = "failed to queue handoff"
|
||||
MarkedForHandoffMessage = "connection marked for handoff"
|
||||
|
||||
// ========================================
|
||||
// PUSH_NOTIFICATION_HANDLER.GO - Push notification validation and processing
|
||||
// ========================================
|
||||
InvalidNotificationFormatMessage = "invalid notification format"
|
||||
InvalidNotificationTypeFormatMessage = "invalid notification type format"
|
||||
InvalidSeqIDInMovingNotificationMessage = "invalid seqID in MOVING notification"
|
||||
InvalidTimeSInMovingNotificationMessage = "invalid timeS in MOVING notification"
|
||||
InvalidNewEndpointInMovingNotificationMessage = "invalid newEndpoint in MOVING notification"
|
||||
NoConnectionInHandlerContextMessage = "no connection in handler context"
|
||||
InvalidConnectionTypeInHandlerContextMessage = "invalid connection type in handler context"
|
||||
SchedulingHandoffToCurrentEndpointMessage = "scheduling handoff to current endpoint"
|
||||
RelaxedTimeoutDueToNotificationMessage = "applying relaxed timeout due to notification"
|
||||
UnrelaxedTimeoutMessage = "clearing relaxed timeout"
|
||||
ManagerNotInitializedMessage = "manager not initialized"
|
||||
FailedToMarkForHandoffMessage = "failed to mark connection for handoff"
|
||||
|
||||
// ========================================
|
||||
// used in pool/conn
|
||||
// ========================================
|
||||
UnrelaxedTimeoutAfterDeadlineMessage = "clearing relaxed timeout after deadline"
|
||||
)
|
||||
|
||||
func HandoffStarted(connID uint64, newEndpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffStartedMessage, newEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": newEndpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func HandoffFailed(connID uint64, newEndpoint string, attempt int, maxAttempts int, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s (attempt %d/%d): %v", connID, HandoffFailedMessage, newEndpoint, attempt, maxAttempts, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": newEndpoint,
|
||||
"attempt": attempt,
|
||||
"maxAttempts": maxAttempts,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func HandoffSucceeded(connID uint64, newEndpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffSuccessMessage, newEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": newEndpoint,
|
||||
})
|
||||
}
|
||||
|
||||
// Timeout-related log functions
|
||||
func RelaxedTimeoutDueToNotification(connID uint64, notificationType string, timeout interface{}) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s (%v)", connID, RelaxedTimeoutDueToNotificationMessage, notificationType, timeout)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"notificationType": notificationType,
|
||||
"timeout": fmt.Sprintf("%v", timeout),
|
||||
})
|
||||
}
|
||||
|
||||
func UnrelaxedTimeout(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
func UnrelaxedTimeoutAfterDeadline(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutAfterDeadlineMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
// Handoff queue and marking functions
|
||||
func HandoffQueueFull(queueLen, queueCap int) string {
|
||||
message := fmt.Sprintf("%s (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration", HandoffQueueFullMessage, queueLen, queueCap)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"queueLen": queueLen,
|
||||
"queueCap": queueCap,
|
||||
})
|
||||
}
|
||||
|
||||
func FailedToQueueHandoff(connID uint64, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToQueueHandoffMessage, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func FailedToMarkForHandoff(connID uint64, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToMarkForHandoffMessage, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func FailedToDialNewEndpoint(connID uint64, endpoint string, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s: %v", connID, FailedToDialNewEndpointMessage, endpoint, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func ReachedMaxHandoffRetries(connID uint64, endpoint string, maxRetries int) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s (max retries: %d)", connID, ReachedMaxHandoffRetriesMessage, endpoint, maxRetries)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"maxRetries": maxRetries,
|
||||
})
|
||||
}
|
||||
|
||||
// Notification processing functions
|
||||
func ProcessingNotification(connID uint64, seqID int64, notificationType string, notification interface{}) string {
|
||||
message := fmt.Sprintf("conn[%d] seqID[%d] %s %s: %v", connID, seqID, ProcessingNotificationMessage, notificationType, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seqID": seqID,
|
||||
"notificationType": notificationType,
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func ProcessingNotificationFailed(connID uint64, notificationType string, err error, notification interface{}) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s: %v - %v", connID, ProcessingNotificationFailedMessage, notificationType, err, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"notificationType": notificationType,
|
||||
"error": err.Error(),
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func ProcessingNotificationSucceeded(connID uint64, notificationType string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s", connID, ProcessingNotificationSucceededMessage, notificationType)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"notificationType": notificationType,
|
||||
})
|
||||
}
|
||||
|
||||
// Moving operation tracking functions
|
||||
func DuplicateMovingOperation(connID uint64, endpoint string, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, DuplicateMovingOperationMessage, endpoint, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
func TrackingMovingOperation(connID uint64, endpoint string, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, TrackingMovingOperationMessage, endpoint, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
func UntrackingMovingOperation(connID uint64, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, UntrackingMovingOperationMessage, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
func OperationNotTracked(connID uint64, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, OperationNotTrackedMessage, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
// Connection pool functions
|
||||
func RemovingConnectionFromPool(connID uint64, reason error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, RemovingConnectionFromPoolMessage, reason)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"reason": reason.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func NoPoolProvidedCannotRemove(connID uint64, reason error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, NoPoolProvidedMessageCannotRemoveMessage, reason)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"reason": reason.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Circuit breaker functions
|
||||
func CircuitBreakerOpen(connID uint64, endpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s for %s", connID, CircuitBreakerOpenMessage, endpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
// Additional handoff functions for specific cases
|
||||
func ConnectionNotMarkedForHandoff(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
func ConnectionNotMarkedForHandoffError(connID uint64) string {
|
||||
return fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffErrorMessage)
|
||||
}
|
||||
|
||||
func HandoffRetryAttempt(connID uint64, retries int, newEndpoint string, oldEndpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] Retry %d: %s to %s(was %s)", connID, retries, HandoffRetryAttemptMessage, newEndpoint, oldEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"retries": retries,
|
||||
"newEndpoint": newEndpoint,
|
||||
"oldEndpoint": oldEndpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func CannotQueueHandoffForRetry(err error) string {
|
||||
message := fmt.Sprintf("%s: %v", CannotQueueHandoffForRetryMessage, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Validation and error functions
|
||||
func InvalidNotificationFormat(notification interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidNotificationFormatMessage, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidNotificationTypeFormat(notificationType interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidNotificationTypeFormatMessage, notificationType)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": fmt.Sprintf("%v", notificationType),
|
||||
})
|
||||
}
|
||||
|
||||
// InvalidNotification creates a log message for invalid notifications of any type
|
||||
func InvalidNotification(notificationType string, notification interface{}) string {
|
||||
message := fmt.Sprintf("invalid %s notification: %v", notificationType, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidSeqIDInMovingNotification(seqID interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidSeqIDInMovingNotificationMessage, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"seqID": fmt.Sprintf("%v", seqID),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidTimeSInMovingNotification(timeS interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidTimeSInMovingNotificationMessage, timeS)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"timeS": fmt.Sprintf("%v", timeS),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidNewEndpointInMovingNotification(newEndpoint interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidNewEndpointInMovingNotificationMessage, newEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"newEndpoint": fmt.Sprintf("%v", newEndpoint),
|
||||
})
|
||||
}
|
||||
|
||||
func NoConnectionInHandlerContext(notificationType string) string {
|
||||
message := fmt.Sprintf("%s for %s notification", NoConnectionInHandlerContextMessage, notificationType)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidConnectionTypeInHandlerContext(notificationType string, conn interface{}, handlerCtx interface{}) string {
|
||||
message := fmt.Sprintf("%s for %s notification - %T %#v", InvalidConnectionTypeInHandlerContextMessage, notificationType, conn, handlerCtx)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"connType": fmt.Sprintf("%T", conn),
|
||||
})
|
||||
}
|
||||
|
||||
func SchedulingHandoffToCurrentEndpoint(connID uint64, seconds float64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s in %v seconds", connID, SchedulingHandoffToCurrentEndpointMessage, seconds)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seconds": seconds,
|
||||
})
|
||||
}
|
||||
|
||||
func ManagerNotInitialized() string {
|
||||
return appendJSONIfDebug(ManagerNotInitializedMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func FailedToRegisterHandler(notificationType string, err error) string {
|
||||
message := fmt.Sprintf("%s for %s: %v", FailedToRegisterHandlerMessage, notificationType, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func ShutdownError() string {
|
||||
return appendJSONIfDebug(ShutdownErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Configuration validation error functions
|
||||
func InvalidRelaxedTimeoutError() string {
|
||||
return appendJSONIfDebug(InvalidRelaxedTimeoutErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffTimeoutError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffTimeoutErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffWorkersError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffWorkersErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffQueueSizeError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffQueueSizeErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidPostHandoffRelaxedDurationError() string {
|
||||
return appendJSONIfDebug(InvalidPostHandoffRelaxedDurationErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidEndpointTypeError() string {
|
||||
return appendJSONIfDebug(InvalidEndpointTypeErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidMaintNotificationsError() string {
|
||||
return appendJSONIfDebug(InvalidMaintNotificationsErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffRetriesError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffRetriesErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidClientError() string {
|
||||
return appendJSONIfDebug(InvalidClientErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidNotificationError() string {
|
||||
return appendJSONIfDebug(InvalidNotificationErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func MaxHandoffRetriesReachedError() string {
|
||||
return appendJSONIfDebug(MaxHandoffRetriesReachedErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func HandoffQueueFullError() string {
|
||||
return appendJSONIfDebug(HandoffQueueFullErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidCircuitBreakerFailureThresholdError() string {
|
||||
return appendJSONIfDebug(InvalidCircuitBreakerFailureThresholdErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidCircuitBreakerResetTimeoutError() string {
|
||||
return appendJSONIfDebug(InvalidCircuitBreakerResetTimeoutErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidCircuitBreakerMaxRequestsError() string {
|
||||
return appendJSONIfDebug(InvalidCircuitBreakerMaxRequestsErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Configuration and debug functions
|
||||
func DebugLoggingEnabled() string {
|
||||
return appendJSONIfDebug(DebugLoggingEnabledMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func ConfigDebug(config interface{}) string {
|
||||
message := fmt.Sprintf("%s: %+v", ConfigDebugMessage, config)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"config": fmt.Sprintf("%+v", config),
|
||||
})
|
||||
}
|
||||
|
||||
// Handoff worker functions
|
||||
func WorkerExitingDueToShutdown() string {
|
||||
return appendJSONIfDebug(WorkerExitingDueToShutdownMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func WorkerExitingDueToShutdownWhileProcessing() string {
|
||||
return appendJSONIfDebug(WorkerExitingDueToShutdownWhileProcessingMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func WorkerPanicRecovered(panicValue interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", WorkerPanicRecoveredMessage, panicValue)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"panic": fmt.Sprintf("%v", panicValue),
|
||||
})
|
||||
}
|
||||
|
||||
func WorkerExitingDueToInactivityTimeout(timeout interface{}) string {
|
||||
message := fmt.Sprintf("%s (%v)", WorkerExitingDueToInactivityTimeoutMessage, timeout)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"timeout": fmt.Sprintf("%v", timeout),
|
||||
})
|
||||
}
|
||||
|
||||
func ApplyingRelaxedTimeoutDueToPostHandoff(connID uint64, timeout interface{}, until string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s (%v) until %s", connID, ApplyingRelaxedTimeoutDueToPostHandoffMessage, timeout, until)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"timeout": fmt.Sprintf("%v", timeout),
|
||||
"until": until,
|
||||
})
|
||||
}
|
||||
|
||||
// Example hooks functions
|
||||
func MetricsHookProcessingNotification(notificationType string, connID uint64) string {
|
||||
message := fmt.Sprintf("%s %s notification on conn[%d]", MetricsHookProcessingNotificationMessage, notificationType, connID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
func MetricsHookRecordedError(notificationType string, connID uint64, err error) string {
|
||||
message := fmt.Sprintf("%s for %s notification on conn[%d]: %v", MetricsHookRecordedErrorMessage, notificationType, connID, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"connID": connID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Pool hook functions
|
||||
func MarkedForHandoff(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, MarkedForHandoffMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
// Circuit breaker additional functions
|
||||
func CircuitBreakerTransitioningToHalfOpen(endpoint string) string {
|
||||
message := fmt.Sprintf("%s for %s", CircuitBreakerTransitioningToHalfOpenMessage, endpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerOpened(endpoint string, failures int64) string {
|
||||
message := fmt.Sprintf("%s for endpoint %s after %d failures", CircuitBreakerOpenedMessage, endpoint, failures)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
"failures": failures,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerReopened(endpoint string) string {
|
||||
message := fmt.Sprintf("%s for endpoint %s due to failure in half-open state", CircuitBreakerReopenedMessage, endpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerClosed(endpoint string, successes int64) string {
|
||||
message := fmt.Sprintf("%s for endpoint %s after %d successful requests", CircuitBreakerClosedMessage, endpoint, successes)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
"successes": successes,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerCleanup(removed int, total int) string {
|
||||
message := fmt.Sprintf("%s removed %d/%d entries", CircuitBreakerCleanupMessage, removed, total)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"removed": removed,
|
||||
"total": total,
|
||||
})
|
||||
}
|
||||
|
||||
// ExtractDataFromLogMessage extracts structured data from maintnotifications log messages
|
||||
// Returns a map containing the parsed key-value pairs from the structured data section
|
||||
// Example: "conn[123] handoff started to localhost:6379 {"connID":123,"endpoint":"localhost:6379"}"
|
||||
// Returns: map[string]interface{}{"connID": 123, "endpoint": "localhost:6379"}
|
||||
func ExtractDataFromLogMessage(logMessage string) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// Find the JSON data section at the end of the message
|
||||
re := regexp.MustCompile(`(\{.*\})$`)
|
||||
matches := re.FindStringSubmatch(logMessage)
|
||||
if len(matches) < 2 {
|
||||
return result
|
||||
}
|
||||
|
||||
jsonStr := matches[1]
|
||||
if jsonStr == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
// Parse the JSON directly
|
||||
var jsonResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &jsonResult); err == nil {
|
||||
return jsonResult
|
||||
}
|
||||
|
||||
// If JSON parsing fails, return empty map
|
||||
return result
|
||||
}
|
||||
+789
-21
@@ -1,28 +1,124 @@
|
||||
// Package pool implements the pool management
|
||||
package pool
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
)
|
||||
|
||||
var noDeadline = time.Time{}
|
||||
|
||||
// Preallocated errors for hot paths to avoid allocations
|
||||
var (
|
||||
errAlreadyMarkedForHandoff = errors.New("connection is already marked for handoff")
|
||||
errNotMarkedForHandoff = errors.New("connection was not marked for handoff")
|
||||
errHandoffStateChanged = errors.New("handoff state changed during marking")
|
||||
errConnectionNotAvailable = errors.New("redis: connection not available")
|
||||
errConnNotAvailableForWrite = errors.New("redis: connection not available for write operation")
|
||||
)
|
||||
|
||||
// getCachedTimeNs returns the current time in nanoseconds.
|
||||
// This function previously used a global cache updated by a background goroutine,
|
||||
// but that caused unnecessary CPU usage when the client was idle (ticker waking up
|
||||
// the scheduler every 50ms). We now use time.Now() directly, which is fast enough
|
||||
// on modern systems (vDSO on Linux) and only adds ~1-2% overhead in extreme
|
||||
// high-concurrency benchmarks while eliminating idle CPU usage.
|
||||
func getCachedTimeNs() int64 {
|
||||
return time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// GetCachedTimeNs returns the current time in nanoseconds.
|
||||
// Exported for use by other packages that need fast time access.
|
||||
func GetCachedTimeNs() int64 {
|
||||
return getCachedTimeNs()
|
||||
}
|
||||
|
||||
// Global atomic counter for connection IDs
|
||||
var connIDCounter uint64
|
||||
|
||||
// HandoffState represents the atomic state for connection handoffs
|
||||
// This struct is stored atomically to prevent race conditions between
|
||||
// checking handoff status and reading handoff parameters
|
||||
type HandoffState struct {
|
||||
ShouldHandoff bool // Whether connection should be handed off
|
||||
Endpoint string // New endpoint for handoff
|
||||
SeqID int64 // Sequence ID from MOVING notification
|
||||
}
|
||||
|
||||
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
|
||||
type atomicNetConn struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// generateConnID generates a fast unique identifier for a connection with zero allocations
|
||||
func generateConnID() uint64 {
|
||||
return atomic.AddUint64(&connIDCounter, 1)
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
usedAt int64 // atomic
|
||||
netConn net.Conn
|
||||
// Connection identifier for unique tracking
|
||||
id uint64
|
||||
|
||||
usedAt atomic.Int64
|
||||
lastPutAt atomic.Int64
|
||||
|
||||
// Lock-free netConn access using atomic.Value
|
||||
// Contains *atomicNetConn wrapper, accessed atomically for better performance
|
||||
netConnAtomic atomic.Value // stores *atomicNetConn
|
||||
|
||||
rd *proto.Reader
|
||||
bw *bufio.Writer
|
||||
wr *proto.Writer
|
||||
|
||||
Inited bool
|
||||
// Lightweight mutex to protect reader operations during handoff
|
||||
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
|
||||
readerMu sync.RWMutex
|
||||
|
||||
// State machine for connection state management
|
||||
// Replaces: usable, Inited, used
|
||||
// Provides thread-safe state transitions with FIFO waiting queue
|
||||
// States: CREATED → INITIALIZING → IDLE ⇄ IN_USE
|
||||
// ↓
|
||||
// UNUSABLE (handoff/reauth)
|
||||
// ↓
|
||||
// IDLE/CLOSED
|
||||
stateMachine *ConnStateMachine
|
||||
|
||||
// Handoff metadata - managed separately from state machine
|
||||
// These are atomic for lock-free access during handoff operations
|
||||
handoffStateAtomic atomic.Value // stores *HandoffState
|
||||
handoffRetriesAtomic atomic.Uint32 // retry counter
|
||||
|
||||
pooled bool
|
||||
pubsub bool
|
||||
closed atomic.Bool
|
||||
createdAt time.Time
|
||||
expiresAt time.Time
|
||||
|
||||
// maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers
|
||||
|
||||
// Using atomic operations for lock-free access to avoid mutex contention
|
||||
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
|
||||
|
||||
// Counter to track multiple relaxed timeout setters if we have nested calls
|
||||
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
|
||||
// if counter reaches 0, we clear the relaxed timeouts
|
||||
relaxedCounter atomic.Int32
|
||||
|
||||
// Connection initialization function for reconnections
|
||||
initConnFunc func(context.Context, *Conn) error
|
||||
|
||||
onClose func() error
|
||||
}
|
||||
@@ -32,9 +128,11 @@ func NewConn(netConn net.Conn) *Conn {
|
||||
}
|
||||
|
||||
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
|
||||
now := time.Now()
|
||||
cn := &Conn{
|
||||
netConn: netConn,
|
||||
createdAt: time.Now(),
|
||||
createdAt: now,
|
||||
id: generateConnID(), // Generate unique ID for this connection
|
||||
stateMachine: NewConnStateMachine(),
|
||||
}
|
||||
|
||||
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
|
||||
@@ -50,37 +148,656 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
|
||||
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
|
||||
}
|
||||
|
||||
// Store netConn atomically for lock-free access using wrapper
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
|
||||
cn.wr = proto.NewWriter(cn.bw)
|
||||
cn.SetUsedAt(time.Now())
|
||||
cn.SetUsedAt(now)
|
||||
// Initialize handoff state atomically
|
||||
initialHandoffState := &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: "",
|
||||
SeqID: 0,
|
||||
}
|
||||
cn.handoffStateAtomic.Store(initialHandoffState)
|
||||
return cn
|
||||
}
|
||||
|
||||
func (cn *Conn) UsedAt() time.Time {
|
||||
unix := atomic.LoadInt64(&cn.usedAt)
|
||||
return time.Unix(unix, 0)
|
||||
return time.Unix(0, cn.usedAt.Load())
|
||||
}
|
||||
func (cn *Conn) SetUsedAt(tm time.Time) {
|
||||
cn.usedAt.Store(tm.UnixNano())
|
||||
}
|
||||
|
||||
func (cn *Conn) SetUsedAt(tm time.Time) {
|
||||
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
||||
func (cn *Conn) UsedAtNs() int64 {
|
||||
return cn.usedAt.Load()
|
||||
}
|
||||
func (cn *Conn) SetUsedAtNs(ns int64) {
|
||||
cn.usedAt.Store(ns)
|
||||
}
|
||||
|
||||
func (cn *Conn) LastPutAtNs() int64 {
|
||||
return cn.lastPutAt.Load()
|
||||
}
|
||||
func (cn *Conn) SetLastPutAtNs(ns int64) {
|
||||
cn.lastPutAt.Store(ns)
|
||||
}
|
||||
|
||||
// Backward-compatible wrapper methods for state machine
|
||||
// These maintain the existing API while using the new state machine internally
|
||||
|
||||
// CompareAndSwapUsable atomically compares and swaps the usable flag (lock-free).
|
||||
//
|
||||
// This is used by background operations (handoff, re-auth) to acquire exclusive
|
||||
// access to a connection. The operation sets usable to false, preventing the pool
|
||||
// from returning the connection to clients.
|
||||
//
|
||||
// Returns true if the swap was successful (old value matched), false otherwise.
|
||||
//
|
||||
// Implementation note: This is a compatibility wrapper around the state machine.
|
||||
// It checks if the current state is "usable" (IDLE or IN_USE) and transitions accordingly.
|
||||
// Deprecated: Use GetStateMachine().TryTransition() directly for better state management.
|
||||
func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
|
||||
currentState := cn.stateMachine.GetState()
|
||||
|
||||
// Check if current state matches the "old" usable value
|
||||
currentUsable := (currentState == StateIdle || currentState == StateInUse)
|
||||
if currentUsable != old {
|
||||
return false
|
||||
}
|
||||
|
||||
// If we're trying to set to the same value, succeed immediately
|
||||
if old == new {
|
||||
return true
|
||||
}
|
||||
|
||||
// Transition based on new value
|
||||
if new {
|
||||
// Trying to make usable - transition from UNUSABLE to IDLE
|
||||
// This should only work from UNUSABLE or INITIALIZING states
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(
|
||||
validFromInitializingOrUnusable,
|
||||
StateIdle,
|
||||
)
|
||||
return err == nil
|
||||
}
|
||||
// Trying to make unusable - transition from IDLE to UNUSABLE
|
||||
// This is typically for acquiring the connection for background operations
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(
|
||||
validFromIdle,
|
||||
StateUnusable,
|
||||
)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
|
||||
//
|
||||
// A connection is "usable" when it's in a stable state and can be returned to clients.
|
||||
// It becomes unusable during:
|
||||
// - Handoff operations (network connection replacement)
|
||||
// - Re-authentication (credential updates)
|
||||
// - Other background operations that need exclusive access
|
||||
//
|
||||
// Note: CREATED state is considered usable because new connections need to pass OnGet() hook
|
||||
// before initialization. The initialization happens after OnGet() in the client code.
|
||||
func (cn *Conn) IsUsable() bool {
|
||||
state := cn.stateMachine.GetState()
|
||||
// CREATED, IDLE, and IN_USE states are considered usable
|
||||
// CREATED: new connection, not yet initialized (will be initialized by client)
|
||||
// IDLE: initialized and ready to be acquired
|
||||
// IN_USE: usable but currently acquired by someone
|
||||
return state == StateCreated || state == StateIdle || state == StateInUse
|
||||
}
|
||||
|
||||
// SetUsable sets the usable flag for the connection (lock-free).
|
||||
//
|
||||
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// This should be called to mark a connection as usable after initialization or
|
||||
// to release it after a background operation completes.
|
||||
//
|
||||
// Prefer CompareAndSwapUsable() when acquiring exclusive access to avoid race conditions.
|
||||
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
|
||||
func (cn *Conn) SetUsable(usable bool) {
|
||||
if usable {
|
||||
// Transition to IDLE state (ready to be acquired)
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
} else {
|
||||
// Transition to UNUSABLE state (for background operations)
|
||||
cn.stateMachine.Transition(StateUnusable)
|
||||
}
|
||||
}
|
||||
|
||||
// IsInited returns true if the connection has been initialized.
|
||||
// This is a backward-compatible wrapper around the state machine.
|
||||
func (cn *Conn) IsInited() bool {
|
||||
state := cn.stateMachine.GetState()
|
||||
// Connection is initialized if it's in IDLE or any post-initialization state
|
||||
return state != StateCreated && state != StateInitializing && state != StateClosed
|
||||
}
|
||||
|
||||
// Used - State machine based implementation
|
||||
|
||||
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// This is the preferred method for acquiring a connection from the pool, as it
|
||||
// ensures that only one goroutine marks the connection as used.
|
||||
//
|
||||
// Implementation: Uses state machine transitions IDLE ⇄ IN_USE
|
||||
//
|
||||
// Returns true if the swap was successful (old value matched), false otherwise.
|
||||
// Deprecated: Use GetStateMachine().TryTransition() directly for better state management.
|
||||
func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
|
||||
if old == new {
|
||||
// No change needed
|
||||
currentState := cn.stateMachine.GetState()
|
||||
currentUsed := (currentState == StateInUse)
|
||||
return currentUsed == old
|
||||
}
|
||||
|
||||
if !old && new {
|
||||
// Acquiring: IDLE → IN_USE
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(validFromCreatedOrIdle, StateInUse)
|
||||
return err == nil
|
||||
} else {
|
||||
// Releasing: IN_USE → IDLE
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(validFromInUse, StateIdle)
|
||||
return err == nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsUsed returns true if the connection is currently in use (lock-free).
|
||||
//
|
||||
// Deprecated: Use GetStateMachine().GetState() == StateInUse directly for better clarity.
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// A connection is "used" when it has been retrieved from the pool and is
|
||||
// actively processing a command. Background operations (like re-auth) should
|
||||
// wait until the connection is not used before executing commands.
|
||||
func (cn *Conn) IsUsed() bool {
|
||||
return cn.stateMachine.GetState() == StateInUse
|
||||
}
|
||||
|
||||
// SetUsed sets the used flag for the connection (lock-free).
|
||||
//
|
||||
// This should be called when returning a connection to the pool (set to false)
|
||||
// or when a single-connection pool retrieves its connection (set to true).
|
||||
//
|
||||
// Prefer CompareAndSwapUsed() when acquiring from a multi-connection pool to
|
||||
// avoid race conditions.
|
||||
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
|
||||
func (cn *Conn) SetUsed(val bool) {
|
||||
if val {
|
||||
cn.stateMachine.Transition(StateInUse)
|
||||
} else {
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
}
|
||||
}
|
||||
|
||||
// getNetConn returns the current network connection using atomic load (lock-free).
|
||||
// This is the fast path for accessing netConn without mutex overhead.
|
||||
func (cn *Conn) getNetConn() net.Conn {
|
||||
if v := cn.netConnAtomic.Load(); v != nil {
|
||||
if wrapper, ok := v.(*atomicNetConn); ok {
|
||||
return wrapper.conn
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setNetConn stores the network connection atomically (lock-free).
|
||||
// This is used for the fast path of connection replacement.
|
||||
func (cn *Conn) setNetConn(netConn net.Conn) {
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
}
|
||||
|
||||
// Handoff state management - atomic access to handoff metadata
|
||||
|
||||
// ShouldHandoff returns true if connection needs handoff (lock-free).
|
||||
func (cn *Conn) ShouldHandoff() bool {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
return v.(*HandoffState).ShouldHandoff
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
|
||||
func (cn *Conn) GetHandoffEndpoint() string {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
return v.(*HandoffState).Endpoint
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
|
||||
func (cn *Conn) GetMovingSeqID() int64 {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
return v.(*HandoffState).SeqID
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetHandoffInfo returns all handoff information atomically (lock-free).
|
||||
// This method prevents race conditions by returning all handoff state in a single atomic operation.
|
||||
// Returns (shouldHandoff, endpoint, seqID).
|
||||
func (cn *Conn) GetHandoffInfo() (bool, string, int64) {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
state := v.(*HandoffState)
|
||||
return state.ShouldHandoff, state.Endpoint, state.SeqID
|
||||
}
|
||||
return false, "", 0
|
||||
}
|
||||
|
||||
// HandoffRetries returns the current handoff retry count (lock-free).
|
||||
func (cn *Conn) HandoffRetries() int {
|
||||
return int(cn.handoffRetriesAtomic.Load())
|
||||
}
|
||||
|
||||
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
|
||||
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
|
||||
return int(cn.handoffRetriesAtomic.Add(uint32(n)))
|
||||
}
|
||||
|
||||
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
|
||||
func (cn *Conn) IsPooled() bool {
|
||||
return cn.pooled
|
||||
}
|
||||
|
||||
// IsPubSub returns true if the connection is used for PubSub.
|
||||
func (cn *Conn) IsPubSub() bool {
|
||||
return cn.pubsub
|
||||
}
|
||||
|
||||
// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades.
|
||||
// These timeouts will be used for all subsequent commands until the deadline expires.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
|
||||
cn.relaxedCounter.Add(1)
|
||||
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
|
||||
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
|
||||
}
|
||||
|
||||
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
|
||||
// After the deadline, timeouts automatically revert to normal values.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
|
||||
cn.SetRelaxedTimeout(readTimeout, writeTimeout)
|
||||
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
|
||||
}
|
||||
|
||||
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) ClearRelaxedTimeout() {
|
||||
// Atomically decrement counter and check if we should clear
|
||||
newCount := cn.relaxedCounter.Add(-1)
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
if newCount <= 0 && (deadlineNs == 0 || time.Now().UnixNano() >= deadlineNs) {
|
||||
// Use atomic load to get current value for CAS to avoid stale value race
|
||||
current := cn.relaxedCounter.Load()
|
||||
if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) {
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *Conn) clearRelaxedTimeout() {
|
||||
cn.relaxedReadTimeoutNs.Store(0)
|
||||
cn.relaxedWriteTimeoutNs.Store(0)
|
||||
cn.relaxedDeadlineNs.Store(0)
|
||||
cn.relaxedCounter.Store(0)
|
||||
}
|
||||
|
||||
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
|
||||
// This checks both the timeout values and the deadline (if set).
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) HasRelaxedTimeout() bool {
|
||||
// Fast path: no relaxed timeouts are set
|
||||
if cn.relaxedCounter.Load() <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||
|
||||
// If no relaxed timeouts are set, return false
|
||||
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, relaxed timeouts are active
|
||||
if deadlineNs == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// If deadline is set, check if it's still in the future
|
||||
return time.Now().UnixNano() < deadlineNs
|
||||
}
|
||||
|
||||
// getEffectiveReadTimeout returns the timeout to use for read operations.
|
||||
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
|
||||
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||
|
||||
// Fast path: no relaxed timeout set
|
||||
if readTimeoutNs <= 0 {
|
||||
return normalTimeout
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, use relaxed timeout
|
||||
if deadlineNs == 0 {
|
||||
return time.Duration(readTimeoutNs)
|
||||
}
|
||||
|
||||
// Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
|
||||
nowNs := getCachedTimeNs()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
return time.Duration(readTimeoutNs)
|
||||
} else {
|
||||
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||
newCount := cn.relaxedCounter.Add(-1)
|
||||
if newCount <= 0 {
|
||||
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
return normalTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// getEffectiveWriteTimeout returns the timeout to use for write operations.
|
||||
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
|
||||
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||
|
||||
// Fast path: no relaxed timeout set
|
||||
if writeTimeoutNs <= 0 {
|
||||
return normalTimeout
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, use relaxed timeout
|
||||
if deadlineNs == 0 {
|
||||
return time.Duration(writeTimeoutNs)
|
||||
}
|
||||
|
||||
// Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
|
||||
nowNs := getCachedTimeNs()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
return time.Duration(writeTimeoutNs)
|
||||
} else {
|
||||
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||
newCount := cn.relaxedCounter.Add(-1)
|
||||
if newCount <= 0 {
|
||||
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
return normalTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *Conn) SetOnClose(fn func() error) {
|
||||
cn.onClose = fn
|
||||
}
|
||||
|
||||
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
|
||||
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
|
||||
cn.initConnFunc = fn
|
||||
}
|
||||
|
||||
// ExecuteInitConn runs the stored connection initialization function if available.
|
||||
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
|
||||
if cn.initConnFunc != nil {
|
||||
return cn.initConnFunc(ctx, cn)
|
||||
}
|
||||
return fmt.Errorf("redis: no initConnFunc set for conn[%d]", cn.GetID())
|
||||
}
|
||||
|
||||
func (cn *Conn) SetNetConn(netConn net.Conn) {
|
||||
cn.netConn = netConn
|
||||
// Store the new connection atomically first (lock-free)
|
||||
cn.setNetConn(netConn)
|
||||
// Protect reader reset operations to avoid data races
|
||||
// Use write lock since we're modifying the reader state
|
||||
cn.readerMu.Lock()
|
||||
cn.rd.Reset(netConn)
|
||||
cn.readerMu.Unlock()
|
||||
|
||||
cn.bw.Reset(netConn)
|
||||
}
|
||||
|
||||
// GetNetConn safely returns the current network connection using atomic load (lock-free).
|
||||
// This method is used by the pool for health checks and provides better performance.
|
||||
func (cn *Conn) GetNetConn() net.Conn {
|
||||
return cn.getNetConn()
|
||||
}
|
||||
|
||||
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
|
||||
// This method ensures only one initialization can happen at a time by using atomic state transitions.
|
||||
// If another goroutine is currently initializing, this will wait for it to complete.
|
||||
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
|
||||
// Wait for and transition to INITIALIZING state - this prevents concurrent initializations
|
||||
// Valid from states: CREATED (first init), IDLE (reconnect), UNUSABLE (handoff/reauth)
|
||||
// If another goroutine is initializing, we'll wait for it to finish
|
||||
// if the context has a deadline, use that, otherwise use the connection read (relaxed) timeout
|
||||
// which should be set during handoff. If it is not set, use a 5 second default
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
deadline = time.Now().Add(cn.getEffectiveReadTimeout(5 * time.Second))
|
||||
}
|
||||
waitCtx, cancel := context.WithDeadline(ctx, deadline)
|
||||
defer cancel()
|
||||
// Use predefined slice to avoid allocation
|
||||
finalState, err := cn.stateMachine.AwaitAndTransition(
|
||||
waitCtx,
|
||||
validFromCreatedIdleOrUnusable,
|
||||
StateInitializing,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot initialize connection from state %s: %w", finalState, err)
|
||||
}
|
||||
|
||||
// Replace the underlying connection
|
||||
cn.SetNetConn(netConn)
|
||||
|
||||
// Execute initialization
|
||||
// NOTE: ExecuteInitConn (via baseClient.initConn) will transition to IDLE on success
|
||||
// or CLOSED on failure. We don't need to do it here.
|
||||
// NOTE: Initconn returns conn in IDLE state
|
||||
initErr := cn.ExecuteInitConn(ctx)
|
||||
if initErr != nil {
|
||||
// ExecuteInitConn already transitioned to CLOSED, just return the error
|
||||
return initErr
|
||||
}
|
||||
|
||||
// ExecuteInitConn already transitioned to IDLE
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkForHandoff marks the connection for handoff due to MOVING notification.
|
||||
// Returns an error if the connection is already marked for handoff.
|
||||
// Note: This only sets metadata - the connection state is not changed until OnPut.
|
||||
// This allows the current user to finish using the connection before handoff.
|
||||
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
|
||||
// Check if already marked for handoff
|
||||
if cn.ShouldHandoff() {
|
||||
return errAlreadyMarkedForHandoff
|
||||
}
|
||||
|
||||
// Set handoff metadata atomically
|
||||
cn.handoffStateAtomic.Store(&HandoffState{
|
||||
ShouldHandoff: true,
|
||||
Endpoint: newEndpoint,
|
||||
SeqID: seqID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkQueuedForHandoff marks the connection as queued for handoff processing.
|
||||
// This makes the connection unusable until handoff completes.
|
||||
// This is called from OnPut hook, where the connection is typically in IN_USE state.
|
||||
// The pool will preserve the UNUSABLE state and not overwrite it with IDLE.
|
||||
func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
// Get current handoff state
|
||||
currentState := cn.handoffStateAtomic.Load()
|
||||
if currentState == nil {
|
||||
return errNotMarkedForHandoff
|
||||
}
|
||||
|
||||
state := currentState.(*HandoffState)
|
||||
if !state.ShouldHandoff {
|
||||
return errNotMarkedForHandoff
|
||||
}
|
||||
|
||||
// Create new state with ShouldHandoff=false but preserve endpoint and seqID
|
||||
// This prevents the connection from being queued multiple times while still
|
||||
// allowing the worker to access the handoff metadata
|
||||
newState := &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: state.Endpoint, // Preserve endpoint for handoff processing
|
||||
SeqID: state.SeqID, // Preserve seqID for handoff processing
|
||||
}
|
||||
|
||||
// Atomic compare-and-swap to update state
|
||||
if !cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
||||
// State changed between load and CAS - retry or return error
|
||||
return errHandoffStateChanged
|
||||
}
|
||||
|
||||
// Transition to UNUSABLE from IN_USE (normal flow), IDLE (edge cases), or CREATED (tests/uninitialized)
|
||||
// The connection is typically in IN_USE state when OnPut is called (normal Put flow)
|
||||
// But in some edge cases or tests, it might be in IDLE or CREATED state
|
||||
// The pool will detect this state change and preserve it (not overwrite with IDLE)
|
||||
// Use predefined slice to avoid allocation
|
||||
finalState, err := cn.stateMachine.TryTransition(validFromCreatedInUseOrIdle, StateUnusable)
|
||||
if err != nil {
|
||||
// Check if already in UNUSABLE state (race condition or retry)
|
||||
// ShouldHandoff should be false now, but check just in case
|
||||
if finalState == StateUnusable && !cn.ShouldHandoff() {
|
||||
// Already unusable - this is fine, keep the new handoff state
|
||||
return nil
|
||||
}
|
||||
// Restore the original state if transition fails for other reasons
|
||||
cn.handoffStateAtomic.Store(currentState)
|
||||
return fmt.Errorf("failed to mark connection as unusable: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetID returns the unique identifier for this connection.
|
||||
func (cn *Conn) GetID() uint64 {
|
||||
return cn.id
|
||||
}
|
||||
|
||||
// GetStateMachine returns the connection's state machine for advanced state management.
|
||||
// This is primarily used by internal packages like maintnotifications for handoff processing.
|
||||
func (cn *Conn) GetStateMachine() *ConnStateMachine {
|
||||
return cn.stateMachine
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire the connection for use.
|
||||
// This is an optimized inline method for the hot path (Get operation).
|
||||
//
|
||||
// It tries to transition from IDLE -> IN_USE or CREATED -> CREATED.
|
||||
// Returns true if the connection was successfully acquired, false otherwise.
|
||||
// The CREATED->CREATED is done so we can keep the state correct for later
|
||||
// initialization of the connection in initConn.
|
||||
//
|
||||
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast()
|
||||
//
|
||||
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
|
||||
// methods. This breaks encapsulation but is necessary for performance.
|
||||
// The IDLE->IN_USE and CREATED->CREATED transitions don't need
|
||||
// waiter notification, and benchmarks show 1-3% improvement. If the state machine ever
|
||||
// needs to notify waiters on these transitions, update this to use TryTransitionFast().
|
||||
func (cn *Conn) TryAcquire() bool {
|
||||
// The || operator short-circuits, so only 1 CAS in the common case
|
||||
return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) ||
|
||||
cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateCreated))
|
||||
}
|
||||
|
||||
// Release releases the connection back to the pool.
|
||||
// This is an optimized inline method for the hot path (Put operation).
|
||||
//
|
||||
// It tries to transition from IN_USE -> IDLE.
|
||||
// Returns true if the connection was successfully released, false otherwise.
|
||||
//
|
||||
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast().
|
||||
//
|
||||
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
|
||||
// methods. This breaks encapsulation but is necessary for performance.
|
||||
// If the state machine ever needs to notify waiters
|
||||
// on this transition, update this to use TryTransitionFast().
|
||||
func (cn *Conn) Release() bool {
|
||||
// Inline the hot path - single CAS operation
|
||||
return cn.stateMachine.state.CompareAndSwap(uint32(StateInUse), uint32(StateIdle))
|
||||
}
|
||||
|
||||
// ClearHandoffState clears the handoff state after successful handoff.
|
||||
// Makes the connection usable again.
|
||||
func (cn *Conn) ClearHandoffState() {
|
||||
// Clear handoff metadata
|
||||
cn.handoffStateAtomic.Store(&HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: "",
|
||||
SeqID: 0,
|
||||
})
|
||||
|
||||
// Reset retry counter
|
||||
cn.handoffRetriesAtomic.Store(0)
|
||||
|
||||
// Mark connection as usable again
|
||||
// Use state machine directly instead of deprecated SetUsable
|
||||
// probably done by initConn
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
}
|
||||
|
||||
// HasBufferedData safely checks if the connection has buffered data.
|
||||
// This method is used to avoid data races when checking for push notifications.
|
||||
func (cn *Conn) HasBufferedData() bool {
|
||||
// Use read lock for concurrent access to reader state
|
||||
cn.readerMu.RLock()
|
||||
defer cn.readerMu.RUnlock()
|
||||
return cn.rd.Buffered() > 0
|
||||
}
|
||||
|
||||
// PeekReplyTypeSafe safely peeks at the reply type.
|
||||
// This method is used to avoid data races when checking for push notifications.
|
||||
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
|
||||
// Use read lock for concurrent access to reader state
|
||||
cn.readerMu.RLock()
|
||||
defer cn.readerMu.RUnlock()
|
||||
|
||||
if cn.rd.Buffered() <= 0 {
|
||||
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
|
||||
}
|
||||
return cn.rd.PeekReplyType()
|
||||
}
|
||||
|
||||
func (cn *Conn) Write(b []byte) (int, error) {
|
||||
return cn.netConn.Write(b)
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.Write(b)
|
||||
}
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
func (cn *Conn) RemoteAddr() net.Addr {
|
||||
if cn.netConn != nil {
|
||||
return cn.netConn.RemoteAddr()
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.RemoteAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -89,7 +806,16 @@ func (cn *Conn) WithReader(
|
||||
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
|
||||
) error {
|
||||
if timeout >= 0 {
|
||||
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
|
||||
// Use relaxed timeout if set, otherwise use provided timeout
|
||||
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
|
||||
|
||||
// Get the connection directly from atomic storage
|
||||
netConn := cn.getNetConn()
|
||||
if netConn == nil {
|
||||
return errConnectionNotAvailable
|
||||
}
|
||||
|
||||
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -100,13 +826,25 @@ func (cn *Conn) WithWriter(
|
||||
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
|
||||
) error {
|
||||
if timeout >= 0 {
|
||||
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
|
||||
return err
|
||||
// Use relaxed timeout if set, otherwise use provided timeout
|
||||
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
|
||||
|
||||
// Set write deadline on the connection
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Connection is not available - return preallocated error
|
||||
return errConnNotAvailableForWrite
|
||||
}
|
||||
}
|
||||
|
||||
// Reset the buffered writer if needed, should not happen
|
||||
if cn.bw.Buffered() > 0 {
|
||||
cn.bw.Reset(cn.netConn)
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
cn.bw.Reset(netConn)
|
||||
}
|
||||
}
|
||||
|
||||
if err := fn(cn.wr); err != nil {
|
||||
@@ -116,17 +854,47 @@ func (cn *Conn) WithWriter(
|
||||
return cn.bw.Flush()
|
||||
}
|
||||
|
||||
func (cn *Conn) IsClosed() bool {
|
||||
return cn.closed.Load() || cn.stateMachine.GetState() == StateClosed
|
||||
}
|
||||
|
||||
func (cn *Conn) Close() error {
|
||||
cn.closed.Store(true)
|
||||
|
||||
// Transition to CLOSED state
|
||||
cn.stateMachine.Transition(StateClosed)
|
||||
|
||||
if cn.onClose != nil {
|
||||
// ignore error
|
||||
_ = cn.onClose()
|
||||
}
|
||||
return cn.netConn.Close()
|
||||
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaybeHasData tries to peek at the next byte in the socket without consuming it
|
||||
// This is used to check if there are push notifications available
|
||||
// Important: This will work on Linux, but not on Windows
|
||||
func (cn *Conn) MaybeHasData() bool {
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return maybeHasData(netConn)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// deadline computes the effective deadline time based on context and timeout.
|
||||
// It updates the usedAt timestamp to now.
|
||||
// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation).
|
||||
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
|
||||
tm := time.Now()
|
||||
cn.SetUsedAt(tm)
|
||||
// Use cached time for deadline calculation (called 2x per command: read + write)
|
||||
nowNs := getCachedTimeNs()
|
||||
cn.SetUsedAtNs(nowNs)
|
||||
tm := time.Unix(0, nowNs)
|
||||
|
||||
if timeout > 0 {
|
||||
tm = tm.Add(timeout)
|
||||
|
||||
+11
-1
@@ -12,6 +12,9 @@ import (
|
||||
|
||||
var errUnexpectedRead = errors.New("unexpected read from socket")
|
||||
|
||||
// connCheck checks if the connection is still alive and if there is data in the socket
|
||||
// it will try to peek at the next byte without consuming it since we may want to work with it
|
||||
// later on (e.g. push notifications)
|
||||
func connCheck(conn net.Conn) error {
|
||||
// Reset previous timeout.
|
||||
_ = conn.SetDeadline(time.Time{})
|
||||
@@ -29,7 +32,9 @@ func connCheck(conn net.Conn) error {
|
||||
|
||||
if err := rawConn.Read(func(fd uintptr) bool {
|
||||
var buf [1]byte
|
||||
n, err := syscall.Read(int(fd), buf[:])
|
||||
// Use MSG_PEEK to peek at data without consuming it
|
||||
n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK|syscall.MSG_DONTWAIT)
|
||||
|
||||
switch {
|
||||
case n == 0 && err == nil:
|
||||
sysErr = io.EOF
|
||||
@@ -47,3 +52,8 @@ func connCheck(conn net.Conn) error {
|
||||
|
||||
return sysErr
|
||||
}
|
||||
|
||||
// maybeHasData checks if there is data in the socket without consuming it
|
||||
func maybeHasData(conn net.Conn) bool {
|
||||
return connCheck(conn) == errUnexpectedRead
|
||||
}
|
||||
|
||||
+13
-2
@@ -2,8 +2,19 @@
|
||||
|
||||
package pool
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
func connCheck(conn net.Conn) error {
|
||||
// errUnexpectedRead is placeholder error variable for non-unix build constraints
|
||||
var errUnexpectedRead = errors.New("unexpected read from socket")
|
||||
|
||||
func connCheck(_ net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// since we can't check for data on the socket, we just assume there is some
|
||||
func maybeHasData(_ net.Conn) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
+343
@@ -0,0 +1,343 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// ConnState represents the connection state in the state machine.
|
||||
// States are designed to be lightweight and fast to check.
|
||||
//
|
||||
// State Transitions:
|
||||
//
|
||||
// CREATED → INITIALIZING → IDLE ⇄ IN_USE
|
||||
// ↓
|
||||
// UNUSABLE (handoff/reauth)
|
||||
// ↓
|
||||
// IDLE/CLOSED
|
||||
type ConnState uint32
|
||||
|
||||
const (
|
||||
// StateCreated - Connection just created, not yet initialized
|
||||
StateCreated ConnState = iota
|
||||
|
||||
// StateInitializing - Connection initialization in progress
|
||||
StateInitializing
|
||||
|
||||
// StateIdle - Connection initialized and idle in pool, ready to be acquired
|
||||
StateIdle
|
||||
|
||||
// StateInUse - Connection actively processing a command (retrieved from pool)
|
||||
StateInUse
|
||||
|
||||
// StateUnusable - Connection temporarily unusable due to background operation
|
||||
// (handoff, reauth, etc.). Cannot be acquired from pool.
|
||||
StateUnusable
|
||||
|
||||
// StateClosed - Connection closed
|
||||
StateClosed
|
||||
)
|
||||
|
||||
// Predefined state slices to avoid allocations in hot paths
|
||||
var (
|
||||
validFromInUse = []ConnState{StateInUse}
|
||||
validFromCreatedOrIdle = []ConnState{StateCreated, StateIdle}
|
||||
validFromCreatedInUseOrIdle = []ConnState{StateCreated, StateInUse, StateIdle}
|
||||
// For AwaitAndTransition calls
|
||||
validFromCreatedIdleOrUnusable = []ConnState{StateCreated, StateIdle, StateUnusable}
|
||||
validFromIdle = []ConnState{StateIdle}
|
||||
// For CompareAndSwapUsable
|
||||
validFromInitializingOrUnusable = []ConnState{StateInitializing, StateUnusable}
|
||||
)
|
||||
|
||||
// Accessor functions for predefined slices to avoid allocations in external packages
|
||||
// These return the same slice instance, so they're zero-allocation
|
||||
|
||||
// ValidFromIdle returns a predefined slice containing only StateIdle.
|
||||
// Use this to avoid allocations when calling AwaitAndTransition or TryTransition.
|
||||
func ValidFromIdle() []ConnState {
|
||||
return validFromIdle
|
||||
}
|
||||
|
||||
// ValidFromCreatedIdleOrUnusable returns a predefined slice for initialization transitions.
|
||||
// Use this to avoid allocations when calling AwaitAndTransition or TryTransition.
|
||||
func ValidFromCreatedIdleOrUnusable() []ConnState {
|
||||
return validFromCreatedIdleOrUnusable
|
||||
}
|
||||
|
||||
// String returns a human-readable string representation of the state.
|
||||
func (s ConnState) String() string {
|
||||
switch s {
|
||||
case StateCreated:
|
||||
return "CREATED"
|
||||
case StateInitializing:
|
||||
return "INITIALIZING"
|
||||
case StateIdle:
|
||||
return "IDLE"
|
||||
case StateInUse:
|
||||
return "IN_USE"
|
||||
case StateUnusable:
|
||||
return "UNUSABLE"
|
||||
case StateClosed:
|
||||
return "CLOSED"
|
||||
default:
|
||||
return fmt.Sprintf("UNKNOWN(%d)", s)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrInvalidStateTransition is returned when a state transition is not allowed
|
||||
ErrInvalidStateTransition = errors.New("invalid state transition")
|
||||
|
||||
// ErrStateMachineClosed is returned when operating on a closed state machine
|
||||
ErrStateMachineClosed = errors.New("state machine is closed")
|
||||
|
||||
// ErrTimeout is returned when a state transition times out
|
||||
ErrTimeout = errors.New("state transition timeout")
|
||||
)
|
||||
|
||||
// waiter represents a goroutine waiting for a state transition.
|
||||
// Designed for minimal allocations and fast processing.
|
||||
type waiter struct {
|
||||
validStates map[ConnState]struct{} // States we're waiting for
|
||||
targetState ConnState // State to transition to
|
||||
done chan error // Signaled when transition completes or times out
|
||||
}
|
||||
|
||||
// ConnStateMachine manages connection state transitions with FIFO waiting queue.
|
||||
// Optimized for:
|
||||
// - Lock-free reads (hot path)
|
||||
// - Minimal allocations
|
||||
// - Fast state transitions
|
||||
// - FIFO fairness for waiters
|
||||
// Note: Handoff metadata (endpoint, seqID, retries) is managed separately in the Conn struct.
|
||||
type ConnStateMachine struct {
|
||||
// Current state - atomic for lock-free reads
|
||||
state atomic.Uint32
|
||||
|
||||
// FIFO queue for waiters - only locked during waiter add/remove/notify
|
||||
mu sync.Mutex
|
||||
waiters *list.List // List of *waiter
|
||||
waiterCount atomic.Int32 // Fast lock-free check for waiters (avoids mutex in hot path)
|
||||
}
|
||||
|
||||
// NewConnStateMachine creates a new connection state machine.
|
||||
// Initial state is StateCreated.
|
||||
func NewConnStateMachine() *ConnStateMachine {
|
||||
sm := &ConnStateMachine{
|
||||
waiters: list.New(),
|
||||
}
|
||||
sm.state.Store(uint32(StateCreated))
|
||||
return sm
|
||||
}
|
||||
|
||||
// GetState returns the current state (lock-free read).
|
||||
// This is the hot path - optimized for zero allocations and minimal overhead.
|
||||
// Note: Zero allocations applies to state reads; converting the returned state to a string
|
||||
// (via String()) may allocate if the state is unknown.
|
||||
func (sm *ConnStateMachine) GetState() ConnState {
|
||||
return ConnState(sm.state.Load())
|
||||
}
|
||||
|
||||
// TryTransitionFast is an optimized version for the hot path (Get/Put operations).
|
||||
// It only handles simple state transitions without waiter notification.
|
||||
// This is safe because:
|
||||
// 1. Get/Put don't need to wait for state changes
|
||||
// 2. Background operations (handoff/reauth) use UNUSABLE state, which this won't match
|
||||
// 3. If a background operation is in progress (state is UNUSABLE), this fails fast
|
||||
//
|
||||
// Returns true if transition succeeded, false otherwise.
|
||||
// Use this for performance-critical paths where you don't need error details.
|
||||
//
|
||||
// Performance: Single CAS operation - as fast as the old atomic bool!
|
||||
// For multiple from states, use: sm.TryTransitionFast(State1, Target) || sm.TryTransitionFast(State2, Target)
|
||||
// The || operator short-circuits, so only 1 CAS is executed in the common case.
|
||||
func (sm *ConnStateMachine) TryTransitionFast(fromState, targetState ConnState) bool {
|
||||
return sm.state.CompareAndSwap(uint32(fromState), uint32(targetState))
|
||||
}
|
||||
|
||||
// TryTransition attempts an immediate state transition without waiting.
|
||||
// Returns the current state after the transition attempt and an error if the transition failed.
|
||||
// The returned state is the CURRENT state (after the attempt), not the previous state.
|
||||
// This is faster than AwaitAndTransition when you don't need to wait.
|
||||
// Uses compare-and-swap to atomically transition, preventing concurrent transitions.
|
||||
// This method does NOT wait - it fails immediately if the transition cannot be performed.
|
||||
//
|
||||
// Performance: Zero allocations on success path (hot path).
|
||||
func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) (ConnState, error) {
|
||||
// Try each valid from state with CAS
|
||||
// This ensures only ONE goroutine can successfully transition at a time
|
||||
for _, fromState := range validFromStates {
|
||||
// Try to atomically swap from fromState to targetState
|
||||
// If successful, we won the race and can proceed
|
||||
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
|
||||
// Success! We transitioned atomically
|
||||
// Hot path optimization: only check for waiters if transition succeeded
|
||||
// This avoids atomic load on every Get/Put when no waiters exist
|
||||
if sm.waiterCount.Load() > 0 {
|
||||
sm.notifyWaiters()
|
||||
}
|
||||
return targetState, nil
|
||||
}
|
||||
}
|
||||
|
||||
// All CAS attempts failed - state is not valid for this transition
|
||||
// Return the current state so caller can decide what to do
|
||||
// Note: This error path allocates, but it's the exceptional case
|
||||
currentState := sm.GetState()
|
||||
return currentState, fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)",
|
||||
ErrInvalidStateTransition, currentState, targetState, validFromStates)
|
||||
}
|
||||
|
||||
// Transition unconditionally transitions to the target state.
|
||||
// Use with caution - prefer AwaitAndTransition or TryTransition for safety.
|
||||
// This is useful for error paths or when you know the transition is valid.
|
||||
func (sm *ConnStateMachine) Transition(targetState ConnState) {
|
||||
sm.state.Store(uint32(targetState))
|
||||
sm.notifyWaiters()
|
||||
}
|
||||
|
||||
// AwaitAndTransition waits for the connection to reach one of the valid states,
|
||||
// then atomically transitions to the target state.
|
||||
// Returns the current state after the transition attempt and an error if the operation failed.
|
||||
// The returned state is the CURRENT state (after the attempt), not the previous state.
|
||||
// Returns error if timeout expires or context is cancelled.
|
||||
//
|
||||
// This method implements FIFO fairness - the first caller to wait gets priority
|
||||
// when the state becomes available.
|
||||
//
|
||||
// Performance notes:
|
||||
// - If already in a valid state, this is very fast (no allocation, no waiting)
|
||||
// - If waiting is required, allocates one waiter struct and one channel
|
||||
func (sm *ConnStateMachine) AwaitAndTransition(
|
||||
ctx context.Context,
|
||||
validFromStates []ConnState,
|
||||
targetState ConnState,
|
||||
) (ConnState, error) {
|
||||
// Fast path: try immediate transition with CAS to prevent race conditions
|
||||
// BUT: only if there are no waiters in the queue (to maintain FIFO ordering)
|
||||
if sm.waiterCount.Load() == 0 {
|
||||
for _, fromState := range validFromStates {
|
||||
// Check if we're already in target state
|
||||
if fromState == targetState && sm.GetState() == targetState {
|
||||
return targetState, nil
|
||||
}
|
||||
|
||||
// Try to atomically swap from fromState to targetState
|
||||
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
|
||||
// Success! We transitioned atomically
|
||||
sm.notifyWaiters()
|
||||
return targetState, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fast path failed - check if we should wait or fail
|
||||
currentState := sm.GetState()
|
||||
|
||||
// Check if closed
|
||||
if currentState == StateClosed {
|
||||
return currentState, ErrStateMachineClosed
|
||||
}
|
||||
|
||||
// Slow path: need to wait for state change
|
||||
// Create waiter with valid states map for fast lookup
|
||||
validStatesMap := make(map[ConnState]struct{}, len(validFromStates))
|
||||
for _, s := range validFromStates {
|
||||
validStatesMap[s] = struct{}{}
|
||||
}
|
||||
|
||||
w := &waiter{
|
||||
validStates: validStatesMap,
|
||||
targetState: targetState,
|
||||
done: make(chan error, 1), // Buffered to avoid goroutine leak
|
||||
}
|
||||
|
||||
// Add to FIFO queue
|
||||
sm.mu.Lock()
|
||||
elem := sm.waiters.PushBack(w)
|
||||
sm.waiterCount.Add(1)
|
||||
sm.mu.Unlock()
|
||||
|
||||
// Wait for state change or timeout
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Timeout or cancellation - remove from queue
|
||||
sm.mu.Lock()
|
||||
sm.waiters.Remove(elem)
|
||||
sm.waiterCount.Add(-1)
|
||||
sm.mu.Unlock()
|
||||
return sm.GetState(), ctx.Err()
|
||||
case err := <-w.done:
|
||||
// Transition completed (or failed)
|
||||
// Note: waiterCount is decremented either in notifyWaiters (when the waiter is notified and removed)
|
||||
// or here (on timeout/cancellation).
|
||||
return sm.GetState(), err
|
||||
}
|
||||
}
|
||||
|
||||
// notifyWaiters checks if any waiters can proceed and notifies them in FIFO order.
|
||||
// This is called after every state transition.
|
||||
func (sm *ConnStateMachine) notifyWaiters() {
|
||||
// Fast path: check atomic counter without acquiring lock
|
||||
// This eliminates mutex overhead in the common case (no waiters)
|
||||
if sm.waiterCount.Load() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring lock (waiters might have been processed)
|
||||
if sm.waiters.Len() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Process waiters in FIFO order until no more can be processed
|
||||
// We loop instead of recursing to avoid stack overflow and mutex issues
|
||||
for {
|
||||
processed := false
|
||||
|
||||
// Find the first waiter that can proceed
|
||||
for elem := sm.waiters.Front(); elem != nil; elem = elem.Next() {
|
||||
w := elem.Value.(*waiter)
|
||||
|
||||
// Read current state inside the loop to get the latest value
|
||||
currentState := sm.GetState()
|
||||
|
||||
// Check if current state is valid for this waiter
|
||||
if _, valid := w.validStates[currentState]; valid {
|
||||
// Remove from queue first
|
||||
sm.waiters.Remove(elem)
|
||||
sm.waiterCount.Add(-1)
|
||||
|
||||
// Use CAS to ensure state hasn't changed since we checked
|
||||
// This prevents race condition where another thread changes state
|
||||
// between our check and our transition
|
||||
if sm.state.CompareAndSwap(uint32(currentState), uint32(w.targetState)) {
|
||||
// Successfully transitioned - notify waiter
|
||||
w.done <- nil
|
||||
processed = true
|
||||
break
|
||||
} else {
|
||||
// State changed - re-add waiter to front of queue to maintain FIFO ordering
|
||||
// This waiter was first in line and should retain priority
|
||||
sm.waiters.PushFront(w)
|
||||
sm.waiterCount.Add(1)
|
||||
// Continue to next iteration to re-read state
|
||||
processed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we didn't process any waiter, we're done
|
||||
if !processed {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
+165
@@ -0,0 +1,165 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PoolHook defines the interface for connection lifecycle hooks.
|
||||
type PoolHook interface {
|
||||
// OnGet is called when a connection is retrieved from the pool.
|
||||
// It can modify the connection or return an error to prevent its use.
|
||||
// The accept flag can be used to prevent the connection from being used.
|
||||
// On Accept = false the connection is rejected and returned to the pool.
|
||||
// The error can be used to prevent the connection from being used and returned to the pool.
|
||||
// On Errors, the connection is removed from the pool.
|
||||
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
|
||||
// The flag can be used for gathering metrics on pool hit/miss ratio.
|
||||
OnGet(ctx context.Context, conn *Conn, isNewConn bool) (accept bool, err error)
|
||||
|
||||
// OnPut is called when a connection is returned to the pool.
|
||||
// It returns whether the connection should be pooled and whether it should be removed.
|
||||
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
|
||||
|
||||
// OnRemove is called when a connection is removed from the pool.
|
||||
// This happens when:
|
||||
// - Connection fails health check
|
||||
// - Connection exceeds max lifetime
|
||||
// - Pool is being closed
|
||||
// - Connection encounters an error
|
||||
// Implementations should clean up any per-connection state.
|
||||
// The reason parameter indicates why the connection was removed.
|
||||
OnRemove(ctx context.Context, conn *Conn, reason error)
|
||||
}
|
||||
|
||||
// PoolHookManager manages multiple pool hooks.
|
||||
type PoolHookManager struct {
|
||||
hooks []PoolHook
|
||||
hooksMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPoolHookManager creates a new pool hook manager.
|
||||
func NewPoolHookManager() *PoolHookManager {
|
||||
return &PoolHookManager{
|
||||
hooks: make([]PoolHook, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddHook adds a pool hook to the manager.
|
||||
// Hooks are called in the order they were added.
|
||||
func (phm *PoolHookManager) AddHook(hook PoolHook) {
|
||||
phm.hooksMu.Lock()
|
||||
defer phm.hooksMu.Unlock()
|
||||
phm.hooks = append(phm.hooks, hook)
|
||||
}
|
||||
|
||||
// RemoveHook removes a pool hook from the manager.
|
||||
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
|
||||
phm.hooksMu.Lock()
|
||||
defer phm.hooksMu.Unlock()
|
||||
|
||||
for i, h := range phm.hooks {
|
||||
if h == hook {
|
||||
// Remove hook by swapping with last element and truncating
|
||||
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
|
||||
phm.hooks = phm.hooks[:len(phm.hooks)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessOnGet calls all OnGet hooks in order.
|
||||
// If any hook returns an error, processing stops and the error is returned.
|
||||
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) (acceptConn bool, err error) {
|
||||
// Copy slice reference while holding lock (fast)
|
||||
phm.hooksMu.RLock()
|
||||
hooks := phm.hooks
|
||||
phm.hooksMu.RUnlock()
|
||||
|
||||
// Call hooks without holding lock (slow operations)
|
||||
for _, hook := range hooks {
|
||||
acceptConn, err := hook.OnGet(ctx, conn, isNewConn)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !acceptConn {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ProcessOnPut calls all OnPut hooks in order.
|
||||
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
|
||||
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
// Copy slice reference while holding lock (fast)
|
||||
phm.hooksMu.RLock()
|
||||
hooks := phm.hooks
|
||||
phm.hooksMu.RUnlock()
|
||||
|
||||
shouldPool = true // Default to pooling the connection
|
||||
|
||||
// Call hooks without holding lock (slow operations)
|
||||
for _, hook := range hooks {
|
||||
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
|
||||
|
||||
if hookErr != nil {
|
||||
return false, true, hookErr
|
||||
}
|
||||
|
||||
// If any hook says to remove or not pool, respect that decision
|
||||
if hookShouldRemove {
|
||||
return false, true, nil
|
||||
}
|
||||
|
||||
if !hookShouldPool {
|
||||
shouldPool = false
|
||||
}
|
||||
}
|
||||
|
||||
return shouldPool, false, nil
|
||||
}
|
||||
|
||||
// ProcessOnRemove calls all OnRemove hooks in order.
|
||||
func (phm *PoolHookManager) ProcessOnRemove(ctx context.Context, conn *Conn, reason error) {
|
||||
// Copy slice reference while holding lock (fast)
|
||||
phm.hooksMu.RLock()
|
||||
hooks := phm.hooks
|
||||
phm.hooksMu.RUnlock()
|
||||
|
||||
// Call hooks without holding lock (slow operations)
|
||||
for _, hook := range hooks {
|
||||
hook.OnRemove(ctx, conn, reason)
|
||||
}
|
||||
}
|
||||
|
||||
// GetHookCount returns the number of registered hooks (for testing).
|
||||
func (phm *PoolHookManager) GetHookCount() int {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
return len(phm.hooks)
|
||||
}
|
||||
|
||||
// GetHooks returns a copy of all registered hooks.
|
||||
func (phm *PoolHookManager) GetHooks() []PoolHook {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
hooks := make([]PoolHook, len(phm.hooks))
|
||||
copy(hooks, phm.hooks)
|
||||
return hooks
|
||||
}
|
||||
|
||||
// Clone creates a copy of the hook manager with the same hooks.
|
||||
// This is used for lock-free atomic updates of the hook manager.
|
||||
func (phm *PoolHookManager) Clone() *PoolHookManager {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
newManager := &PoolHookManager{
|
||||
hooks: make([]PoolHook, len(phm.hooks)),
|
||||
}
|
||||
copy(newManager.hooks, phm.hooks)
|
||||
return newManager
|
||||
}
|
||||
+656
-148
File diff suppressed because it is too large
Load Diff
+50
-4
@@ -1,7 +1,13 @@
|
||||
package pool
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SingleConnPool is a pool that always returns the same connection.
|
||||
// Note: This pool is not thread-safe.
|
||||
// It is intended to be used by clients that need a single connection.
|
||||
type SingleConnPool struct {
|
||||
pool Pooler
|
||||
cn *Conn
|
||||
@@ -10,6 +16,12 @@ type SingleConnPool struct {
|
||||
|
||||
var _ Pooler = (*SingleConnPool)(nil)
|
||||
|
||||
// NewSingleConnPool creates a new single connection pool.
|
||||
// The pool will always return the same connection.
|
||||
// The pool will not:
|
||||
// - Close the connection
|
||||
// - Reconnect the connection
|
||||
// - Track the connection in any way
|
||||
func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool {
|
||||
return &SingleConnPool{
|
||||
pool: pool,
|
||||
@@ -25,20 +37,47 @@ func (p *SingleConnPool) CloseConn(cn *Conn) error {
|
||||
return p.pool.CloseConn(cn)
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) {
|
||||
if p.stickyErr != nil {
|
||||
return nil, p.stickyErr
|
||||
}
|
||||
if p.cn == nil {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
// NOTE: SingleConnPool is NOT thread-safe by design and is used in special scenarios:
|
||||
// - During initialization (connection is in INITIALIZING state)
|
||||
// - During re-authentication (connection is in UNUSABLE state)
|
||||
// - For transactions (connection might be in various states)
|
||||
// We use SetUsed() which forces the transition, rather than TryTransition() which
|
||||
// would fail if the connection is not in IDLE/CREATED state.
|
||||
p.cn.SetUsed(true)
|
||||
p.cn.SetUsedAt(time.Now())
|
||||
return p.cn, nil
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {}
|
||||
func (p *SingleConnPool) Put(_ context.Context, cn *Conn) {
|
||||
if p.cn == nil {
|
||||
return
|
||||
}
|
||||
if p.cn != cn {
|
||||
return
|
||||
}
|
||||
p.cn.SetUsed(false)
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
|
||||
func (p *SingleConnPool) Remove(_ context.Context, cn *Conn, reason error) {
|
||||
cn.SetUsed(false)
|
||||
p.cn = nil
|
||||
p.stickyErr = reason
|
||||
}
|
||||
|
||||
// RemoveWithoutTurn has the same behavior as Remove for SingleConnPool
|
||||
// since SingleConnPool doesn't use a turn-based queue system.
|
||||
func (p *SingleConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) {
|
||||
p.Remove(ctx, cn, reason)
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Close() error {
|
||||
p.cn = nil
|
||||
p.stickyErr = ErrClosed
|
||||
@@ -53,6 +92,13 @@ func (p *SingleConnPool) IdleLen() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Size returns the maximum pool size, which is always 1 for SingleConnPool.
|
||||
func (p *SingleConnPool) Size() int { return 1 }
|
||||
|
||||
func (p *SingleConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) AddPoolHook(_ PoolHook) {}
|
||||
|
||||
func (p *SingleConnPool) RemovePoolHook(_ PoolHook) {}
|
||||
|
||||
+13
@@ -123,6 +123,12 @@ func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
|
||||
p.ch <- cn
|
||||
}
|
||||
|
||||
// RemoveWithoutTurn has the same behavior as Remove for StickyConnPool
|
||||
// since StickyConnPool doesn't use a turn-based queue system.
|
||||
func (p *StickyConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) {
|
||||
p.Remove(ctx, cn, reason)
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) Close() error {
|
||||
if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
|
||||
return nil
|
||||
@@ -196,6 +202,13 @@ func (p *StickyConnPool) IdleLen() int {
|
||||
return len(p.ch)
|
||||
}
|
||||
|
||||
// Size returns the maximum pool size, which is always 1 for StickyConnPool.
|
||||
func (p *StickyConnPool) Size() int { return 1 }
|
||||
|
||||
func (p *StickyConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
|
||||
|
||||
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}
|
||||
|
||||
+80
@@ -0,0 +1,80 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type PubSubStats struct {
|
||||
Created uint32
|
||||
Untracked uint32
|
||||
Active uint32
|
||||
}
|
||||
|
||||
// PubSubPool manages a pool of PubSub connections.
|
||||
type PubSubPool struct {
|
||||
opt *Options
|
||||
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// Map to track active PubSub connections
|
||||
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
|
||||
closed atomic.Bool
|
||||
stats PubSubStats
|
||||
}
|
||||
|
||||
// NewPubSubPool implements a pool for PubSub connections.
|
||||
// It intentionally does not implement the Pooler interface
|
||||
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
|
||||
return &PubSubPool{
|
||||
opt: opt,
|
||||
netDialer: netDialer,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
|
||||
if p.closed.Load() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
netConn, err := p.netDialer(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
|
||||
cn.pubsub = true
|
||||
atomic.AddUint32(&p.stats.Created, 1)
|
||||
return cn, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *PubSubPool) TrackConn(cn *Conn) {
|
||||
atomic.AddUint32(&p.stats.Active, 1)
|
||||
p.activeConns.Store(cn.GetID(), cn)
|
||||
}
|
||||
|
||||
func (p *PubSubPool) UntrackConn(cn *Conn) {
|
||||
atomic.AddUint32(&p.stats.Active, ^uint32(0))
|
||||
atomic.AddUint32(&p.stats.Untracked, 1)
|
||||
p.activeConns.Delete(cn.GetID())
|
||||
}
|
||||
|
||||
func (p *PubSubPool) Close() error {
|
||||
p.closed.Store(true)
|
||||
p.activeConns.Range(func(key, value interface{}) bool {
|
||||
cn := value.(*Conn)
|
||||
_ = cn.Close()
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PubSubPool) Stats() *PubSubStats {
|
||||
// load stats atomically
|
||||
return &PubSubStats{
|
||||
Created: atomic.LoadUint32(&p.stats.Created),
|
||||
Untracked: atomic.LoadUint32(&p.stats.Untracked),
|
||||
Active: atomic.LoadUint32(&p.stats.Active),
|
||||
}
|
||||
}
|
||||
+93
@@ -0,0 +1,93 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type wantConn struct {
|
||||
mu sync.Mutex // protects ctx, done and sending of the result
|
||||
ctx context.Context // context for dial, cleared after delivered or canceled
|
||||
cancelCtx context.CancelFunc
|
||||
done bool // true after delivered or canceled
|
||||
result chan wantConnResult // channel to deliver connection or error
|
||||
}
|
||||
|
||||
// getCtxForDial returns context for dial or nil if connection was delivered or canceled.
|
||||
func (w *wantConn) getCtxForDial() context.Context {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
return w.ctx
|
||||
}
|
||||
|
||||
func (w *wantConn) tryDeliver(cn *Conn, err error) bool {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.done {
|
||||
return false
|
||||
}
|
||||
|
||||
w.done = true
|
||||
w.ctx = nil
|
||||
|
||||
w.result <- wantConnResult{cn: cn, err: err}
|
||||
close(w.result)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *wantConn) cancel() *Conn {
|
||||
w.mu.Lock()
|
||||
var cn *Conn
|
||||
if w.done {
|
||||
select {
|
||||
case result := <-w.result:
|
||||
cn = result.cn
|
||||
default:
|
||||
}
|
||||
} else {
|
||||
close(w.result)
|
||||
}
|
||||
|
||||
w.done = true
|
||||
w.ctx = nil
|
||||
w.mu.Unlock()
|
||||
|
||||
return cn
|
||||
}
|
||||
|
||||
type wantConnResult struct {
|
||||
cn *Conn
|
||||
err error
|
||||
}
|
||||
|
||||
type wantConnQueue struct {
|
||||
mu sync.RWMutex
|
||||
items []*wantConn
|
||||
}
|
||||
|
||||
func newWantConnQueue() *wantConnQueue {
|
||||
return &wantConnQueue{
|
||||
items: make([]*wantConn, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *wantConnQueue) enqueue(w *wantConn) {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
q.items = append(q.items, w)
|
||||
}
|
||||
|
||||
func (q *wantConnQueue) dequeue() (*wantConn, bool) {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
|
||||
if len(q.items) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
item := q.items[0]
|
||||
q.items = q.items[1:]
|
||||
return item, true
|
||||
}
|
||||
+89
-2
@@ -50,7 +50,8 @@ func (e RedisError) Error() string { return string(e) }
|
||||
func (RedisError) RedisError() {}
|
||||
|
||||
func ParseErrorReply(line []byte) error {
|
||||
return RedisError(line[1:])
|
||||
msg := string(line[1:])
|
||||
return parseTypedRedisError(msg)
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -99,6 +100,92 @@ func (r *Reader) PeekReplyType() (byte, error) {
|
||||
return b[0], nil
|
||||
}
|
||||
|
||||
func (r *Reader) PeekPushNotificationName() (string, error) {
|
||||
// "prime" the buffer by peeking at the next byte
|
||||
c, err := r.Peek(1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if c[0] != RespPush {
|
||||
return "", fmt.Errorf("redis: can't peek push notification name, next reply is not a push notification")
|
||||
}
|
||||
|
||||
// peek 36 bytes at most, should be enough to read the push notification name
|
||||
toPeek := 36
|
||||
buffered := r.Buffered()
|
||||
if buffered == 0 {
|
||||
return "", fmt.Errorf("redis: can't peek push notification name, no data available")
|
||||
}
|
||||
if buffered < toPeek {
|
||||
toPeek = buffered
|
||||
}
|
||||
buf, err := r.rd.Peek(toPeek)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if buf[0] != RespPush {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
|
||||
if len(buf) < 3 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
|
||||
// remove push notification type
|
||||
buf = buf[1:]
|
||||
// remove first line - e.g. >2\r\n
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
buf = buf[i+2:]
|
||||
break
|
||||
} else {
|
||||
if buf[i] < '0' || buf[i] > '9' {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
// next line should be $<length><string>\r\n or +<length><string>\r\n
|
||||
// should have the type of the push notification name and it's length
|
||||
if buf[0] != RespString && buf[0] != RespStatus {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
typeOfName := buf[0]
|
||||
// remove the type of the push notification name
|
||||
buf = buf[1:]
|
||||
if typeOfName == RespString {
|
||||
// remove the length of the string
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
buf = buf[i+2:]
|
||||
break
|
||||
} else {
|
||||
if buf[i] < '0' || buf[i] > '9' {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
// keep only the notification name
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
buf = buf[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return util.BytesToString(buf), nil
|
||||
}
|
||||
|
||||
// ReadLine Return a valid reply, it will check the protocol or redis error,
|
||||
// and discard the attribute type.
|
||||
func (r *Reader) ReadLine() ([]byte, error) {
|
||||
@@ -115,7 +202,7 @@ func (r *Reader) ReadLine() ([]byte, error) {
|
||||
var blobErr string
|
||||
blobErr, err = r.readStringReply(line)
|
||||
if err == nil {
|
||||
err = RedisError(blobErr)
|
||||
err = parseTypedRedisError(blobErr)
|
||||
}
|
||||
return nil, err
|
||||
case RespAttr:
|
||||
|
||||
+488
@@ -0,0 +1,488 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Typed Redis errors for better error handling with wrapping support.
|
||||
// These errors maintain backward compatibility by keeping the same error messages.
|
||||
|
||||
// LoadingError is returned when Redis is loading the dataset in memory.
|
||||
type LoadingError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *LoadingError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *LoadingError) RedisError() {}
|
||||
|
||||
// NewLoadingError creates a new LoadingError with the given message.
|
||||
func NewLoadingError(msg string) *LoadingError {
|
||||
return &LoadingError{msg: msg}
|
||||
}
|
||||
|
||||
// ReadOnlyError is returned when trying to write to a read-only replica.
|
||||
type ReadOnlyError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *ReadOnlyError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *ReadOnlyError) RedisError() {}
|
||||
|
||||
// NewReadOnlyError creates a new ReadOnlyError with the given message.
|
||||
func NewReadOnlyError(msg string) *ReadOnlyError {
|
||||
return &ReadOnlyError{msg: msg}
|
||||
}
|
||||
|
||||
// MovedError is returned when a key has been moved to a different node in a cluster.
|
||||
type MovedError struct {
|
||||
msg string
|
||||
addr string
|
||||
}
|
||||
|
||||
func (e *MovedError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *MovedError) RedisError() {}
|
||||
|
||||
// Addr returns the address of the node where the key has been moved.
|
||||
func (e *MovedError) Addr() string {
|
||||
return e.addr
|
||||
}
|
||||
|
||||
// NewMovedError creates a new MovedError with the given message and address.
|
||||
func NewMovedError(msg string, addr string) *MovedError {
|
||||
return &MovedError{msg: msg, addr: addr}
|
||||
}
|
||||
|
||||
// AskError is returned when a key is being migrated and the client should ask another node.
|
||||
type AskError struct {
|
||||
msg string
|
||||
addr string
|
||||
}
|
||||
|
||||
func (e *AskError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *AskError) RedisError() {}
|
||||
|
||||
// Addr returns the address of the node to ask.
|
||||
func (e *AskError) Addr() string {
|
||||
return e.addr
|
||||
}
|
||||
|
||||
// NewAskError creates a new AskError with the given message and address.
|
||||
func NewAskError(msg string, addr string) *AskError {
|
||||
return &AskError{msg: msg, addr: addr}
|
||||
}
|
||||
|
||||
// ClusterDownError is returned when the cluster is down.
|
||||
type ClusterDownError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *ClusterDownError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *ClusterDownError) RedisError() {}
|
||||
|
||||
// NewClusterDownError creates a new ClusterDownError with the given message.
|
||||
func NewClusterDownError(msg string) *ClusterDownError {
|
||||
return &ClusterDownError{msg: msg}
|
||||
}
|
||||
|
||||
// TryAgainError is returned when a command cannot be processed and should be retried.
|
||||
type TryAgainError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *TryAgainError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *TryAgainError) RedisError() {}
|
||||
|
||||
// NewTryAgainError creates a new TryAgainError with the given message.
|
||||
func NewTryAgainError(msg string) *TryAgainError {
|
||||
return &TryAgainError{msg: msg}
|
||||
}
|
||||
|
||||
// MasterDownError is returned when the master is down.
|
||||
type MasterDownError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *MasterDownError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *MasterDownError) RedisError() {}
|
||||
|
||||
// NewMasterDownError creates a new MasterDownError with the given message.
|
||||
func NewMasterDownError(msg string) *MasterDownError {
|
||||
return &MasterDownError{msg: msg}
|
||||
}
|
||||
|
||||
// MaxClientsError is returned when the maximum number of clients has been reached.
|
||||
type MaxClientsError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *MaxClientsError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *MaxClientsError) RedisError() {}
|
||||
|
||||
// NewMaxClientsError creates a new MaxClientsError with the given message.
|
||||
func NewMaxClientsError(msg string) *MaxClientsError {
|
||||
return &MaxClientsError{msg: msg}
|
||||
}
|
||||
|
||||
// AuthError is returned when authentication fails.
|
||||
type AuthError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *AuthError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *AuthError) RedisError() {}
|
||||
|
||||
// NewAuthError creates a new AuthError with the given message.
|
||||
func NewAuthError(msg string) *AuthError {
|
||||
return &AuthError{msg: msg}
|
||||
}
|
||||
|
||||
// PermissionError is returned when a user lacks required permissions.
|
||||
type PermissionError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *PermissionError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *PermissionError) RedisError() {}
|
||||
|
||||
// NewPermissionError creates a new PermissionError with the given message.
|
||||
func NewPermissionError(msg string) *PermissionError {
|
||||
return &PermissionError{msg: msg}
|
||||
}
|
||||
|
||||
// ExecAbortError is returned when a transaction is aborted.
|
||||
type ExecAbortError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *ExecAbortError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *ExecAbortError) RedisError() {}
|
||||
|
||||
// NewExecAbortError creates a new ExecAbortError with the given message.
|
||||
func NewExecAbortError(msg string) *ExecAbortError {
|
||||
return &ExecAbortError{msg: msg}
|
||||
}
|
||||
|
||||
// OOMError is returned when Redis is out of memory.
|
||||
type OOMError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *OOMError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *OOMError) RedisError() {}
|
||||
|
||||
// NewOOMError creates a new OOMError with the given message.
|
||||
func NewOOMError(msg string) *OOMError {
|
||||
return &OOMError{msg: msg}
|
||||
}
|
||||
|
||||
// parseTypedRedisError parses a Redis error message and returns a typed error if applicable.
|
||||
// This function maintains backward compatibility by keeping the same error messages.
|
||||
func parseTypedRedisError(msg string) error {
|
||||
// Check for specific error patterns and return typed errors
|
||||
switch {
|
||||
case strings.HasPrefix(msg, "LOADING "):
|
||||
return NewLoadingError(msg)
|
||||
case strings.HasPrefix(msg, "READONLY "):
|
||||
return NewReadOnlyError(msg)
|
||||
case strings.HasPrefix(msg, "MOVED "):
|
||||
// Extract address from "MOVED <slot> <addr>"
|
||||
addr := extractAddr(msg)
|
||||
return NewMovedError(msg, addr)
|
||||
case strings.HasPrefix(msg, "ASK "):
|
||||
// Extract address from "ASK <slot> <addr>"
|
||||
addr := extractAddr(msg)
|
||||
return NewAskError(msg, addr)
|
||||
case strings.HasPrefix(msg, "CLUSTERDOWN "):
|
||||
return NewClusterDownError(msg)
|
||||
case strings.HasPrefix(msg, "TRYAGAIN "):
|
||||
return NewTryAgainError(msg)
|
||||
case strings.HasPrefix(msg, "MASTERDOWN "):
|
||||
return NewMasterDownError(msg)
|
||||
case msg == "ERR max number of clients reached":
|
||||
return NewMaxClientsError(msg)
|
||||
case strings.HasPrefix(msg, "NOAUTH "), strings.HasPrefix(msg, "WRONGPASS "), strings.Contains(msg, "unauthenticated"):
|
||||
return NewAuthError(msg)
|
||||
case strings.HasPrefix(msg, "NOPERM "):
|
||||
return NewPermissionError(msg)
|
||||
case strings.HasPrefix(msg, "EXECABORT "):
|
||||
return NewExecAbortError(msg)
|
||||
case strings.HasPrefix(msg, "OOM "):
|
||||
return NewOOMError(msg)
|
||||
default:
|
||||
// Return generic RedisError for unknown error types
|
||||
return RedisError(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// extractAddr extracts the address from MOVED/ASK error messages.
|
||||
// Format: "MOVED <slot> <addr>" or "ASK <slot> <addr>"
|
||||
func extractAddr(msg string) string {
|
||||
ind := strings.LastIndex(msg, " ")
|
||||
if ind == -1 {
|
||||
return ""
|
||||
}
|
||||
return msg[ind+1:]
|
||||
}
|
||||
|
||||
// IsLoadingError checks if an error is a LoadingError, even if wrapped.
|
||||
func IsLoadingError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var loadingErr *LoadingError
|
||||
if errors.As(err, &loadingErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with LOADING prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "LOADING ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "LOADING ")
|
||||
}
|
||||
|
||||
// IsReadOnlyError checks if an error is a ReadOnlyError, even if wrapped.
|
||||
func IsReadOnlyError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var readOnlyErr *ReadOnlyError
|
||||
if errors.As(err, &readOnlyErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with READONLY prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "READONLY ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "READONLY ")
|
||||
}
|
||||
|
||||
// IsMovedError checks if an error is a MovedError, even if wrapped.
|
||||
// Returns the error and a boolean indicating if it's a MovedError.
|
||||
func IsMovedError(err error) (*MovedError, bool) {
|
||||
if err == nil {
|
||||
return nil, false
|
||||
}
|
||||
var movedErr *MovedError
|
||||
if errors.As(err, &movedErr) {
|
||||
return movedErr, true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
s := err.Error()
|
||||
if strings.HasPrefix(s, "MOVED ") {
|
||||
// Parse: MOVED 3999 127.0.0.1:6381
|
||||
parts := strings.Split(s, " ")
|
||||
if len(parts) == 3 {
|
||||
return &MovedError{msg: s, addr: parts[2]}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// IsAskError checks if an error is an AskError, even if wrapped.
|
||||
// Returns the error and a boolean indicating if it's an AskError.
|
||||
func IsAskError(err error) (*AskError, bool) {
|
||||
if err == nil {
|
||||
return nil, false
|
||||
}
|
||||
var askErr *AskError
|
||||
if errors.As(err, &askErr) {
|
||||
return askErr, true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
s := err.Error()
|
||||
if strings.HasPrefix(s, "ASK ") {
|
||||
// Parse: ASK 3999 127.0.0.1:6381
|
||||
parts := strings.Split(s, " ")
|
||||
if len(parts) == 3 {
|
||||
return &AskError{msg: s, addr: parts[2]}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// IsClusterDownError checks if an error is a ClusterDownError, even if wrapped.
|
||||
func IsClusterDownError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var clusterDownErr *ClusterDownError
|
||||
if errors.As(err, &clusterDownErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with CLUSTERDOWN prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "CLUSTERDOWN ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "CLUSTERDOWN ")
|
||||
}
|
||||
|
||||
// IsTryAgainError checks if an error is a TryAgainError, even if wrapped.
|
||||
func IsTryAgainError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var tryAgainErr *TryAgainError
|
||||
if errors.As(err, &tryAgainErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with TRYAGAIN prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "TRYAGAIN ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "TRYAGAIN ")
|
||||
}
|
||||
|
||||
// IsMasterDownError checks if an error is a MasterDownError, even if wrapped.
|
||||
func IsMasterDownError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var masterDownErr *MasterDownError
|
||||
if errors.As(err, &masterDownErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with MASTERDOWN prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "MASTERDOWN ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "MASTERDOWN ")
|
||||
}
|
||||
|
||||
// IsMaxClientsError checks if an error is a MaxClientsError, even if wrapped.
|
||||
func IsMaxClientsError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var maxClientsErr *MaxClientsError
|
||||
if errors.As(err, &maxClientsErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with max clients prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "ERR max number of clients reached") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "ERR max number of clients reached")
|
||||
}
|
||||
|
||||
// IsAuthError checks if an error is an AuthError, even if wrapped.
|
||||
func IsAuthError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var authErr *AuthError
|
||||
if errors.As(err, &authErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with auth error prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) {
|
||||
s := redisErr.Error()
|
||||
return strings.HasPrefix(s, "NOAUTH ") || strings.HasPrefix(s, "WRONGPASS ") || strings.Contains(s, "unauthenticated")
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
s := err.Error()
|
||||
return strings.HasPrefix(s, "NOAUTH ") || strings.HasPrefix(s, "WRONGPASS ") || strings.Contains(s, "unauthenticated")
|
||||
}
|
||||
|
||||
// IsPermissionError checks if an error is a PermissionError, even if wrapped.
|
||||
func IsPermissionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var permErr *PermissionError
|
||||
if errors.As(err, &permErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with NOPERM prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "NOPERM ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "NOPERM ")
|
||||
}
|
||||
|
||||
// IsExecAbortError checks if an error is an ExecAbortError, even if wrapped.
|
||||
func IsExecAbortError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var execAbortErr *ExecAbortError
|
||||
if errors.As(err, &execAbortErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with EXECABORT prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "EXECABORT ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "EXECABORT ")
|
||||
}
|
||||
|
||||
// IsOOMError checks if an error is an OOMError, even if wrapped.
|
||||
func IsOOMError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var oomErr *OOMError
|
||||
if errors.As(err, &oomErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with OOM prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "OOM ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "OOM ")
|
||||
}
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
package internal
|
||||
|
||||
const RedisNull = "<nil>"
|
||||
+193
@@ -0,0 +1,193 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var semTimers = sync.Pool{
|
||||
New: func() interface{} {
|
||||
t := time.NewTimer(time.Hour)
|
||||
t.Stop()
|
||||
return t
|
||||
},
|
||||
}
|
||||
|
||||
// FastSemaphore is a channel-based semaphore optimized for performance.
|
||||
// It uses a fast path that avoids timer allocation when tokens are available.
|
||||
// The channel is pre-filled with tokens: Acquire = receive, Release = send.
|
||||
// Closing the semaphore unblocks all waiting goroutines.
|
||||
//
|
||||
// Performance: ~30 ns/op with zero allocations on fast path.
|
||||
// Fairness: Eventual fairness (no starvation) but not strict FIFO.
|
||||
type FastSemaphore struct {
|
||||
tokens chan struct{}
|
||||
max int32
|
||||
}
|
||||
|
||||
// NewFastSemaphore creates a new fast semaphore with the given capacity.
|
||||
func NewFastSemaphore(capacity int32) *FastSemaphore {
|
||||
ch := make(chan struct{}, capacity)
|
||||
// Pre-fill with tokens
|
||||
for i := int32(0); i < capacity; i++ {
|
||||
ch <- struct{}{}
|
||||
}
|
||||
return &FastSemaphore{
|
||||
tokens: ch,
|
||||
max: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire a token without blocking.
|
||||
// Returns true if successful, false if no tokens available.
|
||||
func (s *FastSemaphore) TryAcquire() bool {
|
||||
select {
|
||||
case <-s.tokens:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire acquires a token, blocking if necessary until one is available.
|
||||
// Returns an error if the context is cancelled or the timeout expires.
|
||||
// Uses a fast path to avoid timer allocation when tokens are immediately available.
|
||||
func (s *FastSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error {
|
||||
// Check context first
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Try fast path first (no timer needed)
|
||||
select {
|
||||
case <-s.tokens:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// Slow path: need to wait with timeout
|
||||
timer := semTimers.Get().(*time.Timer)
|
||||
defer semTimers.Put(timer)
|
||||
timer.Reset(timeout)
|
||||
|
||||
select {
|
||||
case <-s.tokens:
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return timeoutErr
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireBlocking acquires a token, blocking indefinitely until one is available.
|
||||
func (s *FastSemaphore) AcquireBlocking() {
|
||||
<-s.tokens
|
||||
}
|
||||
|
||||
// Release releases a token back to the semaphore.
|
||||
func (s *FastSemaphore) Release() {
|
||||
s.tokens <- struct{}{}
|
||||
}
|
||||
|
||||
// Close closes the semaphore, unblocking all waiting goroutines.
|
||||
// After close, all Acquire calls will receive a closed channel signal.
|
||||
func (s *FastSemaphore) Close() {
|
||||
close(s.tokens)
|
||||
}
|
||||
|
||||
// Len returns the current number of acquired tokens.
|
||||
func (s *FastSemaphore) Len() int32 {
|
||||
return s.max - int32(len(s.tokens))
|
||||
}
|
||||
|
||||
// FIFOSemaphore is a channel-based semaphore with strict FIFO ordering.
|
||||
// Unlike FastSemaphore, this guarantees that threads are served in the exact order they call Acquire().
|
||||
// The channel is pre-filled with tokens: Acquire = receive, Release = send.
|
||||
// Closing the semaphore unblocks all waiting goroutines.
|
||||
//
|
||||
// Performance: ~115 ns/op with zero allocations (slower than FastSemaphore due to timer allocation).
|
||||
// Fairness: Strict FIFO ordering guaranteed by Go runtime.
|
||||
type FIFOSemaphore struct {
|
||||
tokens chan struct{}
|
||||
max int32
|
||||
}
|
||||
|
||||
// NewFIFOSemaphore creates a new FIFO semaphore with the given capacity.
|
||||
func NewFIFOSemaphore(capacity int32) *FIFOSemaphore {
|
||||
ch := make(chan struct{}, capacity)
|
||||
// Pre-fill with tokens
|
||||
for i := int32(0); i < capacity; i++ {
|
||||
ch <- struct{}{}
|
||||
}
|
||||
return &FIFOSemaphore{
|
||||
tokens: ch,
|
||||
max: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire a token without blocking.
|
||||
// Returns true if successful, false if no tokens available.
|
||||
func (s *FIFOSemaphore) TryAcquire() bool {
|
||||
select {
|
||||
case <-s.tokens:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire acquires a token, blocking if necessary until one is available.
|
||||
// Returns an error if the context is cancelled or the timeout expires.
|
||||
// Always uses timer to guarantee FIFO ordering (no fast path).
|
||||
func (s *FIFOSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error {
|
||||
// No fast path - always use timer to guarantee FIFO
|
||||
timer := semTimers.Get().(*time.Timer)
|
||||
defer semTimers.Put(timer)
|
||||
timer.Reset(timeout)
|
||||
|
||||
select {
|
||||
case <-s.tokens:
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return timeoutErr
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireBlocking acquires a token, blocking indefinitely until one is available.
|
||||
func (s *FIFOSemaphore) AcquireBlocking() {
|
||||
<-s.tokens
|
||||
}
|
||||
|
||||
// Release releases a token back to the semaphore.
|
||||
func (s *FIFOSemaphore) Release() {
|
||||
s.tokens <- struct{}{}
|
||||
}
|
||||
|
||||
// Close closes the semaphore, unblocking all waiting goroutines.
|
||||
// After close, all Acquire calls will receive a closed channel signal.
|
||||
func (s *FIFOSemaphore) Close() {
|
||||
close(s.tokens)
|
||||
}
|
||||
|
||||
// Len returns the current number of acquired tokens.
|
||||
func (s *FIFOSemaphore) Len() int32 {
|
||||
return s.max - int32(len(s.tokens))
|
||||
}
|
||||
+11
@@ -28,3 +28,14 @@ func MustParseFloat(s string) float64 {
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// SafeIntToInt32 safely converts an int to int32, returning an error if overflow would occur.
|
||||
func SafeIntToInt32(value int, fieldName string) (int32, error) {
|
||||
if value > math.MaxInt32 {
|
||||
return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32)
|
||||
}
|
||||
if value < math.MinInt32 {
|
||||
return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32)
|
||||
}
|
||||
return int32(value), nil
|
||||
}
|
||||
|
||||
+17
@@ -0,0 +1,17 @@
|
||||
package util
|
||||
|
||||
// Max returns the maximum of two integers
|
||||
func Max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Min returns the minimum of two integers
|
||||
func Min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
+218
@@ -0,0 +1,218 @@
|
||||
# Maintenance Notifications - FEATURES
|
||||
|
||||
## Overview
|
||||
|
||||
The Maintenance Notifications feature enables seamless Redis connection handoffs during cluster maintenance operations without dropping active connections. This feature leverages Redis RESP3 push notifications to provide zero-downtime maintenance for Redis Enterprise and compatible Redis deployments.
|
||||
|
||||
## Important
|
||||
|
||||
Using Maintenance Notifications may affect the read and write timeouts by relaxing them during maintenance operations.
|
||||
This is necessary to prevent false failures due to increased latency during handoffs. The relaxed timeouts are automatically applied and removed as needed.
|
||||
|
||||
## Key Features
|
||||
|
||||
### Seamless Connection Handoffs
|
||||
- **Zero-Downtime Maintenance**: Automatically handles connection transitions during cluster operations
|
||||
- **Active Operation Preservation**: Transfers in-flight operations to new connections without interruption
|
||||
- **Graceful Degradation**: Falls back to standard reconnection if handoff fails
|
||||
|
||||
### Push Notification Support
|
||||
Supports all Redis Enterprise maintenance notification types:
|
||||
- **MOVING** - Slot moving to a new node
|
||||
- **MIGRATING** - Slot in migration state
|
||||
- **MIGRATED** - Migration completed
|
||||
- **FAILING_OVER** - Node failing over
|
||||
- **FAILED_OVER** - Failover completed
|
||||
|
||||
### Circuit Breaker Pattern
|
||||
- **Endpoint-Specific Failure Tracking**: Prevents repeated connection attempts to failing endpoints
|
||||
- **Automatic Recovery Testing**: Half-open state allows gradual recovery validation
|
||||
- **Configurable Thresholds**: Customize failure thresholds and reset timeouts
|
||||
|
||||
### Flexible Configuration
|
||||
- **Auto-Detection Mode**: Automatically detects server support for maintenance notifications
|
||||
- **Multiple Endpoint Types**: Support for internal/external IP/FQDN endpoint resolution
|
||||
- **Auto-Scaling Workers**: Automatically sizes worker pool based on connection pool size
|
||||
- **Timeout Management**: Separate timeouts for relaxed (during maintenance) and normal operations
|
||||
|
||||
### Extensible Hook System
|
||||
- **Pre/Post Processing Hooks**: Monitor and customize notification handling
|
||||
- **Built-in Hooks**: Logging and metrics collection hooks included
|
||||
- **Custom Hook Support**: Implement custom business logic around maintenance events
|
||||
|
||||
### Comprehensive Monitoring
|
||||
- **Metrics Collection**: Track notification counts, processing times, and error rates
|
||||
- **Circuit Breaker Stats**: Monitor endpoint health and circuit breaker states
|
||||
- **Operation Tracking**: Track active handoff operations and their lifecycle
|
||||
|
||||
## Architecture Highlights
|
||||
|
||||
### Event-Driven Handoff System
|
||||
- **Asynchronous Processing**: Non-blocking handoff operations using worker pool pattern
|
||||
- **Queue-Based Architecture**: Configurable queue size with auto-scaling support
|
||||
- **Retry Mechanism**: Configurable retry attempts with exponential backoff
|
||||
|
||||
### Connection Pool Integration
|
||||
- **Pool Hook Interface**: Seamless integration with go-redis connection pool
|
||||
- **Connection State Management**: Atomic flags for connection usability tracking
|
||||
- **Graceful Shutdown**: Ensures all in-flight handoffs complete before shutdown
|
||||
|
||||
### Thread-Safe Design
|
||||
- **Lock-Free Operations**: Atomic operations for high-performance state tracking
|
||||
- **Concurrent-Safe Maps**: sync.Map for tracking active operations
|
||||
- **Minimal Lock Contention**: Read-write locks only where necessary
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Operation Modes
|
||||
- **`ModeDisabled`**: Maintenance notifications completely disabled
|
||||
- **`ModeEnabled`**: Forcefully enabled (fails if server doesn't support)
|
||||
- **`ModeAuto`**: Auto-detect server support (recommended default)
|
||||
|
||||
### Endpoint Types
|
||||
- **`EndpointTypeAuto`**: Auto-detect based on current connection
|
||||
- **`EndpointTypeInternalIP`**: Use internal IP addresses
|
||||
- **`EndpointTypeInternalFQDN`**: Use internal fully qualified domain names
|
||||
- **`EndpointTypeExternalIP`**: Use external IP addresses
|
||||
- **`EndpointTypeExternalFQDN`**: Use external fully qualified domain names
|
||||
- **`EndpointTypeNone`**: No endpoint (reconnect with current configuration)
|
||||
|
||||
### Timeout Configuration
|
||||
- **`RelaxedTimeout`**: Extended timeout during maintenance operations (default: 10s)
|
||||
- **`HandoffTimeout`**: Maximum time for handoff completion (default: 15s)
|
||||
- **`PostHandoffRelaxedDuration`**: Relaxed period after handoff (default: 2×RelaxedTimeout)
|
||||
|
||||
### Worker Pool Configuration
|
||||
- **`MaxWorkers`**: Maximum concurrent handoff workers (auto-calculated if 0)
|
||||
- **`HandoffQueueSize`**: Handoff queue capacity (auto-calculated if 0)
|
||||
- **`MaxHandoffRetries`**: Maximum retry attempts for failed handoffs (default: 3)
|
||||
|
||||
### Circuit Breaker Configuration
|
||||
- **`CircuitBreakerFailureThreshold`**: Failures before opening circuit (default: 5)
|
||||
- **`CircuitBreakerResetTimeout`**: Time before testing recovery (default: 60s)
|
||||
- **`CircuitBreakerMaxRequests`**: Max requests in half-open state (default: 3)
|
||||
|
||||
## Auto-Scaling Formulas
|
||||
|
||||
### Worker Pool Sizing
|
||||
When `MaxWorkers = 0` (auto-calculate):
|
||||
```
|
||||
MaxWorkers = min(PoolSize/2, max(10, PoolSize/3))
|
||||
```
|
||||
|
||||
### Queue Sizing
|
||||
When `HandoffQueueSize = 0` (auto-calculate):
|
||||
```
|
||||
QueueSize = max(20 × MaxWorkers, PoolSize)
|
||||
Capped by: min(MaxActiveConns + 1, 5 × PoolSize)
|
||||
```
|
||||
|
||||
### Examples
|
||||
- **Pool Size 100**: 33 workers, 660 queue (capped at 500)
|
||||
- **Pool Size 100 + MaxActiveConns 150**: 33 workers, 151 queue
|
||||
- **Pool Size 50**: 16 workers, 320 queue (capped at 250)
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
### Throughput
|
||||
- **Non-Blocking Handoffs**: Client operations continue during handoffs
|
||||
- **Concurrent Processing**: Multiple handoffs processed in parallel
|
||||
- **Minimal Overhead**: Lock-free atomic operations for state tracking
|
||||
|
||||
### Latency
|
||||
- **Relaxed Timeouts**: Extended timeouts during maintenance prevent false failures
|
||||
- **Fast Path**: Connections not undergoing handoff have zero overhead
|
||||
- **Graceful Degradation**: Failed handoffs fall back to standard reconnection
|
||||
|
||||
### Resource Usage
|
||||
- **Memory Efficient**: Bounded queue sizes prevent memory exhaustion
|
||||
- **Worker Pool**: Fixed worker count prevents goroutine explosion
|
||||
- **Connection Reuse**: Handoff reuses existing connection objects
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Tests
|
||||
- Comprehensive unit test coverage for all components
|
||||
- Mock-based testing for isolation
|
||||
- Concurrent operation testing
|
||||
|
||||
### Integration Tests
|
||||
- Pool integration tests with real connection handoffs
|
||||
- Circuit breaker behavior validation
|
||||
- Hook system integration testing
|
||||
|
||||
### E2E Tests
|
||||
- Real Redis Enterprise cluster testing
|
||||
- Multiple scenario coverage (timeouts, endpoint types, stress tests)
|
||||
- Fault injection testing
|
||||
- TLS configuration testing
|
||||
|
||||
## Compatibility
|
||||
|
||||
### Requirements
|
||||
- **Redis Protocol**: RESP3 required for push notifications
|
||||
- **Redis Version**: Redis Enterprise or compatible Redis with maintenance notifications
|
||||
- **Go Version**: Go 1.18+ (uses generics and atomic types)
|
||||
|
||||
### Client Support
|
||||
#### Currently Supported
|
||||
- **Standalone Client** (`redis.NewClient`)
|
||||
|
||||
#### Planned Support
|
||||
- **Cluster Client** (not yet supported)
|
||||
|
||||
#### Will Not Support
|
||||
- **Failover Client** (no planned support)
|
||||
- **Ring Client** (no planned support)
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### Enabling Maintenance Notifications
|
||||
|
||||
**Before:**
|
||||
```go
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Protocol: 2, // RESP2
|
||||
})
|
||||
```
|
||||
|
||||
**After:**
|
||||
```go
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Protocol: 3, // RESP3 required
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeAuto,
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Adding Monitoring
|
||||
|
||||
```go
|
||||
// Get the manager from the client
|
||||
manager := client.GetMaintNotificationsManager()
|
||||
if manager != nil {
|
||||
// Add logging hook
|
||||
loggingHook := maintnotifications.NewLoggingHook(2) // Info level
|
||||
manager.AddNotificationHook(loggingHook)
|
||||
|
||||
// Add metrics hook
|
||||
metricsHook := maintnotifications.NewMetricsHook()
|
||||
manager.AddNotificationHook(metricsHook)
|
||||
}
|
||||
```
|
||||
|
||||
## Known Limitations
|
||||
|
||||
1. **Standalone Only**: Currently only supported in standalone Redis clients
|
||||
2. **RESP3 Required**: Push notifications require RESP3 protocol
|
||||
3. **Server Support**: Requires Redis Enterprise or compatible Redis with maintenance notifications
|
||||
4. **Single Connection Commands**: Some commands (MULTI/EXEC, WATCH) may need special handling
|
||||
5. **No Failover/Ring Client Support**: Failover and Ring clients are not supported and there are no plans to add support
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Cluster client support
|
||||
- Enhanced metrics and observability
|
||||
+67
@@ -0,0 +1,67 @@
|
||||
# Maintenance Notifications
|
||||
|
||||
Seamless Redis connection handoffs during cluster maintenance operations without dropping connections.
|
||||
|
||||
## ⚠️ **Important Note**
|
||||
**Maintenance notifications are currently supported only in standalone Redis clients.** Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support this functionality.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Protocol: 3, // RESP3 required
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeEnabled,
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Modes
|
||||
|
||||
- **`ModeDisabled`** - Maintenance notifications disabled
|
||||
- **`ModeEnabled`** - Forcefully enabled (fails if server doesn't support)
|
||||
- **`ModeAuto`** - Auto-detect server support (default)
|
||||
|
||||
## Configuration
|
||||
|
||||
```go
|
||||
&maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeAuto,
|
||||
EndpointType: maintnotifications.EndpointTypeAuto,
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxHandoffRetries: 3,
|
||||
MaxWorkers: 0, // Auto-calculated
|
||||
HandoffQueueSize: 0, // Auto-calculated
|
||||
PostHandoffRelaxedDuration: 0, // 2 * RelaxedTimeout
|
||||
}
|
||||
```
|
||||
|
||||
### Endpoint Types
|
||||
|
||||
- **`EndpointTypeAuto`** - Auto-detect based on connection (default)
|
||||
- **`EndpointTypeInternalIP`** - Internal IP address
|
||||
- **`EndpointTypeInternalFQDN`** - Internal FQDN
|
||||
- **`EndpointTypeExternalIP`** - External IP address
|
||||
- **`EndpointTypeExternalFQDN`** - External FQDN
|
||||
- **`EndpointTypeNone`** - No endpoint (reconnect with current config)
|
||||
|
||||
### Auto-Scaling
|
||||
|
||||
**Workers**: `min(PoolSize/2, max(10, PoolSize/3))` when auto-calculated
|
||||
**Queue**: `max(20×Workers, PoolSize)` capped by `MaxActiveConns+1` or `5×PoolSize`
|
||||
|
||||
**Examples:**
|
||||
- Pool 100: 33 workers, 660 queue (capped at 500)
|
||||
- Pool 100 + MaxActiveConns 150: 33 workers, 151 queue
|
||||
|
||||
## How It Works
|
||||
|
||||
1. Redis sends push notifications about cluster maintenance operations
|
||||
2. Client creates new connections to updated endpoints
|
||||
3. Active operations transfer to new connections
|
||||
4. Old connections close gracefully
|
||||
|
||||
|
||||
## For more information, see [FEATURES](FEATURES.md)
|
||||
+353
@@ -0,0 +1,353 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
)
|
||||
|
||||
// CircuitBreakerState represents the state of a circuit breaker
|
||||
type CircuitBreakerState int32
|
||||
|
||||
const (
|
||||
// CircuitBreakerClosed - normal operation, requests allowed
|
||||
CircuitBreakerClosed CircuitBreakerState = iota
|
||||
// CircuitBreakerOpen - failing fast, requests rejected
|
||||
CircuitBreakerOpen
|
||||
// CircuitBreakerHalfOpen - testing if service recovered
|
||||
CircuitBreakerHalfOpen
|
||||
)
|
||||
|
||||
func (s CircuitBreakerState) String() string {
|
||||
switch s {
|
||||
case CircuitBreakerClosed:
|
||||
return "closed"
|
||||
case CircuitBreakerOpen:
|
||||
return "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling
|
||||
type CircuitBreaker struct {
|
||||
// Configuration
|
||||
failureThreshold int // Number of failures before opening
|
||||
resetTimeout time.Duration // How long to stay open before testing
|
||||
maxRequests int // Max requests allowed in half-open state
|
||||
|
||||
// State tracking (atomic for lock-free access)
|
||||
state atomic.Int32 // CircuitBreakerState
|
||||
failures atomic.Int64 // Current failure count
|
||||
successes atomic.Int64 // Success count in half-open state
|
||||
requests atomic.Int64 // Request count in half-open state
|
||||
lastFailureTime atomic.Int64 // Unix timestamp of last failure
|
||||
lastSuccessTime atomic.Int64 // Unix timestamp of last success
|
||||
|
||||
// Endpoint identification
|
||||
endpoint string
|
||||
config *Config
|
||||
}
|
||||
|
||||
// newCircuitBreaker creates a new circuit breaker for an endpoint
|
||||
func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker {
|
||||
// Use configuration values with sensible defaults
|
||||
failureThreshold := 5
|
||||
resetTimeout := 60 * time.Second
|
||||
maxRequests := 3
|
||||
|
||||
if config != nil {
|
||||
failureThreshold = config.CircuitBreakerFailureThreshold
|
||||
resetTimeout = config.CircuitBreakerResetTimeout
|
||||
maxRequests = config.CircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
return &CircuitBreaker{
|
||||
failureThreshold: failureThreshold,
|
||||
resetTimeout: resetTimeout,
|
||||
maxRequests: maxRequests,
|
||||
endpoint: endpoint,
|
||||
config: config,
|
||||
state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0)
|
||||
}
|
||||
}
|
||||
|
||||
// IsOpen returns true if the circuit breaker is open (rejecting requests)
|
||||
func (cb *CircuitBreaker) IsOpen() bool {
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
return state == CircuitBreakerOpen
|
||||
}
|
||||
|
||||
// shouldAttemptReset checks if enough time has passed to attempt reset
|
||||
func (cb *CircuitBreaker) shouldAttemptReset() bool {
|
||||
lastFailure := time.Unix(cb.lastFailureTime.Load(), 0)
|
||||
return time.Since(lastFailure) >= cb.resetTimeout
|
||||
}
|
||||
|
||||
// Execute runs the given function with circuit breaker protection
|
||||
func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
// Single atomic state load for consistency
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerOpen:
|
||||
if cb.shouldAttemptReset() {
|
||||
// Attempt transition to half-open
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
|
||||
cb.requests.Store(0)
|
||||
cb.successes.Store(0)
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint))
|
||||
}
|
||||
// Fall through to half-open logic
|
||||
} else {
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
} else {
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
fallthrough
|
||||
case CircuitBreakerHalfOpen:
|
||||
requests := cb.requests.Add(1)
|
||||
if requests > int64(cb.maxRequests) {
|
||||
cb.requests.Add(-1) // Revert the increment
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the function with consistent state
|
||||
err := fn()
|
||||
|
||||
if err != nil {
|
||||
cb.recordFailure()
|
||||
return err
|
||||
}
|
||||
|
||||
cb.recordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordFailure records a failure and potentially opens the circuit
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
cb.lastFailureTime.Store(time.Now().Unix())
|
||||
failures := cb.failures.Add(1)
|
||||
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
if failures >= int64(cb.failureThreshold) {
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures))
|
||||
}
|
||||
}
|
||||
}
|
||||
case CircuitBreakerHalfOpen:
|
||||
// Any failure in half-open state immediately opens the circuit
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordSuccess records a success and potentially closes the circuit
|
||||
func (cb *CircuitBreaker) recordSuccess() {
|
||||
cb.lastSuccessTime.Store(time.Now().Unix())
|
||||
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
// Reset failure count on success in closed state
|
||||
cb.failures.Store(0)
|
||||
case CircuitBreakerHalfOpen:
|
||||
successes := cb.successes.Add(1)
|
||||
|
||||
// If we've had enough successful requests, close the circuit
|
||||
if successes >= int64(cb.maxRequests) {
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
|
||||
cb.failures.Store(0)
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current state of the circuit breaker
|
||||
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
return CircuitBreakerState(cb.state.Load())
|
||||
}
|
||||
|
||||
// GetStats returns current statistics for monitoring
|
||||
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
|
||||
return CircuitBreakerStats{
|
||||
Endpoint: cb.endpoint,
|
||||
State: cb.GetState(),
|
||||
Failures: cb.failures.Load(),
|
||||
Successes: cb.successes.Load(),
|
||||
Requests: cb.requests.Load(),
|
||||
LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0),
|
||||
LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0),
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerStats provides statistics about a circuit breaker
|
||||
type CircuitBreakerStats struct {
|
||||
Endpoint string
|
||||
State CircuitBreakerState
|
||||
Failures int64
|
||||
Successes int64
|
||||
Requests int64
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
}
|
||||
|
||||
// CircuitBreakerEntry wraps a circuit breaker with access tracking
|
||||
type CircuitBreakerEntry struct {
|
||||
breaker *CircuitBreaker
|
||||
lastAccess atomic.Int64 // Unix timestamp
|
||||
created time.Time
|
||||
}
|
||||
|
||||
// CircuitBreakerManager manages circuit breakers for multiple endpoints
|
||||
type CircuitBreakerManager struct {
|
||||
breakers sync.Map // map[string]*CircuitBreakerEntry
|
||||
config *Config
|
||||
cleanupStop chan struct{}
|
||||
cleanupMu sync.Mutex
|
||||
lastCleanup atomic.Int64 // Unix timestamp
|
||||
}
|
||||
|
||||
// newCircuitBreakerManager creates a new circuit breaker manager
|
||||
func newCircuitBreakerManager(config *Config) *CircuitBreakerManager {
|
||||
cbm := &CircuitBreakerManager{
|
||||
config: config,
|
||||
cleanupStop: make(chan struct{}),
|
||||
}
|
||||
cbm.lastCleanup.Store(time.Now().Unix())
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go cbm.cleanupLoop()
|
||||
|
||||
return cbm
|
||||
}
|
||||
|
||||
// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary
|
||||
func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker {
|
||||
now := time.Now().Unix()
|
||||
|
||||
if entry, ok := cbm.breakers.Load(endpoint); ok {
|
||||
cbEntry := entry.(*CircuitBreakerEntry)
|
||||
cbEntry.lastAccess.Store(now)
|
||||
return cbEntry.breaker
|
||||
}
|
||||
|
||||
// Create new circuit breaker with metadata
|
||||
newBreaker := newCircuitBreaker(endpoint, cbm.config)
|
||||
newEntry := &CircuitBreakerEntry{
|
||||
breaker: newBreaker,
|
||||
created: time.Now(),
|
||||
}
|
||||
newEntry.lastAccess.Store(now)
|
||||
|
||||
actual, _ := cbm.breakers.LoadOrStore(endpoint, newEntry)
|
||||
return actual.(*CircuitBreakerEntry).breaker
|
||||
}
|
||||
|
||||
// GetAllStats returns statistics for all circuit breakers
|
||||
func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats {
|
||||
var stats []CircuitBreakerStats
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
stats = append(stats, entry.breaker.GetStats())
|
||||
return true
|
||||
})
|
||||
return stats
|
||||
}
|
||||
|
||||
// cleanupLoop runs background cleanup of unused circuit breakers
|
||||
func (cbm *CircuitBreakerManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cbm.cleanup()
|
||||
case <-cbm.cleanupStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes circuit breakers that haven't been accessed recently
|
||||
func (cbm *CircuitBreakerManager) cleanup() {
|
||||
// Prevent concurrent cleanups
|
||||
if !cbm.cleanupMu.TryLock() {
|
||||
return
|
||||
}
|
||||
defer cbm.cleanupMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-30 * time.Minute).Unix() // 30 minute TTL
|
||||
|
||||
var toDelete []string
|
||||
count := 0
|
||||
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
endpoint := key.(string)
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
|
||||
count++
|
||||
|
||||
// Remove if not accessed recently
|
||||
if entry.lastAccess.Load() < cutoff {
|
||||
toDelete = append(toDelete, endpoint)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Delete expired entries
|
||||
for _, endpoint := range toDelete {
|
||||
cbm.breakers.Delete(endpoint)
|
||||
}
|
||||
|
||||
// Log cleanup results
|
||||
if len(toDelete) > 0 && internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count))
|
||||
}
|
||||
|
||||
cbm.lastCleanup.Store(now.Unix())
|
||||
}
|
||||
|
||||
// Shutdown stops the cleanup goroutine
|
||||
func (cbm *CircuitBreakerManager) Shutdown() {
|
||||
close(cbm.cleanupStop)
|
||||
}
|
||||
|
||||
// Reset resets all circuit breakers (useful for testing)
|
||||
func (cbm *CircuitBreakerManager) Reset() {
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
breaker := entry.breaker
|
||||
breaker.state.Store(int32(CircuitBreakerClosed))
|
||||
breaker.failures.Store(0)
|
||||
breaker.successes.Store(0)
|
||||
breaker.requests.Store(0)
|
||||
breaker.lastFailureTime.Store(0)
|
||||
breaker.lastSuccessTime.Store(0)
|
||||
return true
|
||||
})
|
||||
}
|
||||
+458
@@ -0,0 +1,458 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
)
|
||||
|
||||
// Mode represents the maintenance notifications mode
|
||||
type Mode string
|
||||
|
||||
// Constants for maintenance push notifications modes
|
||||
const (
|
||||
ModeDisabled Mode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
|
||||
ModeEnabled Mode = "enabled" // Client forcefully sends command, interrupts connection on error
|
||||
ModeAuto Mode = "auto" // Client tries to send command, disables feature on error
|
||||
)
|
||||
|
||||
// IsValid returns true if the maintenance notifications mode is valid
|
||||
func (m Mode) IsValid() bool {
|
||||
switch m {
|
||||
case ModeDisabled, ModeEnabled, ModeAuto:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of the mode
|
||||
func (m Mode) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
// EndpointType represents the type of endpoint to request in MOVING notifications
|
||||
type EndpointType string
|
||||
|
||||
// Constants for endpoint types
|
||||
const (
|
||||
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
|
||||
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
|
||||
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
|
||||
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
|
||||
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
|
||||
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
|
||||
)
|
||||
|
||||
// IsValid returns true if the endpoint type is valid
|
||||
func (e EndpointType) IsValid() bool {
|
||||
switch e {
|
||||
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of the endpoint type
|
||||
func (e EndpointType) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Config provides configuration options for maintenance notifications
|
||||
type Config struct {
|
||||
// Mode controls how client maintenance notifications are handled.
|
||||
// Valid values: ModeDisabled, ModeEnabled, ModeAuto
|
||||
// Default: ModeAuto
|
||||
Mode Mode
|
||||
|
||||
// EndpointType specifies the type of endpoint to request in MOVING notifications.
|
||||
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
|
||||
// Default: EndpointTypeAuto
|
||||
EndpointType EndpointType
|
||||
|
||||
// RelaxedTimeout is the concrete timeout value to use during
|
||||
// MIGRATING/FAILING_OVER states to accommodate increased latency.
|
||||
// This applies to both read and write timeouts.
|
||||
// Default: 10 seconds
|
||||
RelaxedTimeout time.Duration
|
||||
|
||||
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
|
||||
// If handoff takes longer than this, the old connection will be forcibly closed.
|
||||
// Default: 15 seconds (matches server-side eviction timeout)
|
||||
HandoffTimeout time.Duration
|
||||
|
||||
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
|
||||
// Workers are created on-demand and automatically cleaned up when idle.
|
||||
// If zero, defaults to min(10, PoolSize/2) to handle bursts effectively.
|
||||
// If explicitly set, enforces minimum of PoolSize/2
|
||||
//
|
||||
// Default: min(PoolSize/2, max(10, PoolSize/3)), Minimum when set: PoolSize/2
|
||||
MaxWorkers int
|
||||
|
||||
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
|
||||
// If the queue is full, new handoff requests will be rejected.
|
||||
// Scales with both worker count and pool size for better burst handling.
|
||||
//
|
||||
// Default: max(20×MaxWorkers, PoolSize), capped by MaxActiveConns+1 (if set) or 5×PoolSize
|
||||
// When set: minimum 200, capped by MaxActiveConns+1 (if set) or 5×PoolSize
|
||||
HandoffQueueSize int
|
||||
|
||||
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
|
||||
// after a handoff completes. This provides additional resilience during cluster transitions.
|
||||
// Default: 2 * RelaxedTimeout
|
||||
PostHandoffRelaxedDuration time.Duration
|
||||
|
||||
// Circuit breaker configuration for endpoint failure handling
|
||||
// CircuitBreakerFailureThreshold is the number of failures before opening the circuit.
|
||||
// Default: 5
|
||||
CircuitBreakerFailureThreshold int
|
||||
|
||||
// CircuitBreakerResetTimeout is how long to wait before testing if the endpoint recovered.
|
||||
// Default: 60 seconds
|
||||
CircuitBreakerResetTimeout time.Duration
|
||||
|
||||
// CircuitBreakerMaxRequests is the maximum number of requests allowed in half-open state.
|
||||
// Default: 3
|
||||
CircuitBreakerMaxRequests int
|
||||
|
||||
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
|
||||
// After this many retries, the connection will be removed from the pool.
|
||||
// Default: 3
|
||||
MaxHandoffRetries int
|
||||
}
|
||||
|
||||
func (c *Config) IsEnabled() bool {
|
||||
return c != nil && c.Mode != ModeDisabled
|
||||
}
|
||||
|
||||
// DefaultConfig returns a Config with sensible defaults.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Mode: ModeAuto, // Enable by default for Redis Cloud
|
||||
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxWorkers: 0, // Auto-calculated based on pool size
|
||||
HandoffQueueSize: 0, // Auto-calculated based on max workers
|
||||
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
|
||||
|
||||
// Circuit breaker configuration
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 60 * time.Second,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
|
||||
// Connection Handoff Configuration
|
||||
MaxHandoffRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid.
|
||||
func (c *Config) Validate() error {
|
||||
if c.RelaxedTimeout <= 0 {
|
||||
return ErrInvalidRelaxedTimeout
|
||||
}
|
||||
if c.HandoffTimeout <= 0 {
|
||||
return ErrInvalidHandoffTimeout
|
||||
}
|
||||
// Validate worker configuration
|
||||
// Allow 0 for auto-calculation, but negative values are invalid
|
||||
if c.MaxWorkers < 0 {
|
||||
return ErrInvalidHandoffWorkers
|
||||
}
|
||||
// HandoffQueueSize validation - allow 0 for auto-calculation
|
||||
if c.HandoffQueueSize < 0 {
|
||||
return ErrInvalidHandoffQueueSize
|
||||
}
|
||||
if c.PostHandoffRelaxedDuration < 0 {
|
||||
return ErrInvalidPostHandoffRelaxedDuration
|
||||
}
|
||||
|
||||
// Circuit breaker validation
|
||||
if c.CircuitBreakerFailureThreshold < 1 {
|
||||
return ErrInvalidCircuitBreakerFailureThreshold
|
||||
}
|
||||
if c.CircuitBreakerResetTimeout < 0 {
|
||||
return ErrInvalidCircuitBreakerResetTimeout
|
||||
}
|
||||
if c.CircuitBreakerMaxRequests < 1 {
|
||||
return ErrInvalidCircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
// Validate Mode (maintenance notifications mode)
|
||||
if !c.Mode.IsValid() {
|
||||
return ErrInvalidMaintNotifications
|
||||
}
|
||||
|
||||
// Validate EndpointType
|
||||
if !c.EndpointType.IsValid() {
|
||||
return ErrInvalidEndpointType
|
||||
}
|
||||
|
||||
// Validate configuration fields
|
||||
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
|
||||
return ErrInvalidHandoffRetries
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyDefaults applies default values to any zero-value fields in the configuration.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaults() *Config {
|
||||
return c.ApplyDefaultsWithPoolSize(0)
|
||||
}
|
||||
|
||||
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
|
||||
// using the provided pool size to calculate worker defaults.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
|
||||
return c.ApplyDefaultsWithPoolConfig(poolSize, 0)
|
||||
}
|
||||
|
||||
// ApplyDefaultsWithPoolConfig applies default values to any zero-value fields in the configuration,
|
||||
// using the provided pool size and max active connections to calculate worker and queue defaults.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *Config {
|
||||
if c == nil {
|
||||
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
|
||||
}
|
||||
|
||||
defaults := DefaultConfig()
|
||||
result := &Config{}
|
||||
|
||||
// Apply defaults for enum fields (empty/zero means not set)
|
||||
result.Mode = defaults.Mode
|
||||
if c.Mode != "" {
|
||||
result.Mode = c.Mode
|
||||
}
|
||||
|
||||
result.EndpointType = defaults.EndpointType
|
||||
if c.EndpointType != "" {
|
||||
result.EndpointType = c.EndpointType
|
||||
}
|
||||
|
||||
// Apply defaults for duration fields (zero means not set)
|
||||
result.RelaxedTimeout = defaults.RelaxedTimeout
|
||||
if c.RelaxedTimeout > 0 {
|
||||
result.RelaxedTimeout = c.RelaxedTimeout
|
||||
}
|
||||
|
||||
result.HandoffTimeout = defaults.HandoffTimeout
|
||||
if c.HandoffTimeout > 0 {
|
||||
result.HandoffTimeout = c.HandoffTimeout
|
||||
}
|
||||
|
||||
// Copy worker configuration
|
||||
result.MaxWorkers = c.MaxWorkers
|
||||
|
||||
// Apply worker defaults based on pool size
|
||||
result.applyWorkerDefaults(poolSize)
|
||||
|
||||
// Apply queue size defaults with new scaling approach
|
||||
// Default: max(20x workers, PoolSize), capped by maxActiveConns or 5x pool size
|
||||
workerBasedSize := result.MaxWorkers * 20
|
||||
poolBasedSize := poolSize
|
||||
result.HandoffQueueSize = util.Max(workerBasedSize, poolBasedSize)
|
||||
if c.HandoffQueueSize > 0 {
|
||||
// When explicitly set: enforce minimum of 200
|
||||
result.HandoffQueueSize = util.Max(200, c.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Cap queue size: use maxActiveConns+1 if set, otherwise 5x pool size
|
||||
var queueCap int
|
||||
if maxActiveConns > 0 {
|
||||
queueCap = maxActiveConns + 1
|
||||
// Ensure queue cap is at least 2 for very small maxActiveConns
|
||||
if queueCap < 2 {
|
||||
queueCap = 2
|
||||
}
|
||||
} else {
|
||||
queueCap = poolSize * 5
|
||||
}
|
||||
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, queueCap)
|
||||
|
||||
// Ensure minimum queue size of 2 (fallback for very small pools)
|
||||
if result.HandoffQueueSize < 2 {
|
||||
result.HandoffQueueSize = 2
|
||||
}
|
||||
|
||||
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
|
||||
if c.PostHandoffRelaxedDuration > 0 {
|
||||
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
|
||||
}
|
||||
|
||||
// Apply defaults for configuration fields
|
||||
result.MaxHandoffRetries = defaults.MaxHandoffRetries
|
||||
if c.MaxHandoffRetries > 0 {
|
||||
result.MaxHandoffRetries = c.MaxHandoffRetries
|
||||
}
|
||||
|
||||
// Circuit breaker configuration
|
||||
result.CircuitBreakerFailureThreshold = defaults.CircuitBreakerFailureThreshold
|
||||
if c.CircuitBreakerFailureThreshold > 0 {
|
||||
result.CircuitBreakerFailureThreshold = c.CircuitBreakerFailureThreshold
|
||||
}
|
||||
|
||||
result.CircuitBreakerResetTimeout = defaults.CircuitBreakerResetTimeout
|
||||
if c.CircuitBreakerResetTimeout > 0 {
|
||||
result.CircuitBreakerResetTimeout = c.CircuitBreakerResetTimeout
|
||||
}
|
||||
|
||||
result.CircuitBreakerMaxRequests = defaults.CircuitBreakerMaxRequests
|
||||
if c.CircuitBreakerMaxRequests > 0 {
|
||||
result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
if internal.LogLevel.DebugOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.DebugLoggingEnabled())
|
||||
internal.Logger.Printf(context.Background(), logs.ConfigDebug(result))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of the configuration.
|
||||
func (c *Config) Clone() *Config {
|
||||
if c == nil {
|
||||
return DefaultConfig()
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Mode: c.Mode,
|
||||
EndpointType: c.EndpointType,
|
||||
RelaxedTimeout: c.RelaxedTimeout,
|
||||
HandoffTimeout: c.HandoffTimeout,
|
||||
MaxWorkers: c.MaxWorkers,
|
||||
HandoffQueueSize: c.HandoffQueueSize,
|
||||
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
|
||||
|
||||
// Circuit breaker configuration
|
||||
CircuitBreakerFailureThreshold: c.CircuitBreakerFailureThreshold,
|
||||
CircuitBreakerResetTimeout: c.CircuitBreakerResetTimeout,
|
||||
CircuitBreakerMaxRequests: c.CircuitBreakerMaxRequests,
|
||||
|
||||
// Configuration fields
|
||||
MaxHandoffRetries: c.MaxHandoffRetries,
|
||||
}
|
||||
}
|
||||
|
||||
// applyWorkerDefaults calculates and applies worker defaults based on pool size
|
||||
func (c *Config) applyWorkerDefaults(poolSize int) {
|
||||
// Calculate defaults based on pool size
|
||||
if poolSize <= 0 {
|
||||
poolSize = 10 * runtime.GOMAXPROCS(0)
|
||||
}
|
||||
|
||||
// When not set: min(poolSize/2, max(10, poolSize/3)) - balanced scaling approach
|
||||
originalMaxWorkers := c.MaxWorkers
|
||||
c.MaxWorkers = util.Min(poolSize/2, util.Max(10, poolSize/3))
|
||||
if originalMaxWorkers != 0 {
|
||||
// When explicitly set: max(poolSize/2, set_value) - ensure at least poolSize/2 workers
|
||||
c.MaxWorkers = util.Max(poolSize/2, originalMaxWorkers)
|
||||
}
|
||||
|
||||
// Ensure minimum of 1 worker (fallback for very small pools)
|
||||
if c.MaxWorkers < 1 {
|
||||
c.MaxWorkers = 1
|
||||
}
|
||||
}
|
||||
|
||||
// DetectEndpointType automatically detects the appropriate endpoint type
|
||||
// based on the connection address and TLS configuration.
|
||||
//
|
||||
// For IP addresses:
|
||||
// - If TLS is enabled: requests FQDN for proper certificate validation
|
||||
// - If TLS is disabled: requests IP for better performance
|
||||
//
|
||||
// For hostnames:
|
||||
// - If TLS is enabled: always requests FQDN for proper certificate validation
|
||||
// - If TLS is disabled: requests IP for better performance
|
||||
//
|
||||
// Internal vs External detection:
|
||||
// - For IPs: uses private IP range detection
|
||||
// - For hostnames: uses heuristics based on common internal naming patterns
|
||||
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
|
||||
// Extract host from "host:port" format
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr // Assume no port
|
||||
}
|
||||
|
||||
// Check if the host is an IP address or hostname
|
||||
ip := net.ParseIP(host)
|
||||
isIPAddress := ip != nil
|
||||
var endpointType EndpointType
|
||||
|
||||
if isIPAddress {
|
||||
// Address is an IP - determine if it's private or public
|
||||
isPrivate := ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
|
||||
|
||||
if tlsEnabled {
|
||||
// TLS with IP addresses - still prefer FQDN for certificate validation
|
||||
if isPrivate {
|
||||
endpointType = EndpointTypeInternalFQDN
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalFQDN
|
||||
}
|
||||
} else {
|
||||
// No TLS - can use IP addresses directly
|
||||
if isPrivate {
|
||||
endpointType = EndpointTypeInternalIP
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalIP
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Address is a hostname
|
||||
isInternalHostname := isInternalHostname(host)
|
||||
if isInternalHostname {
|
||||
endpointType = EndpointTypeInternalFQDN
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalFQDN
|
||||
}
|
||||
}
|
||||
|
||||
return endpointType
|
||||
}
|
||||
|
||||
// isInternalHostname determines if a hostname appears to be internal/private.
|
||||
// This is a heuristic based on common naming patterns.
|
||||
func isInternalHostname(hostname string) bool {
|
||||
// Convert to lowercase for comparison
|
||||
hostname = strings.ToLower(hostname)
|
||||
|
||||
// Common internal hostname patterns
|
||||
internalPatterns := []string{
|
||||
"localhost",
|
||||
".local",
|
||||
".internal",
|
||||
".corp",
|
||||
".lan",
|
||||
".intranet",
|
||||
".private",
|
||||
}
|
||||
|
||||
// Check for exact match or suffix match
|
||||
for _, pattern := range internalPatterns {
|
||||
if hostname == pattern || strings.HasSuffix(hostname, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for RFC 1918 style hostnames (e.g., redis-1, db-server, etc.)
|
||||
// If hostname doesn't contain dots, it's likely internal
|
||||
if !strings.Contains(hostname, ".") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Default to external for fully qualified domain names
|
||||
return false
|
||||
}
|
||||
+76
@@ -0,0 +1,76 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
)
|
||||
|
||||
// Configuration errors
|
||||
var (
|
||||
ErrInvalidRelaxedTimeout = errors.New(logs.InvalidRelaxedTimeoutError())
|
||||
ErrInvalidHandoffTimeout = errors.New(logs.InvalidHandoffTimeoutError())
|
||||
ErrInvalidHandoffWorkers = errors.New(logs.InvalidHandoffWorkersError())
|
||||
ErrInvalidHandoffQueueSize = errors.New(logs.InvalidHandoffQueueSizeError())
|
||||
ErrInvalidPostHandoffRelaxedDuration = errors.New(logs.InvalidPostHandoffRelaxedDurationError())
|
||||
ErrInvalidEndpointType = errors.New(logs.InvalidEndpointTypeError())
|
||||
ErrInvalidMaintNotifications = errors.New(logs.InvalidMaintNotificationsError())
|
||||
ErrMaxHandoffRetriesReached = errors.New(logs.MaxHandoffRetriesReachedError())
|
||||
|
||||
// Configuration validation errors
|
||||
|
||||
// ErrInvalidHandoffRetries is returned when the number of handoff retries is invalid
|
||||
ErrInvalidHandoffRetries = errors.New(logs.InvalidHandoffRetriesError())
|
||||
)
|
||||
|
||||
// Integration errors
|
||||
var (
|
||||
// ErrInvalidClient is returned when the client does not support push notifications
|
||||
ErrInvalidClient = errors.New(logs.InvalidClientError())
|
||||
)
|
||||
|
||||
// Handoff errors
|
||||
var (
|
||||
// ErrHandoffQueueFull is returned when the handoff queue is full
|
||||
ErrHandoffQueueFull = errors.New(logs.HandoffQueueFullError())
|
||||
)
|
||||
|
||||
// Notification errors
|
||||
var (
|
||||
// ErrInvalidNotification is returned when a notification is in an invalid format
|
||||
ErrInvalidNotification = errors.New(logs.InvalidNotificationError())
|
||||
)
|
||||
|
||||
// connection handoff errors
|
||||
var (
|
||||
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
|
||||
// and should not be used until the handoff is complete
|
||||
ErrConnectionMarkedForHandoff = errors.New(logs.ConnectionMarkedForHandoffErrorMessage)
|
||||
// ErrConnectionMarkedForHandoffWithState is returned when a connection is marked for handoff
|
||||
// and should not be used until the handoff is complete
|
||||
ErrConnectionMarkedForHandoffWithState = errors.New(logs.ConnectionMarkedForHandoffErrorMessage + " with state")
|
||||
// ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff
|
||||
ErrConnectionInvalidHandoffState = errors.New(logs.ConnectionInvalidHandoffStateErrorMessage)
|
||||
)
|
||||
|
||||
// shutdown errors
|
||||
var (
|
||||
// ErrShutdown is returned when the maintnotifications manager is shutdown
|
||||
ErrShutdown = errors.New(logs.ShutdownError())
|
||||
)
|
||||
|
||||
// circuit breaker errors
|
||||
var (
|
||||
// ErrCircuitBreakerOpen is returned when the circuit breaker is open
|
||||
ErrCircuitBreakerOpen = errors.New(logs.CircuitBreakerOpenErrorMessage)
|
||||
)
|
||||
|
||||
// circuit breaker configuration errors
|
||||
var (
|
||||
// ErrInvalidCircuitBreakerFailureThreshold is returned when the circuit breaker failure threshold is invalid
|
||||
ErrInvalidCircuitBreakerFailureThreshold = errors.New(logs.InvalidCircuitBreakerFailureThresholdError())
|
||||
// ErrInvalidCircuitBreakerResetTimeout is returned when the circuit breaker reset timeout is invalid
|
||||
ErrInvalidCircuitBreakerResetTimeout = errors.New(logs.InvalidCircuitBreakerResetTimeoutError())
|
||||
// ErrInvalidCircuitBreakerMaxRequests is returned when the circuit breaker max requests is invalid
|
||||
ErrInvalidCircuitBreakerMaxRequests = errors.New(logs.InvalidCircuitBreakerMaxRequestsError())
|
||||
)
|
||||
+101
@@ -0,0 +1,101 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
startTimeKey contextKey = "maint_notif_start_time"
|
||||
)
|
||||
|
||||
// MetricsHook collects metrics about notification processing.
|
||||
type MetricsHook struct {
|
||||
NotificationCounts map[string]int64
|
||||
ProcessingTimes map[string]time.Duration
|
||||
ErrorCounts map[string]int64
|
||||
HandoffCounts int64 // Total handoffs initiated
|
||||
HandoffSuccesses int64 // Successful handoffs
|
||||
HandoffFailures int64 // Failed handoffs
|
||||
}
|
||||
|
||||
// NewMetricsHook creates a new metrics collection hook.
|
||||
func NewMetricsHook() *MetricsHook {
|
||||
return &MetricsHook{
|
||||
NotificationCounts: make(map[string]int64),
|
||||
ProcessingTimes: make(map[string]time.Duration),
|
||||
ErrorCounts: make(map[string]int64),
|
||||
}
|
||||
}
|
||||
|
||||
// PreHook records the start time for processing metrics.
|
||||
func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
mh.NotificationCounts[notificationType]++
|
||||
|
||||
// Log connection information if available
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
internal.Logger.Printf(ctx, logs.MetricsHookProcessingNotification(notificationType, conn.GetID()))
|
||||
}
|
||||
|
||||
// Store start time in context for duration calculation
|
||||
startTime := time.Now()
|
||||
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
|
||||
|
||||
return notification, true
|
||||
}
|
||||
|
||||
// PostHook records processing completion and any errors.
|
||||
func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
// Calculate processing duration
|
||||
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
|
||||
duration := time.Since(startTime)
|
||||
mh.ProcessingTimes[notificationType] = duration
|
||||
}
|
||||
|
||||
// Record errors
|
||||
if result != nil {
|
||||
mh.ErrorCounts[notificationType]++
|
||||
|
||||
// Log error details with connection information
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
internal.Logger.Printf(ctx, logs.MetricsHookRecordedError(notificationType, conn.GetID(), result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns a summary of collected metrics.
|
||||
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"notification_counts": mh.NotificationCounts,
|
||||
"processing_times": mh.ProcessingTimes,
|
||||
"error_counts": mh.ErrorCounts,
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleCircuitBreakerMonitor demonstrates how to monitor circuit breaker status
|
||||
func ExampleCircuitBreakerMonitor(poolHook *PoolHook) {
|
||||
// Get circuit breaker statistics
|
||||
stats := poolHook.GetCircuitBreakerStats()
|
||||
|
||||
for _, stat := range stats {
|
||||
fmt.Printf("Circuit Breaker for %s:\n", stat.Endpoint)
|
||||
fmt.Printf(" State: %s\n", stat.State)
|
||||
fmt.Printf(" Failures: %d\n", stat.Failures)
|
||||
fmt.Printf(" Last Failure: %v\n", stat.LastFailureTime)
|
||||
fmt.Printf(" Last Success: %v\n", stat.LastSuccessTime)
|
||||
|
||||
// Alert if circuit breaker is open
|
||||
if stat.State.String() == "open" {
|
||||
fmt.Printf(" ⚠️ ALERT: Circuit breaker is OPEN for %s\n", stat.Endpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
+512
@@ -0,0 +1,512 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// handoffWorkerManager manages background workers and queue for connection handoffs
|
||||
type handoffWorkerManager struct {
|
||||
// Event-driven handoff support
|
||||
handoffQueue chan HandoffRequest // Queue for handoff requests
|
||||
shutdown chan struct{} // Shutdown signal
|
||||
shutdownOnce sync.Once // Ensure clean shutdown
|
||||
workerWg sync.WaitGroup // Track worker goroutines
|
||||
|
||||
// On-demand worker management
|
||||
maxWorkers int
|
||||
activeWorkers atomic.Int32
|
||||
workerTimeout time.Duration // How long workers wait for work before exiting
|
||||
workersScaling atomic.Bool
|
||||
|
||||
// Simple state tracking
|
||||
pending sync.Map // map[uint64]int64 (connID -> seqID)
|
||||
|
||||
// Configuration for the maintenance notifications
|
||||
config *Config
|
||||
|
||||
// Pool hook reference for handoff processing
|
||||
poolHook *PoolHook
|
||||
|
||||
// Circuit breaker manager for endpoint failure handling
|
||||
circuitBreakerManager *CircuitBreakerManager
|
||||
}
|
||||
|
||||
// newHandoffWorkerManager creates a new handoff worker manager
|
||||
func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager {
|
||||
return &handoffWorkerManager{
|
||||
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
|
||||
shutdown: make(chan struct{}),
|
||||
maxWorkers: config.MaxWorkers,
|
||||
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
|
||||
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
|
||||
config: config,
|
||||
poolHook: poolHook,
|
||||
circuitBreakerManager: newCircuitBreakerManager(config),
|
||||
}
|
||||
}
|
||||
|
||||
// getCurrentWorkers returns the current number of active workers (for testing)
|
||||
func (hwm *handoffWorkerManager) getCurrentWorkers() int {
|
||||
return int(hwm.activeWorkers.Load())
|
||||
}
|
||||
|
||||
// getPendingMap returns the pending map for testing purposes
|
||||
func (hwm *handoffWorkerManager) getPendingMap() *sync.Map {
|
||||
return &hwm.pending
|
||||
}
|
||||
|
||||
// getMaxWorkers returns the max workers for testing purposes
|
||||
func (hwm *handoffWorkerManager) getMaxWorkers() int {
|
||||
return hwm.maxWorkers
|
||||
}
|
||||
|
||||
// getHandoffQueue returns the handoff queue for testing purposes
|
||||
func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest {
|
||||
return hwm.handoffQueue
|
||||
}
|
||||
|
||||
// getCircuitBreakerStats returns circuit breaker statistics for monitoring
|
||||
func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats {
|
||||
return hwm.circuitBreakerManager.GetAllStats()
|
||||
}
|
||||
|
||||
// resetCircuitBreakers resets all circuit breakers (useful for testing)
|
||||
func (hwm *handoffWorkerManager) resetCircuitBreakers() {
|
||||
hwm.circuitBreakerManager.Reset()
|
||||
}
|
||||
|
||||
// isHandoffPending returns true if the given connection has a pending handoff
|
||||
func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool {
|
||||
_, pending := hwm.pending.Load(conn.GetID())
|
||||
return pending
|
||||
}
|
||||
|
||||
// ensureWorkerAvailable ensures at least one worker is available to process requests
|
||||
// Creates a new worker if needed and under the max limit
|
||||
func (hwm *handoffWorkerManager) ensureWorkerAvailable() {
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return
|
||||
default:
|
||||
if hwm.workersScaling.CompareAndSwap(false, true) {
|
||||
defer hwm.workersScaling.Store(false)
|
||||
// Check if we need a new worker
|
||||
currentWorkers := hwm.activeWorkers.Load()
|
||||
workersWas := currentWorkers
|
||||
for currentWorkers < int32(hwm.maxWorkers) {
|
||||
hwm.workerWg.Add(1)
|
||||
go hwm.onDemandWorker()
|
||||
currentWorkers++
|
||||
}
|
||||
// workersWas is always <= currentWorkers
|
||||
// currentWorkers will be maxWorkers, but if we have a worker that was closed
|
||||
// while we were creating new workers, just add the difference between
|
||||
// the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created)
|
||||
hwm.activeWorkers.Add(currentWorkers - workersWas)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// onDemandWorker processes handoff requests and exits when idle
|
||||
func (hwm *handoffWorkerManager) onDemandWorker() {
|
||||
defer func() {
|
||||
// Handle panics to ensure proper cleanup
|
||||
if r := recover(); r != nil {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerPanicRecovered(r))
|
||||
}
|
||||
|
||||
// Decrement active worker count when exiting
|
||||
hwm.activeWorkers.Add(-1)
|
||||
hwm.workerWg.Done()
|
||||
}()
|
||||
|
||||
// Create reusable timer to prevent timer leaks
|
||||
timer := time.NewTimer(hwm.workerTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
// Reset timer for next iteration
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(hwm.workerTimeout)
|
||||
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdown())
|
||||
}
|
||||
return
|
||||
case <-timer.C:
|
||||
// Worker has been idle for too long, exit to save resources
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout))
|
||||
}
|
||||
return
|
||||
case request := <-hwm.handoffQueue:
|
||||
// Check for shutdown before processing
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing())
|
||||
}
|
||||
// Clean up the request before exiting
|
||||
hwm.pending.Delete(request.ConnID)
|
||||
return
|
||||
default:
|
||||
// Process the request
|
||||
hwm.processHandoffRequest(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processHandoffRequest processes a single handoff request
|
||||
func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint))
|
||||
}
|
||||
|
||||
// Create a context with handoff timeout from config
|
||||
handoffTimeout := 15 * time.Second // Default timeout
|
||||
if hwm.config != nil && hwm.config.HandoffTimeout > 0 {
|
||||
handoffTimeout = hwm.config.HandoffTimeout
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Create a context that also respects the shutdown signal
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
|
||||
defer shutdownCancel()
|
||||
|
||||
// Monitor shutdown signal in a separate goroutine
|
||||
go func() {
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
shutdownCancel()
|
||||
case <-shutdownCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
// Perform the handoff with cancellable context
|
||||
shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn)
|
||||
minRetryBackoff := 500 * time.Millisecond
|
||||
if err != nil {
|
||||
if shouldRetry {
|
||||
now := time.Now()
|
||||
deadline, ok := shutdownCtx.Deadline()
|
||||
thirdOfTimeout := handoffTimeout / 3
|
||||
if !ok || deadline.Before(now) {
|
||||
// wait half the timeout before retrying if no deadline or deadline has passed
|
||||
deadline = now.Add(thirdOfTimeout)
|
||||
}
|
||||
afterTime := deadline.Sub(now)
|
||||
if afterTime < minRetryBackoff {
|
||||
afterTime = minRetryBackoff
|
||||
}
|
||||
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
// Get current retry count for better logging
|
||||
currentRetries := request.Conn.HandoffRetries()
|
||||
maxRetries := 3 // Default fallback
|
||||
if hwm.config != nil {
|
||||
maxRetries = hwm.config.MaxHandoffRetries
|
||||
}
|
||||
internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err))
|
||||
}
|
||||
// Schedule retry - keep connection in pending map until retry is queued
|
||||
time.AfterFunc(afterTime, func() {
|
||||
if err := hwm.queueHandoff(request.Conn); err != nil {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err))
|
||||
}
|
||||
// Failed to queue retry - remove from pending and close connection
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
hwm.closeConnFromRequest(context.Background(), request, err)
|
||||
} else {
|
||||
// Successfully queued retry - remove from pending (will be re-added by queueHandoff)
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
}
|
||||
})
|
||||
return
|
||||
} else {
|
||||
// Won't retry - remove from pending and close connection
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
go hwm.closeConnFromRequest(ctx, request, err)
|
||||
}
|
||||
|
||||
// Clear handoff state if not returned for retry
|
||||
seqID := request.Conn.GetMovingSeqID()
|
||||
connID := request.Conn.GetID()
|
||||
if hwm.poolHook.operationsManager != nil {
|
||||
hwm.poolHook.operationsManager.UntrackOperationWithConnID(seqID, connID)
|
||||
}
|
||||
} else {
|
||||
// Success - remove from pending map
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
}
|
||||
}
|
||||
|
||||
// queueHandoff queues a handoff request for processing
|
||||
// if err is returned, connection will be removed from pool
|
||||
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
|
||||
// Get handoff info atomically to prevent race conditions
|
||||
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
|
||||
|
||||
// on retries the connection will not be marked for handoff, but it will have retries > 0
|
||||
// if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff
|
||||
if !shouldHandoff && conn.HandoffRetries() == 0 {
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID()))
|
||||
}
|
||||
return errors.New(logs.ConnectionNotMarkedForHandoffError(conn.GetID()))
|
||||
}
|
||||
|
||||
// Create handoff request with atomically retrieved data
|
||||
request := HandoffRequest{
|
||||
Conn: conn,
|
||||
ConnID: conn.GetID(),
|
||||
Endpoint: endpoint,
|
||||
SeqID: seqID,
|
||||
Pool: hwm.poolHook.pool, // Include pool for connection removal on failure
|
||||
}
|
||||
|
||||
select {
|
||||
// priority to shutdown
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
default:
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
case hwm.handoffQueue <- request:
|
||||
// Store in pending map
|
||||
hwm.pending.Store(request.ConnID, request.SeqID)
|
||||
// Ensure we have a worker to process this request
|
||||
hwm.ensureWorkerAvailable()
|
||||
return nil
|
||||
default:
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
case hwm.handoffQueue <- request:
|
||||
// Store in pending map
|
||||
hwm.pending.Store(request.ConnID, request.SeqID)
|
||||
// Ensure we have a worker to process this request
|
||||
hwm.ensureWorkerAvailable()
|
||||
return nil
|
||||
case <-time.After(100 * time.Millisecond): // give workers a chance to process
|
||||
// Queue is full - log and attempt scaling
|
||||
queueLen := len(hwm.handoffQueue)
|
||||
queueCap := cap(hwm.handoffQueue)
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have workers available to handle the load
|
||||
hwm.ensureWorkerAvailable()
|
||||
return ErrHandoffQueueFull
|
||||
}
|
||||
|
||||
// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete
|
||||
func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error {
|
||||
hwm.shutdownOnce.Do(func() {
|
||||
close(hwm.shutdown)
|
||||
// workers will exit when they finish their current request
|
||||
|
||||
// Shutdown circuit breaker manager cleanup goroutine
|
||||
if hwm.circuitBreakerManager != nil {
|
||||
hwm.circuitBreakerManager.Shutdown()
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for workers to complete
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
hwm.workerWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// performConnectionHandoff performs the actual connection handoff
|
||||
// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached
|
||||
func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) {
|
||||
// Clear handoff state after successful handoff
|
||||
connID := conn.GetID()
|
||||
|
||||
newEndpoint := conn.GetHandoffEndpoint()
|
||||
if newEndpoint == "" {
|
||||
return false, ErrConnectionInvalidHandoffState
|
||||
}
|
||||
|
||||
// Use circuit breaker to protect against failing endpoints
|
||||
circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint)
|
||||
|
||||
// Check if circuit breaker is open before attempting handoff
|
||||
if circuitBreaker.IsOpen() {
|
||||
internal.Logger.Printf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint))
|
||||
return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open
|
||||
}
|
||||
|
||||
// Perform the handoff
|
||||
shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID)
|
||||
|
||||
// Update circuit breaker based on result
|
||||
if err != nil {
|
||||
// Only track dial/network errors in circuit breaker, not initialization errors
|
||||
if shouldRetry {
|
||||
circuitBreaker.recordFailure()
|
||||
}
|
||||
return shouldRetry, err
|
||||
}
|
||||
|
||||
// Success - record in circuit breaker
|
||||
circuitBreaker.recordSuccess()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
|
||||
func (hwm *handoffWorkerManager) performHandoffInternal(
|
||||
ctx context.Context,
|
||||
conn *pool.Conn,
|
||||
newEndpoint string,
|
||||
connID uint64,
|
||||
) (shouldRetry bool, err error) {
|
||||
retries := conn.IncrementAndGetHandoffRetries(1)
|
||||
internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String()))
|
||||
maxRetries := 3 // Default fallback
|
||||
if hwm.config != nil {
|
||||
maxRetries = hwm.config.MaxHandoffRetries
|
||||
}
|
||||
|
||||
if retries > maxRetries {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries))
|
||||
}
|
||||
// won't retry on ErrMaxHandoffRetriesReached
|
||||
return false, ErrMaxHandoffRetriesReached
|
||||
}
|
||||
|
||||
// Create endpoint-specific dialer
|
||||
endpointDialer := hwm.createEndpointDialer(newEndpoint)
|
||||
|
||||
// Create new connection to the new endpoint
|
||||
newNetConn, err := endpointDialer(ctx)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err))
|
||||
// will retry
|
||||
// Maybe a network error - retry after a delay
|
||||
return true, err
|
||||
}
|
||||
|
||||
// Get the old connection
|
||||
oldConn := conn.GetNetConn()
|
||||
|
||||
// Apply relaxed timeout to the new connection for the configured post-handoff duration
|
||||
// This gives the new connection more time to handle operations during cluster transition
|
||||
// Setting this here (before initing the connection) ensures that the connection is going
|
||||
// to use the relaxed timeout for the first operation (auth/ACL select)
|
||||
if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 {
|
||||
relaxedTimeout := hwm.config.RelaxedTimeout
|
||||
// Set relaxed timeout with deadline - no background goroutine needed
|
||||
deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration)
|
||||
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
|
||||
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000")))
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the connection and execute initialization
|
||||
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
|
||||
if err != nil {
|
||||
// won't retry
|
||||
// Initialization failed - remove the connection
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
if oldConn != nil {
|
||||
oldConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Clear handoff state will:
|
||||
// - set the connection as usable again
|
||||
// - clear the handoff state (shouldHandoff, endpoint, seqID)
|
||||
// - reset the handoff retries to 0
|
||||
// Note: Theoretically there may be a short window where the connection is in the pool
|
||||
// and IDLE (initConn completed) but still has handoff state set.
|
||||
conn.ClearHandoffState()
|
||||
internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint))
|
||||
|
||||
// successfully completed the handoff, no retry needed and no error
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// createEndpointDialer creates a dialer function that connects to a specific endpoint
|
||||
func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
// Parse endpoint to extract host and port
|
||||
host, port, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
// If no port specified, assume default Redis port
|
||||
host = endpoint
|
||||
if port == "" {
|
||||
port = "6379"
|
||||
}
|
||||
}
|
||||
|
||||
// Use the base dialer to connect to the new endpoint
|
||||
return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port))
|
||||
}
|
||||
}
|
||||
|
||||
// closeConnFromRequest closes the connection and logs the reason
|
||||
func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
|
||||
pooler := request.Pool
|
||||
conn := request.Conn
|
||||
|
||||
// Clear handoff state before closing
|
||||
conn.ClearHandoffState()
|
||||
|
||||
if pooler != nil {
|
||||
// Use RemoveWithoutTurn instead of Remove to avoid freeing a turn that we don't have.
|
||||
// The handoff worker doesn't call Get(), so it doesn't have a turn to free.
|
||||
// Remove() is meant to be called after Get() and frees a turn.
|
||||
// RemoveWithoutTurn() removes and closes the connection without affecting the queue.
|
||||
pooler.RemoveWithoutTurn(ctx, conn, err)
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err))
|
||||
}
|
||||
} else {
|
||||
err := conn.Close() // Close the connection if no pool provided
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "redis: failed to close connection: %v", err)
|
||||
}
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err))
|
||||
}
|
||||
}
|
||||
}
|
||||
+60
@@ -0,0 +1,60 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// LoggingHook is an example hook implementation that logs all notifications.
|
||||
type LoggingHook struct {
|
||||
LogLevel int // 0=Error, 1=Warn, 2=Info, 3=Debug
|
||||
}
|
||||
|
||||
// PreHook logs the notification before processing and allows modification.
|
||||
func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
if lh.LogLevel >= 2 { // Info level
|
||||
// Log the notification type and content
|
||||
connID := uint64(0)
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
connID = conn.GetID()
|
||||
}
|
||||
seqID := int64(0)
|
||||
if slices.Contains(maintenanceNotificationTypes, notificationType) {
|
||||
// seqID is the second element in the notification array
|
||||
if len(notification) > 1 {
|
||||
if parsedSeqID, ok := notification[1].(int64); !ok {
|
||||
seqID = 0
|
||||
} else {
|
||||
seqID = parsedSeqID
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
internal.Logger.Printf(ctx, logs.ProcessingNotification(connID, seqID, notificationType, notification))
|
||||
}
|
||||
return notification, true // Continue processing with unmodified notification
|
||||
}
|
||||
|
||||
// PostHook logs the result after processing.
|
||||
func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
connID := uint64(0)
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
connID = conn.GetID()
|
||||
}
|
||||
if result != nil && lh.LogLevel >= 1 { // Warning level
|
||||
internal.Logger.Printf(ctx, logs.ProcessingNotificationFailed(connID, notificationType, result, notification))
|
||||
} else if lh.LogLevel >= 3 { // Debug level
|
||||
internal.Logger.Printf(ctx, logs.ProcessingNotificationSucceeded(connID, notificationType))
|
||||
}
|
||||
}
|
||||
|
||||
// NewLoggingHook creates a new logging hook with the specified log level.
|
||||
// Log levels: 0=Error, 1=Warn, 2=Info, 3=Debug
|
||||
func NewLoggingHook(logLevel int) *LoggingHook {
|
||||
return &LoggingHook{LogLevel: logLevel}
|
||||
}
|
||||
+320
@@ -0,0 +1,320 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Push notification type constants for maintenance
|
||||
const (
|
||||
NotificationMoving = "MOVING"
|
||||
NotificationMigrating = "MIGRATING"
|
||||
NotificationMigrated = "MIGRATED"
|
||||
NotificationFailingOver = "FAILING_OVER"
|
||||
NotificationFailedOver = "FAILED_OVER"
|
||||
)
|
||||
|
||||
// maintenanceNotificationTypes contains all notification types that maintenance handles
|
||||
var maintenanceNotificationTypes = []string{
|
||||
NotificationMoving,
|
||||
NotificationMigrating,
|
||||
NotificationMigrated,
|
||||
NotificationFailingOver,
|
||||
NotificationFailedOver,
|
||||
}
|
||||
|
||||
// NotificationHook is called before and after notification processing
|
||||
// PreHook can modify the notification and return false to skip processing
|
||||
// PostHook is called after successful processing
|
||||
type NotificationHook interface {
|
||||
PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool)
|
||||
PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error)
|
||||
}
|
||||
|
||||
// MovingOperationKey provides a unique key for tracking MOVING operations
|
||||
// that combines sequence ID with connection identifier to handle duplicate
|
||||
// sequence IDs across multiple connections to the same node.
|
||||
type MovingOperationKey struct {
|
||||
SeqID int64 // Sequence ID from MOVING notification
|
||||
ConnID uint64 // Unique connection identifier
|
||||
}
|
||||
|
||||
// String returns a string representation of the key for debugging
|
||||
func (k MovingOperationKey) String() string {
|
||||
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
|
||||
}
|
||||
|
||||
// Manager provides a simplified upgrade functionality with hooks and atomic state.
|
||||
type Manager struct {
|
||||
client interfaces.ClientInterface
|
||||
config *Config
|
||||
options interfaces.OptionsInterface
|
||||
pool pool.Pooler
|
||||
|
||||
// MOVING operation tracking - using sync.Map for better concurrent performance
|
||||
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
|
||||
|
||||
// Atomic state tracking - no locks needed for state queries
|
||||
activeOperationCount atomic.Int64 // Number of active operations
|
||||
closed atomic.Bool // Manager closed state
|
||||
|
||||
// Notification hooks for extensibility
|
||||
hooks []NotificationHook
|
||||
hooksMu sync.RWMutex // Protects hooks slice
|
||||
poolHooksRef *PoolHook
|
||||
}
|
||||
|
||||
// MovingOperation tracks an active MOVING operation.
|
||||
type MovingOperation struct {
|
||||
SeqID int64
|
||||
NewEndpoint string
|
||||
StartTime time.Time
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
// NewManager creates a new simplified manager.
|
||||
func NewManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*Manager, error) {
|
||||
if client == nil {
|
||||
return nil, ErrInvalidClient
|
||||
}
|
||||
|
||||
hm := &Manager{
|
||||
client: client,
|
||||
pool: pool,
|
||||
options: client.GetOptions(),
|
||||
config: config.Clone(),
|
||||
hooks: make([]NotificationHook, 0),
|
||||
}
|
||||
|
||||
// Set up push notification handling
|
||||
if err := hm.setupPushNotifications(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
// GetPoolHook creates a pool hook with a custom dialer.
|
||||
func (hm *Manager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
|
||||
poolHook := hm.createPoolHook(baseDialer)
|
||||
hm.pool.AddPoolHook(poolHook)
|
||||
}
|
||||
|
||||
// setupPushNotifications sets up push notification handling by registering with the client's processor.
|
||||
func (hm *Manager) setupPushNotifications() error {
|
||||
processor := hm.client.GetPushProcessor()
|
||||
if processor == nil {
|
||||
return ErrInvalidClient // Client doesn't support push notifications
|
||||
}
|
||||
|
||||
// Create our notification handler
|
||||
handler := &NotificationHandler{manager: hm, operationsManager: hm}
|
||||
|
||||
// Register handlers for all upgrade notifications with the client's processor
|
||||
for _, notificationType := range maintenanceNotificationTypes {
|
||||
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
|
||||
return errors.New(logs.FailedToRegisterHandler(notificationType, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
|
||||
func (hm *Manager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
|
||||
// Create composite key
|
||||
key := MovingOperationKey{
|
||||
SeqID: seqID,
|
||||
ConnID: connID,
|
||||
}
|
||||
|
||||
// Create MOVING operation record
|
||||
movingOp := &MovingOperation{
|
||||
SeqID: seqID,
|
||||
NewEndpoint: newEndpoint,
|
||||
StartTime: time.Now(),
|
||||
Deadline: deadline,
|
||||
}
|
||||
|
||||
// Use LoadOrStore for atomic check-and-set operation
|
||||
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
|
||||
// Duplicate MOVING notification, ignore
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID))
|
||||
}
|
||||
|
||||
// Increment active operation count atomically
|
||||
hm.activeOperationCount.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
|
||||
func (hm *Manager) UntrackOperationWithConnID(seqID int64, connID uint64) {
|
||||
// Create composite key
|
||||
key := MovingOperationKey{
|
||||
SeqID: seqID,
|
||||
ConnID: connID,
|
||||
}
|
||||
|
||||
// Remove from active operations atomically
|
||||
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.UntrackingMovingOperation(connID, seqID))
|
||||
}
|
||||
// Decrement active operation count only if operation existed
|
||||
hm.activeOperationCount.Add(-1)
|
||||
} else {
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.OperationNotTracked(connID, seqID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveMovingOperations returns active operations with composite keys.
|
||||
// WARNING: This method creates a new map and copies all operations on every call.
|
||||
// Use sparingly, especially in hot paths or high-frequency logging.
|
||||
func (hm *Manager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
|
||||
result := make(map[MovingOperationKey]*MovingOperation)
|
||||
|
||||
// Iterate over sync.Map to build result
|
||||
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||
k := key.(MovingOperationKey)
|
||||
op := value.(*MovingOperation)
|
||||
|
||||
// Create a copy to avoid sharing references
|
||||
result[k] = &MovingOperation{
|
||||
SeqID: op.SeqID,
|
||||
NewEndpoint: op.NewEndpoint,
|
||||
StartTime: op.StartTime,
|
||||
Deadline: op.Deadline,
|
||||
}
|
||||
return true // Continue iteration
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// IsHandoffInProgress returns true if any handoff is in progress.
|
||||
// Uses atomic counter for lock-free operation.
|
||||
func (hm *Manager) IsHandoffInProgress() bool {
|
||||
return hm.activeOperationCount.Load() > 0
|
||||
}
|
||||
|
||||
// GetActiveOperationCount returns the number of active operations.
|
||||
// Uses atomic counter for lock-free operation.
|
||||
func (hm *Manager) GetActiveOperationCount() int64 {
|
||||
return hm.activeOperationCount.Load()
|
||||
}
|
||||
|
||||
// Close closes the manager.
|
||||
func (hm *Manager) Close() error {
|
||||
// Use atomic operation for thread-safe close check
|
||||
if !hm.closed.CompareAndSwap(false, true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
// Shutdown the pool hook if it exists
|
||||
if hm.poolHooksRef != nil {
|
||||
// Use a timeout to prevent hanging indefinitely
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := hm.poolHooksRef.Shutdown(shutdownCtx)
|
||||
if err != nil {
|
||||
// was not able to close pool hook, keep closed state false
|
||||
hm.closed.Store(false)
|
||||
return err
|
||||
}
|
||||
// Remove the pool hook from the pool
|
||||
if hm.pool != nil {
|
||||
hm.pool.RemovePoolHook(hm.poolHooksRef)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear all active operations
|
||||
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||
hm.activeMovingOps.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
// Reset counter
|
||||
hm.activeOperationCount.Store(0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetState returns current state using atomic counter for lock-free operation.
|
||||
func (hm *Manager) GetState() State {
|
||||
if hm.activeOperationCount.Load() > 0 {
|
||||
return StateMoving
|
||||
}
|
||||
return StateIdle
|
||||
}
|
||||
|
||||
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
|
||||
func (hm *Manager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
hm.hooksMu.RLock()
|
||||
defer hm.hooksMu.RUnlock()
|
||||
|
||||
currentNotification := notification
|
||||
|
||||
for _, hook := range hm.hooks {
|
||||
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification)
|
||||
if !shouldContinue {
|
||||
return modifiedNotification, false
|
||||
}
|
||||
currentNotification = modifiedNotification
|
||||
}
|
||||
|
||||
return currentNotification, true
|
||||
}
|
||||
|
||||
// processPostHooks calls all post-hooks with the processing result.
|
||||
func (hm *Manager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
hm.hooksMu.RLock()
|
||||
defer hm.hooksMu.RUnlock()
|
||||
|
||||
for _, hook := range hm.hooks {
|
||||
hook.PostHook(ctx, notificationCtx, notificationType, notification, result)
|
||||
}
|
||||
}
|
||||
|
||||
// createPoolHook creates a pool hook with this manager already set.
|
||||
func (hm *Manager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
|
||||
if hm.poolHooksRef != nil {
|
||||
return hm.poolHooksRef
|
||||
}
|
||||
// Get pool size from client options for better worker defaults
|
||||
poolSize := 0
|
||||
if hm.options != nil {
|
||||
poolSize = hm.options.GetPoolSize()
|
||||
}
|
||||
|
||||
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
|
||||
hm.poolHooksRef.SetPool(hm.pool)
|
||||
|
||||
return hm.poolHooksRef
|
||||
}
|
||||
|
||||
func (hm *Manager) AddNotificationHook(notificationHook NotificationHook) {
|
||||
hm.hooksMu.Lock()
|
||||
defer hm.hooksMu.Unlock()
|
||||
hm.hooks = append(hm.hooks, notificationHook)
|
||||
}
|
||||
+182
@@ -0,0 +1,182 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// OperationsManagerInterface defines the interface for completing handoff operations
|
||||
type OperationsManagerInterface interface {
|
||||
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
|
||||
UntrackOperationWithConnID(seqID int64, connID uint64)
|
||||
}
|
||||
|
||||
// HandoffRequest represents a request to handoff a connection to a new endpoint
|
||||
type HandoffRequest struct {
|
||||
Conn *pool.Conn
|
||||
ConnID uint64 // Unique connection identifier
|
||||
Endpoint string
|
||||
SeqID int64
|
||||
Pool pool.Pooler // Pool to remove connection from on failure
|
||||
}
|
||||
|
||||
// PoolHook implements pool.PoolHook for Redis-specific connection handling
|
||||
// with maintenance notifications support.
|
||||
type PoolHook struct {
|
||||
// Base dialer for creating connections to new endpoints during handoffs
|
||||
// args are network and address
|
||||
baseDialer func(context.Context, string, string) (net.Conn, error)
|
||||
|
||||
// Network type (e.g., "tcp", "unix")
|
||||
network string
|
||||
|
||||
// Worker manager for background handoff processing
|
||||
workerManager *handoffWorkerManager
|
||||
|
||||
// Configuration for the maintenance notifications
|
||||
config *Config
|
||||
|
||||
// Operations manager interface for operation completion tracking
|
||||
operationsManager OperationsManagerInterface
|
||||
|
||||
// Pool interface for removing connections on handoff failure
|
||||
pool pool.Pooler
|
||||
}
|
||||
|
||||
// NewPoolHook creates a new pool hook
|
||||
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface) *PoolHook {
|
||||
return NewPoolHookWithPoolSize(baseDialer, network, config, operationsManager, 0)
|
||||
}
|
||||
|
||||
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
|
||||
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface, poolSize int) *PoolHook {
|
||||
// Apply defaults if config is nil or has zero values
|
||||
if config == nil {
|
||||
config = config.ApplyDefaultsWithPoolSize(poolSize)
|
||||
}
|
||||
|
||||
ph := &PoolHook{
|
||||
// baseDialer is used to create connections to new endpoints during handoffs
|
||||
baseDialer: baseDialer,
|
||||
network: network,
|
||||
config: config,
|
||||
operationsManager: operationsManager,
|
||||
}
|
||||
|
||||
// Create worker manager
|
||||
ph.workerManager = newHandoffWorkerManager(config, ph)
|
||||
|
||||
return ph
|
||||
}
|
||||
|
||||
// SetPool sets the pool interface for removing connections on handoff failure
|
||||
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
|
||||
ph.pool = pooler
|
||||
}
|
||||
|
||||
// GetCurrentWorkers returns the current number of active workers (for testing)
|
||||
func (ph *PoolHook) GetCurrentWorkers() int {
|
||||
return ph.workerManager.getCurrentWorkers()
|
||||
}
|
||||
|
||||
// IsHandoffPending returns true if the given connection has a pending handoff
|
||||
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
|
||||
return ph.workerManager.isHandoffPending(conn)
|
||||
}
|
||||
|
||||
// GetPendingMap returns the pending map for testing purposes
|
||||
func (ph *PoolHook) GetPendingMap() *sync.Map {
|
||||
return ph.workerManager.getPendingMap()
|
||||
}
|
||||
|
||||
// GetMaxWorkers returns the max workers for testing purposes
|
||||
func (ph *PoolHook) GetMaxWorkers() int {
|
||||
return ph.workerManager.getMaxWorkers()
|
||||
}
|
||||
|
||||
// GetHandoffQueue returns the handoff queue for testing purposes
|
||||
func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest {
|
||||
return ph.workerManager.getHandoffQueue()
|
||||
}
|
||||
|
||||
// GetCircuitBreakerStats returns circuit breaker statistics for monitoring
|
||||
func (ph *PoolHook) GetCircuitBreakerStats() []CircuitBreakerStats {
|
||||
return ph.workerManager.getCircuitBreakerStats()
|
||||
}
|
||||
|
||||
// ResetCircuitBreakers resets all circuit breakers (useful for testing)
|
||||
func (ph *PoolHook) ResetCircuitBreakers() {
|
||||
ph.workerManager.resetCircuitBreakers()
|
||||
}
|
||||
|
||||
// OnGet is called when a connection is retrieved from the pool
|
||||
func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
|
||||
// Check if connection is marked for handoff
|
||||
// This prevents using connections that have received MOVING notifications
|
||||
if conn.ShouldHandoff() {
|
||||
return false, ErrConnectionMarkedForHandoffWithState
|
||||
}
|
||||
|
||||
// Check if connection is usable (not in UNUSABLE or CLOSED state)
|
||||
// This ensures we don't return connections that are currently being handed off or re-authenticated.
|
||||
if !conn.IsUsable() {
|
||||
return false, ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// OnPut is called when a connection is returned to the pool
|
||||
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
// first check if we should handoff for faster rejection
|
||||
if !conn.ShouldHandoff() {
|
||||
// Default behavior (no handoff): pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// check pending handoff to not queue the same connection twice
|
||||
if ph.workerManager.isHandoffPending(conn) {
|
||||
// Default behavior (pending handoff): pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
if err := ph.workerManager.queueHandoff(conn); err != nil {
|
||||
// Failed to queue handoff, remove the connection
|
||||
internal.Logger.Printf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err))
|
||||
// Don't pool, remove connection, no error to caller
|
||||
return false, true, nil
|
||||
}
|
||||
|
||||
// Check if handoff was already processed by a worker before we can mark it as queued
|
||||
if !conn.ShouldHandoff() {
|
||||
// Handoff was already processed - this is normal and the connection should be pooled
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
if err := conn.MarkQueuedForHandoff(); err != nil {
|
||||
// If marking fails, check if handoff was processed in the meantime
|
||||
if !conn.ShouldHandoff() {
|
||||
// Handoff was processed - this is normal, pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
// Other error - remove the connection
|
||||
return false, true, nil
|
||||
}
|
||||
internal.Logger.Printf(ctx, logs.MarkedForHandoff(conn.GetID()))
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
func (ph *PoolHook) OnRemove(_ context.Context, _ *pool.Conn, _ error) {
|
||||
// Not used
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the processor, waiting for workers to complete
|
||||
func (ph *PoolHook) Shutdown(ctx context.Context) error {
|
||||
return ph.workerManager.shutdownWorkers(ctx)
|
||||
}
|
||||
Generated
Vendored
+282
@@ -0,0 +1,282 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// NotificationHandler handles push notifications for the simplified manager.
|
||||
type NotificationHandler struct {
|
||||
manager *Manager
|
||||
operationsManager OperationsManagerInterface
|
||||
}
|
||||
|
||||
// HandlePushNotification processes push notifications with hook support.
|
||||
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) == 0 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotificationFormat(notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
notificationType, ok := notification[0].(string)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotificationTypeFormat(notification[0]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Process pre-hooks - they can modify the notification or skip processing
|
||||
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, handlerCtx, notificationType, notification)
|
||||
if !shouldContinue {
|
||||
return nil // Hooks decided to skip processing
|
||||
}
|
||||
|
||||
var err error
|
||||
switch notificationType {
|
||||
case NotificationMoving:
|
||||
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationMigrating:
|
||||
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationMigrated:
|
||||
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationFailingOver:
|
||||
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationFailedOver:
|
||||
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
|
||||
default:
|
||||
// Ignore other notification types (e.g., pub/sub messages)
|
||||
err = nil
|
||||
}
|
||||
|
||||
// Process post-hooks with the result
|
||||
snh.manager.processPostHooks(ctx, handlerCtx, notificationType, modifiedNotification, err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// handleMoving processes MOVING notifications.
|
||||
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
|
||||
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 3 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("MOVING", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
seqID, ok := notification[1].(int64)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Extract timeS
|
||||
timeS, ok := notification[2].(int64)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidTimeSInMovingNotification(notification[2]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
newEndpoint := ""
|
||||
if len(notification) > 3 {
|
||||
// Extract new endpoint
|
||||
newEndpoint, ok = notification[3].(string)
|
||||
if !ok {
|
||||
stringified := fmt.Sprintf("%v", notification[3])
|
||||
// this could be <nil> which is valid
|
||||
if notification[3] == nil || stringified == internal.RedisNull {
|
||||
newEndpoint = ""
|
||||
} else {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the connection that received this notification
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MOVING"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Type assert to get the underlying pool connection
|
||||
var poolConn *pool.Conn
|
||||
if pc, ok := conn.(*pool.Conn); ok {
|
||||
poolConn = pc
|
||||
} else {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// If the connection is closed or not pooled, we can ignore the notification
|
||||
// this connection won't be remembered by the pool and will be garbage collected
|
||||
// Keep pubsub connections around since they are not pooled but are long-lived
|
||||
// and should be allowed to handoff (the pubsub instance will reconnect and change
|
||||
// the underlying *pool.Conn)
|
||||
if (poolConn.IsClosed() || !poolConn.IsPooled()) && !poolConn.IsPubSub() {
|
||||
return nil
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
|
||||
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
|
||||
if newEndpoint == "" || newEndpoint == internal.RedisNull {
|
||||
if internal.LogLevel.DebugOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2))
|
||||
}
|
||||
// same as current endpoint
|
||||
newEndpoint = snh.manager.options.GetAddr()
|
||||
// delay the handoff for timeS/2 seconds to the same endpoint
|
||||
// do this in a goroutine to avoid blocking the notification handler
|
||||
// NOTE: This timer is started while parsing the notification, so the connection is not marked for handoff
|
||||
// and there should be no possibility of a race condition or double handoff.
|
||||
time.AfterFunc(time.Duration(timeS/2)*time.Second, func() {
|
||||
if poolConn == nil || poolConn.IsClosed() {
|
||||
return
|
||||
}
|
||||
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
|
||||
// Log error but don't fail the goroutine - use background context since original may be cancelled
|
||||
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err))
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
|
||||
}
|
||||
|
||||
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
|
||||
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
|
||||
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err))
|
||||
// Connection is already marked for handoff, which is acceptable
|
||||
// This can happen if multiple MOVING notifications are received for the same connection
|
||||
return nil
|
||||
}
|
||||
// Optionally track in m
|
||||
if snh.operationsManager != nil {
|
||||
connID := conn.GetID()
|
||||
// Track the operation (ignore errors since this is optional)
|
||||
_ = snh.operationsManager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
|
||||
} else {
|
||||
return errors.New(logs.ManagerNotInitialized())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigrating processes MIGRATING notifications.
|
||||
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// MIGRATING notifications indicate that a connection is about to be migrated
|
||||
// Apply relaxed timeouts to the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATING", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATING"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Apply relaxed timeout to this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout))
|
||||
}
|
||||
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigrated processes MIGRATED notifications.
|
||||
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// MIGRATED notifications indicate that a connection migration has completed
|
||||
// Restore normal timeouts for the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATED", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATED"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Clear relaxed timeout for this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
|
||||
}
|
||||
conn.ClearRelaxedTimeout()
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailingOver processes FAILING_OVER notifications.
|
||||
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// FAILING_OVER notifications indicate that a connection is about to failover
|
||||
// Apply relaxed timeouts to the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILING_OVER", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Apply relaxed timeout to this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout))
|
||||
}
|
||||
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailedOver processes FAILED_OVER notifications.
|
||||
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// FAILED_OVER notifications indicate that a connection failover has completed
|
||||
// Restore normal timeouts for the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILED_OVER", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Clear relaxed timeout for this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
|
||||
}
|
||||
conn.ClearRelaxedTimeout()
|
||||
return nil
|
||||
}
|
||||
+24
@@ -0,0 +1,24 @@
|
||||
package maintnotifications
|
||||
|
||||
// State represents the current state of a maintenance operation
|
||||
type State int
|
||||
|
||||
const (
|
||||
// StateIdle indicates no upgrade is in progress
|
||||
StateIdle State = iota
|
||||
|
||||
// StateHandoff indicates a connection handoff is in progress
|
||||
StateMoving
|
||||
)
|
||||
|
||||
// String returns a string representation of the state.
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateIdle:
|
||||
return "idle"
|
||||
case StateMoving:
|
||||
return "moving"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
+149
-14
@@ -16,6 +16,9 @@ import (
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Limiter is the interface of a rate limiter or a circuit breaker.
|
||||
@@ -31,7 +34,6 @@ type Limiter interface {
|
||||
|
||||
// Options keeps the settings to set up redis connection.
|
||||
type Options struct {
|
||||
|
||||
// Network type, either tcp or unix.
|
||||
//
|
||||
// default: is tcp.
|
||||
@@ -109,6 +111,16 @@ type Options struct {
|
||||
// default: 5 seconds
|
||||
DialTimeout time.Duration
|
||||
|
||||
// DialerRetries is the maximum number of retry attempts when dialing fails.
|
||||
//
|
||||
// default: 5
|
||||
DialerRetries int
|
||||
|
||||
// DialerRetryTimeout is the backoff duration between retry attempts.
|
||||
//
|
||||
// default: 100 milliseconds
|
||||
DialerRetryTimeout time.Duration
|
||||
|
||||
// ReadTimeout for socket reads. If reached, commands will fail
|
||||
// with a timeout instead of blocking. Supported values:
|
||||
//
|
||||
@@ -152,6 +164,7 @@ type Options struct {
|
||||
//
|
||||
// Note that FIFO has slightly higher overhead compared to LIFO,
|
||||
// but it helps closing idle connections faster reducing the pool size.
|
||||
// default: false
|
||||
PoolFIFO bool
|
||||
|
||||
// PoolSize is the base number of socket connections.
|
||||
@@ -162,6 +175,10 @@ type Options struct {
|
||||
// default: 10 * runtime.GOMAXPROCS(0)
|
||||
PoolSize int
|
||||
|
||||
// MaxConcurrentDials is the maximum number of concurrent connection creation goroutines.
|
||||
// If <= 0, defaults to PoolSize. If > PoolSize, it will be capped at PoolSize.
|
||||
MaxConcurrentDials int
|
||||
|
||||
// PoolTimeout is the amount of time client waits for connection if all connections
|
||||
// are busy before returning an error.
|
||||
//
|
||||
@@ -232,10 +249,24 @@ type Options struct {
|
||||
// When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult
|
||||
UnstableResp3 bool
|
||||
|
||||
// Push notifications are always enabled for RESP3 connections (Protocol: 3)
|
||||
// and are not available for RESP2 connections. No configuration option is needed.
|
||||
|
||||
// PushNotificationProcessor is the processor for handling push notifications.
|
||||
// If nil, a default processor will be created for RESP3 connections.
|
||||
PushNotificationProcessor push.NotificationProcessor
|
||||
|
||||
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
|
||||
// When a node is marked as failing, it will be avoided for this duration.
|
||||
// Default is 15 seconds.
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
// MaintNotificationsConfig provides custom configuration for maintnotifications.
|
||||
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
|
||||
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, maintnotifications are in "auto" mode and will be enabled if the server supports it.
|
||||
MaintNotificationsConfig *maintnotifications.Config
|
||||
}
|
||||
|
||||
func (opt *Options) init() {
|
||||
@@ -255,12 +286,23 @@ func (opt *Options) init() {
|
||||
if opt.DialTimeout == 0 {
|
||||
opt.DialTimeout = 5 * time.Second
|
||||
}
|
||||
if opt.DialerRetries == 0 {
|
||||
opt.DialerRetries = 5
|
||||
}
|
||||
if opt.DialerRetryTimeout == 0 {
|
||||
opt.DialerRetryTimeout = 100 * time.Millisecond
|
||||
}
|
||||
if opt.Dialer == nil {
|
||||
opt.Dialer = NewDialer(opt)
|
||||
}
|
||||
if opt.PoolSize == 0 {
|
||||
opt.PoolSize = 10 * runtime.GOMAXPROCS(0)
|
||||
}
|
||||
if opt.MaxConcurrentDials <= 0 {
|
||||
opt.MaxConcurrentDials = opt.PoolSize
|
||||
} else if opt.MaxConcurrentDials > opt.PoolSize {
|
||||
opt.MaxConcurrentDials = opt.PoolSize
|
||||
}
|
||||
if opt.ReadBufferSize == 0 {
|
||||
opt.ReadBufferSize = proto.DefaultBufferSize
|
||||
}
|
||||
@@ -312,13 +354,40 @@ func (opt *Options) init() {
|
||||
case 0:
|
||||
opt.MaxRetryBackoff = 512 * time.Millisecond
|
||||
}
|
||||
|
||||
if opt.FailingTimeoutSeconds == 0 {
|
||||
opt.FailingTimeoutSeconds = 15
|
||||
}
|
||||
|
||||
opt.MaintNotificationsConfig = opt.MaintNotificationsConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns)
|
||||
|
||||
// auto-detect endpoint type if not specified
|
||||
endpointType := opt.MaintNotificationsConfig.EndpointType
|
||||
if endpointType == "" || endpointType == maintnotifications.EndpointTypeAuto {
|
||||
// Auto-detect endpoint type if not specified
|
||||
endpointType = maintnotifications.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
|
||||
}
|
||||
opt.MaintNotificationsConfig.EndpointType = endpointType
|
||||
}
|
||||
|
||||
func (opt *Options) clone() *Options {
|
||||
clone := *opt
|
||||
|
||||
// Deep clone MaintNotificationsConfig to avoid sharing between clients
|
||||
if opt.MaintNotificationsConfig != nil {
|
||||
configClone := *opt.MaintNotificationsConfig
|
||||
clone.MaintNotificationsConfig = &configClone
|
||||
}
|
||||
|
||||
return &clone
|
||||
}
|
||||
|
||||
// NewDialer returns a function that will be used as the default dialer
|
||||
// when none is specified in Options.Dialer.
|
||||
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
|
||||
return NewDialer(opt)
|
||||
}
|
||||
|
||||
// NewDialer returns a function that will be used as the default dialer
|
||||
// when none is specified in Options.Dialer.
|
||||
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
|
||||
@@ -565,6 +634,7 @@ func setupConnParams(u *url.URL, o *Options) (*Options, error) {
|
||||
o.MinIdleConns = q.int("min_idle_conns")
|
||||
o.MaxIdleConns = q.int("max_idle_conns")
|
||||
o.MaxActiveConns = q.int("max_active_conns")
|
||||
o.MaxConcurrentDials = q.int("max_concurrent_dials")
|
||||
if q.has("conn_max_idle_time") {
|
||||
o.ConnMaxIdleTime = q.duration("conn_max_idle_time")
|
||||
} else {
|
||||
@@ -604,21 +674,86 @@ func getUserPassword(u *url.URL) (string, string) {
|
||||
func newConnPool(
|
||||
opt *Options,
|
||||
dialer func(ctx context.Context, network, addr string) (net.Conn, error),
|
||||
) *pool.ConnPool {
|
||||
) (*pool.ConnPool, error) {
|
||||
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return dialer(ctx, opt.Network, opt.Addr)
|
||||
},
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: opt.PoolSize,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
MinIdleConns: opt.MinIdleConns,
|
||||
MaxIdleConns: opt.MaxIdleConns,
|
||||
MaxActiveConns: opt.MaxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: opt.ReadBufferSize,
|
||||
WriteBufferSize: opt.WriteBufferSize,
|
||||
})
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: poolSize,
|
||||
MaxConcurrentDials: opt.MaxConcurrentDials,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
DialerRetries: opt.DialerRetries,
|
||||
DialerRetryTimeout: opt.DialerRetryTimeout,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxActiveConns: maxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: opt.ReadBufferSize,
|
||||
WriteBufferSize: opt.WriteBufferSize,
|
||||
PushNotificationsEnabled: opt.Protocol == 3,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
|
||||
) (*pool.PubSubPool, error) {
|
||||
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pool.NewPubSubPool(&pool.Options{
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: poolSize,
|
||||
MaxConcurrentDials: opt.MaxConcurrentDials,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
DialerRetries: opt.DialerRetries,
|
||||
DialerRetryTimeout: opt.DialerRetryTimeout,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxActiveConns: maxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: 32 * 1024,
|
||||
WriteBufferSize: 32 * 1024,
|
||||
PushNotificationsEnabled: opt.Protocol == 3,
|
||||
}, dialer), nil
|
||||
}
|
||||
|
||||
+67
-12
@@ -20,6 +20,8 @@ import (
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/internal/rand"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -38,6 +40,7 @@ type ClusterOptions struct {
|
||||
ClientName string
|
||||
|
||||
// NewClient creates a cluster node client with provided name and options.
|
||||
// If NewClient is set by the user, the user is responsible for handling maintnotifications upgrades and push notifications.
|
||||
NewClient func(opt *Options) *Client
|
||||
|
||||
// The maximum number of retries before giving up. Command is retried
|
||||
@@ -74,6 +77,10 @@ type ClusterOptions struct {
|
||||
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
|
||||
StreamingCredentialsProvider auth.StreamingCredentialsProvider
|
||||
|
||||
// MaxRetries is the maximum number of retries before giving up.
|
||||
// For ClusterClient, retries are disabled by default (set to -1),
|
||||
// because the cluster client handles all kinds of retries internally.
|
||||
// This is intentional and differs from the standalone Options default.
|
||||
MaxRetries int
|
||||
MinRetryBackoff time.Duration
|
||||
MaxRetryBackoff time.Duration
|
||||
@@ -125,10 +132,22 @@ type ClusterOptions struct {
|
||||
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
|
||||
UnstableResp3 bool
|
||||
|
||||
// PushNotificationProcessor is the processor for handling push notifications.
|
||||
// If nil, a default processor will be created for RESP3 connections.
|
||||
PushNotificationProcessor push.NotificationProcessor
|
||||
|
||||
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
|
||||
// When a node is marked as failing, it will be avoided for this duration.
|
||||
// Default is 15 seconds.
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
// MaintNotificationsConfig provides custom configuration for maintnotifications upgrades.
|
||||
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
|
||||
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, maintnotifications upgrades are in "auto" mode and will be enabled if the server supports it.
|
||||
// The ClusterClient does not directly work with maintnotifications, it is up to the clients in the Nodes map to work with maintnotifications.
|
||||
MaintNotificationsConfig *maintnotifications.Config
|
||||
}
|
||||
|
||||
func (opt *ClusterOptions) init() {
|
||||
@@ -319,6 +338,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
|
||||
}
|
||||
|
||||
func (opt *ClusterOptions) clientOptions() *Options {
|
||||
// Clone MaintNotificationsConfig to avoid sharing between cluster node clients
|
||||
var maintNotificationsConfig *maintnotifications.Config
|
||||
if opt.MaintNotificationsConfig != nil {
|
||||
configClone := *opt.MaintNotificationsConfig
|
||||
maintNotificationsConfig = &configClone
|
||||
}
|
||||
|
||||
return &Options{
|
||||
ClientName: opt.ClientName,
|
||||
Dialer: opt.Dialer,
|
||||
@@ -360,8 +386,10 @@ func (opt *ClusterOptions) clientOptions() *Options {
|
||||
// much use for ClusterSlots config). This means we cannot execute the
|
||||
// READONLY command against that node -- setting readOnly to false in such
|
||||
// situations in the options below will prevent that from happening.
|
||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
MaintNotificationsConfig: maintNotificationsConfig,
|
||||
PushNotificationProcessor: opt.PushNotificationProcessor,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1664,7 +1692,7 @@ func (c *ClusterClient) processTxPipelineNode(
|
||||
}
|
||||
|
||||
func (c *ClusterClient) processTxPipelineNodeConn(
|
||||
ctx context.Context, _ *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
|
||||
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
|
||||
) error {
|
||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
||||
return writeCmds(wr, cmds)
|
||||
@@ -1682,7 +1710,7 @@ func (c *ClusterClient) processTxPipelineNodeConn(
|
||||
trimmedCmds := cmds[1 : len(cmds)-1]
|
||||
|
||||
if err := c.txPipelineReadQueued(
|
||||
ctx, rd, statusCmd, trimmedCmds, failedCmds,
|
||||
ctx, node, cn, rd, statusCmd, trimmedCmds, failedCmds,
|
||||
); err != nil {
|
||||
setCmdsErr(cmds, err)
|
||||
|
||||
@@ -1694,23 +1722,37 @@ func (c *ClusterClient) processTxPipelineNodeConn(
|
||||
return err
|
||||
}
|
||||
|
||||
return pipelineReadCmds(rd, trimmedCmds)
|
||||
return node.Client.pipelineReadCmds(ctx, cn, rd, trimmedCmds)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClusterClient) txPipelineReadQueued(
|
||||
ctx context.Context,
|
||||
node *clusterNode,
|
||||
cn *pool.Conn,
|
||||
rd *proto.Reader,
|
||||
statusCmd *StatusCmd,
|
||||
cmds []Cmder,
|
||||
failedCmds *cmdsMap,
|
||||
) error {
|
||||
// Parse queued replies.
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
if err := statusCmd.readReply(rd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, cmd := range cmds {
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
err := statusCmd.readReply(rd)
|
||||
if err != nil {
|
||||
if c.checkMovedErr(ctx, cmd, err, failedCmds) {
|
||||
@@ -1724,6 +1766,12 @@ func (c *ClusterClient) txPipelineReadQueued(
|
||||
}
|
||||
}
|
||||
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
// Parse number of replies.
|
||||
line, err := rd.ReadLine()
|
||||
if err != nil {
|
||||
@@ -1829,12 +1877,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
|
||||
return err
|
||||
}
|
||||
|
||||
// maintenance notifications won't work here for now
|
||||
func (c *ClusterClient) pubSub() *PubSub {
|
||||
var node *clusterNode
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt.clientOptions(),
|
||||
|
||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
||||
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||
if node != nil {
|
||||
panic("node != nil")
|
||||
}
|
||||
@@ -1868,18 +1916,25 @@ func (c *ClusterClient) pubSub() *PubSub {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cn, err := node.Client.newConn(context.TODO())
|
||||
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
|
||||
if err != nil {
|
||||
node = nil
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// will return nil if already initialized
|
||||
err = node.Client.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = cn.Close()
|
||||
node = nil
|
||||
return nil, err
|
||||
}
|
||||
node.Client.pubSubPool.TrackConn(cn)
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
err := node.Client.connPool.CloseConn(cn)
|
||||
// Untrack connection from PubSubPool
|
||||
node.Client.pubSubPool.UntrackConn(cn)
|
||||
err := cn.Close()
|
||||
node = nil
|
||||
return err
|
||||
},
|
||||
|
||||
+74
-7
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// PubSub implements Pub/Sub commands as described in
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
type PubSub struct {
|
||||
opt *Options
|
||||
|
||||
newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
|
||||
newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error)
|
||||
closeConn func(*pool.Conn) error
|
||||
|
||||
mu sync.Mutex
|
||||
@@ -38,6 +39,12 @@ type PubSub struct {
|
||||
chOnce sync.Once
|
||||
msgCh *channel
|
||||
allCh *channel
|
||||
|
||||
// Push notification processor for handling generic push notifications
|
||||
pushProcessor push.NotificationProcessor
|
||||
|
||||
// Cleanup callback for maintenanceNotifications upgrade tracking
|
||||
onClose func()
|
||||
}
|
||||
|
||||
func (c *PubSub) init() {
|
||||
@@ -69,10 +76,18 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er
|
||||
return c.cn, nil
|
||||
}
|
||||
|
||||
if c.opt.Addr == "" {
|
||||
// TODO(maintenanceNotifications):
|
||||
// this is probably cluster client
|
||||
// c.newConn will ignore the addr argument
|
||||
// will be changed when we have maintenanceNotifications upgrades for cluster clients
|
||||
c.opt.Addr = internal.RedisNull
|
||||
}
|
||||
|
||||
channels := mapKeys(c.channels)
|
||||
channels = append(channels, newChannels...)
|
||||
|
||||
cn, err := c.newConn(ctx, channels)
|
||||
cn, err := c.newConn(ctx, c.opt.Addr, channels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -153,12 +168,31 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo
|
||||
if c.cn != cn {
|
||||
return
|
||||
}
|
||||
|
||||
if !cn.IsUsable() || cn.ShouldHandoff() {
|
||||
c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable"))
|
||||
}
|
||||
|
||||
if isBadConn(err, allowTimeout, c.opt.Addr) {
|
||||
c.reconnect(ctx, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PubSub) reconnect(ctx context.Context, reason error) {
|
||||
if c.cn != nil && c.cn.ShouldHandoff() {
|
||||
newEndpoint := c.cn.GetHandoffEndpoint()
|
||||
// If new endpoint is NULL, use the original address
|
||||
if newEndpoint == internal.RedisNull {
|
||||
newEndpoint = c.opt.Addr
|
||||
}
|
||||
|
||||
if newEndpoint != "" {
|
||||
// Update the address in the options
|
||||
oldAddr := c.cn.RemoteAddr().String()
|
||||
c.opt.Addr = newEndpoint
|
||||
internal.Logger.Printf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr)
|
||||
}
|
||||
}
|
||||
_ = c.closeTheCn(reason)
|
||||
_, _ = c.conn(ctx, nil)
|
||||
}
|
||||
@@ -167,9 +201,6 @@ func (c *PubSub) closeTheCn(reason error) error {
|
||||
if c.cn == nil {
|
||||
return nil
|
||||
}
|
||||
if !c.closed {
|
||||
internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
|
||||
}
|
||||
err := c.closeConn(c.cn)
|
||||
c.cn = nil
|
||||
return err
|
||||
@@ -185,6 +216,11 @@ func (c *PubSub) Close() error {
|
||||
c.closed = true
|
||||
close(c.exit)
|
||||
|
||||
// Call cleanup callback if set
|
||||
if c.onClose != nil {
|
||||
c.onClose()
|
||||
}
|
||||
|
||||
return c.closeTheCn(pool.ErrClosed)
|
||||
}
|
||||
|
||||
@@ -429,16 +465,20 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
|
||||
}
|
||||
|
||||
// Don't hold the lock to allow subscriptions and pings.
|
||||
|
||||
cn, err := c.connWithLock(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err)
|
||||
}
|
||||
return c.cmd.readReply(rd)
|
||||
})
|
||||
|
||||
c.releaseConnWithLock(ctx, cn, err, timeout > 0)
|
||||
|
||||
if err != nil {
|
||||
@@ -451,6 +491,12 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int
|
||||
// Receive returns a message as a Subscription, Message, Pong or error.
|
||||
// See PubSub example for details. This is low-level API and in most cases
|
||||
// Channel should be used instead.
|
||||
// Receive returns a message as a Subscription, Message, Pong, or an error.
|
||||
// See PubSub example for details. This is a low-level API and in most cases
|
||||
// Channel should be used instead.
|
||||
// This method blocks until a message is received or an error occurs.
|
||||
// It may return early with an error if the context is canceled, the connection fails,
|
||||
// or other internal errors occur.
|
||||
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
|
||||
return c.ReceiveTimeout(ctx, 0)
|
||||
}
|
||||
@@ -532,6 +578,27 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac
|
||||
return c.allCh.allCh
|
||||
}
|
||||
|
||||
func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
|
||||
// Only process push notifications for RESP3 connections with a processor
|
||||
if c.opt.Protocol != 3 || c.pushProcessor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create handler context with client, connection pool, and connection information
|
||||
handlerCtx := c.pushNotificationHandlerContext(cn)
|
||||
return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd)
|
||||
}
|
||||
|
||||
func (c *PubSub) pushNotificationHandlerContext(cn *pool.Conn) push.NotificationHandlerContext {
|
||||
// PubSub doesn't have a client or connection pool, so we pass nil for those
|
||||
// PubSub connections are blocking
|
||||
return push.NotificationHandlerContext{
|
||||
PubSub: c,
|
||||
Conn: cn,
|
||||
IsBlocking: true,
|
||||
}
|
||||
}
|
||||
|
||||
type ChannelOption func(c *channel)
|
||||
|
||||
// WithChannelSize specifies the Go chan size that is used to buffer incoming messages.
|
||||
|
||||
+176
@@ -0,0 +1,176 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Push notification error definitions
|
||||
// This file contains all error types and messages used by the push notification system
|
||||
|
||||
// Error reason constants
|
||||
const (
|
||||
// HandlerReasons
|
||||
ReasonHandlerNil = "handler cannot be nil"
|
||||
ReasonHandlerExists = "cannot overwrite existing handler"
|
||||
ReasonHandlerProtected = "handler is protected"
|
||||
|
||||
// ProcessorReasons
|
||||
ReasonPushNotificationsDisabled = "push notifications are disabled"
|
||||
)
|
||||
|
||||
// ProcessorType represents the type of processor involved in the error
|
||||
// defined as a custom type for better readability and easier maintenance
|
||||
type ProcessorType string
|
||||
|
||||
const (
|
||||
// ProcessorTypes
|
||||
ProcessorTypeProcessor = ProcessorType("processor")
|
||||
ProcessorTypeVoidProcessor = ProcessorType("void_processor")
|
||||
ProcessorTypeCustom = ProcessorType("custom")
|
||||
)
|
||||
|
||||
// ProcessorOperation represents the operation being performed by the processor
|
||||
// defined as a custom type for better readability and easier maintenance
|
||||
type ProcessorOperation string
|
||||
|
||||
const (
|
||||
// ProcessorOperations
|
||||
ProcessorOperationProcess = ProcessorOperation("process")
|
||||
ProcessorOperationRegister = ProcessorOperation("register")
|
||||
ProcessorOperationUnregister = ProcessorOperation("unregister")
|
||||
ProcessorOperationUnknown = ProcessorOperation("unknown")
|
||||
)
|
||||
|
||||
// Common error variables for reuse
|
||||
var (
|
||||
// ErrHandlerNil is returned when attempting to register a nil handler
|
||||
ErrHandlerNil = errors.New(ReasonHandlerNil)
|
||||
)
|
||||
|
||||
// Registry errors
|
||||
|
||||
// ErrHandlerExists creates an error for when attempting to overwrite an existing handler
|
||||
func ErrHandlerExists(pushNotificationName string) error {
|
||||
return NewHandlerError(ProcessorOperationRegister, pushNotificationName, ReasonHandlerExists, nil)
|
||||
}
|
||||
|
||||
// ErrProtectedHandler creates an error for when attempting to unregister a protected handler
|
||||
func ErrProtectedHandler(pushNotificationName string) error {
|
||||
return NewHandlerError(ProcessorOperationUnregister, pushNotificationName, ReasonHandlerProtected, nil)
|
||||
}
|
||||
|
||||
// VoidProcessor errors
|
||||
|
||||
// ErrVoidProcessorRegister creates an error for when attempting to register a handler on void processor
|
||||
func ErrVoidProcessorRegister(pushNotificationName string) error {
|
||||
return NewProcessorError(ProcessorTypeVoidProcessor, ProcessorOperationRegister, pushNotificationName, ReasonPushNotificationsDisabled, nil)
|
||||
}
|
||||
|
||||
// ErrVoidProcessorUnregister creates an error for when attempting to unregister a handler on void processor
|
||||
func ErrVoidProcessorUnregister(pushNotificationName string) error {
|
||||
return NewProcessorError(ProcessorTypeVoidProcessor, ProcessorOperationUnregister, pushNotificationName, ReasonPushNotificationsDisabled, nil)
|
||||
}
|
||||
|
||||
// Error type definitions for advanced error handling
|
||||
|
||||
// HandlerError represents errors related to handler operations
|
||||
type HandlerError struct {
|
||||
Operation ProcessorOperation
|
||||
PushNotificationName string
|
||||
Reason string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *HandlerError) Error() string {
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("handler %s failed for '%s': %s (%v)", e.Operation, e.PushNotificationName, e.Reason, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("handler %s failed for '%s': %s", e.Operation, e.PushNotificationName, e.Reason)
|
||||
}
|
||||
|
||||
func (e *HandlerError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// NewHandlerError creates a new HandlerError
|
||||
func NewHandlerError(operation ProcessorOperation, pushNotificationName, reason string, err error) *HandlerError {
|
||||
return &HandlerError{
|
||||
Operation: operation,
|
||||
PushNotificationName: pushNotificationName,
|
||||
Reason: reason,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessorError represents errors related to processor operations
|
||||
type ProcessorError struct {
|
||||
ProcessorType ProcessorType // "processor", "void_processor"
|
||||
Operation ProcessorOperation // "process", "register", "unregister"
|
||||
PushNotificationName string // Name of the push notification involved
|
||||
Reason string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ProcessorError) Error() string {
|
||||
notifInfo := ""
|
||||
if e.PushNotificationName != "" {
|
||||
notifInfo = fmt.Sprintf(" for '%s'", e.PushNotificationName)
|
||||
}
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("%s %s failed%s: %s (%v)", e.ProcessorType, e.Operation, notifInfo, e.Reason, e.Err)
|
||||
}
|
||||
return fmt.Sprintf("%s %s failed%s: %s", e.ProcessorType, e.Operation, notifInfo, e.Reason)
|
||||
}
|
||||
|
||||
func (e *ProcessorError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// NewProcessorError creates a new ProcessorError
|
||||
func NewProcessorError(processorType ProcessorType, operation ProcessorOperation, pushNotificationName, reason string, err error) *ProcessorError {
|
||||
return &ProcessorError{
|
||||
ProcessorType: processorType,
|
||||
Operation: operation,
|
||||
PushNotificationName: pushNotificationName,
|
||||
Reason: reason,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for common error scenarios
|
||||
|
||||
// IsHandlerNilError checks if an error is due to a nil handler
|
||||
func IsHandlerNilError(err error) bool {
|
||||
return errors.Is(err, ErrHandlerNil)
|
||||
}
|
||||
|
||||
// IsHandlerExistsError checks if an error is due to attempting to overwrite an existing handler.
|
||||
// This function works correctly even when the error is wrapped.
|
||||
func IsHandlerExistsError(err error) bool {
|
||||
var handlerErr *HandlerError
|
||||
if errors.As(err, &handlerErr) {
|
||||
return handlerErr.Operation == ProcessorOperationRegister && handlerErr.Reason == ReasonHandlerExists
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsProtectedHandlerError checks if an error is due to attempting to unregister a protected handler.
|
||||
// This function works correctly even when the error is wrapped.
|
||||
func IsProtectedHandlerError(err error) bool {
|
||||
var handlerErr *HandlerError
|
||||
if errors.As(err, &handlerErr) {
|
||||
return handlerErr.Operation == ProcessorOperationUnregister && handlerErr.Reason == ReasonHandlerProtected
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsVoidProcessorError checks if an error is due to void processor operations.
|
||||
// This function works correctly even when the error is wrapped.
|
||||
func IsVoidProcessorError(err error) bool {
|
||||
var procErr *ProcessorError
|
||||
if errors.As(err, &procErr) {
|
||||
return procErr.ProcessorType == ProcessorTypeVoidProcessor && procErr.Reason == ReasonPushNotificationsDisabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
+14
@@ -0,0 +1,14 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// NotificationHandler defines the interface for push notification handlers.
|
||||
type NotificationHandler interface {
|
||||
// HandlePushNotification processes a push notification with context information.
|
||||
// The handlerCtx provides information about the client, connection pool, and connection
|
||||
// on which the notification was received, allowing handlers to make informed decisions.
|
||||
// Returns an error if the notification could not be handled.
|
||||
HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error
|
||||
}
|
||||
+44
@@ -0,0 +1,44 @@
|
||||
package push
|
||||
|
||||
// No imports needed for this file
|
||||
|
||||
// NotificationHandlerContext provides context information about where a push notification was received.
|
||||
// This struct allows handlers to make informed decisions based on the source of the notification
|
||||
// with strongly typed access to different client types using concrete types.
|
||||
type NotificationHandlerContext struct {
|
||||
// Client is the Redis client instance that received the notification.
|
||||
// It is interface to both allow for future expansion and to avoid
|
||||
// circular dependencies. The developer is responsible for type assertion.
|
||||
// It can be one of the following types:
|
||||
// - *redis.baseClient
|
||||
// - *redis.Client
|
||||
// - *redis.ClusterClient
|
||||
// - *redis.Conn
|
||||
Client interface{}
|
||||
|
||||
// ConnPool is the connection pool from which the connection was obtained.
|
||||
// It is interface to both allow for future expansion and to avoid
|
||||
// circular dependencies. The developer is responsible for type assertion.
|
||||
// It can be one of the following types:
|
||||
// - *pool.ConnPool
|
||||
// - *pool.SingleConnPool
|
||||
// - *pool.StickyConnPool
|
||||
ConnPool interface{}
|
||||
|
||||
// PubSub is the PubSub instance that received the notification.
|
||||
// It is interface to both allow for future expansion and to avoid
|
||||
// circular dependencies. The developer is responsible for type assertion.
|
||||
// It can be one of the following types:
|
||||
// - *redis.PubSub
|
||||
PubSub interface{}
|
||||
|
||||
// Conn is the specific connection on which the notification was received.
|
||||
// It is interface to both allow for future expansion and to avoid
|
||||
// circular dependencies. The developer is responsible for type assertion.
|
||||
// It can be one of the following types:
|
||||
// - *pool.Conn
|
||||
Conn interface{}
|
||||
|
||||
// IsBlocking indicates if the notification was received on a blocking connection.
|
||||
IsBlocking bool
|
||||
}
|
||||
+203
@@ -0,0 +1,203 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
)
|
||||
|
||||
// NotificationProcessor defines the interface for push notification processors.
|
||||
type NotificationProcessor interface {
|
||||
// GetHandler returns the handler for a specific push notification name.
|
||||
GetHandler(pushNotificationName string) NotificationHandler
|
||||
// ProcessPendingNotifications checks for and processes any pending push notifications.
|
||||
// To be used when it is known that there are notifications on the socket.
|
||||
// It will try to read from the socket and if it is empty - it may block.
|
||||
ProcessPendingNotifications(ctx context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error
|
||||
// RegisterHandler registers a handler for a specific push notification name.
|
||||
RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error
|
||||
// UnregisterHandler removes a handler for a specific push notification name.
|
||||
UnregisterHandler(pushNotificationName string) error
|
||||
}
|
||||
|
||||
// Processor handles push notifications with a registry of handlers
|
||||
type Processor struct {
|
||||
registry *Registry
|
||||
}
|
||||
|
||||
// NewProcessor creates a new push notification processor
|
||||
func NewProcessor() *Processor {
|
||||
return &Processor{
|
||||
registry: NewRegistry(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetHandler returns the handler for a specific push notification name
|
||||
func (p *Processor) GetHandler(pushNotificationName string) NotificationHandler {
|
||||
return p.registry.GetHandler(pushNotificationName)
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a specific push notification name
|
||||
func (p *Processor) RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error {
|
||||
return p.registry.RegisterHandler(pushNotificationName, handler, protected)
|
||||
}
|
||||
|
||||
// UnregisterHandler removes a handler for a specific push notification name
|
||||
func (p *Processor) UnregisterHandler(pushNotificationName string) error {
|
||||
return p.registry.UnregisterHandler(pushNotificationName)
|
||||
}
|
||||
|
||||
// ProcessPendingNotifications checks for and processes any pending push notifications
|
||||
// This method should be called by the client in WithReader before reading the reply
|
||||
// It will try to read from the socket and if it is empty - it may block.
|
||||
func (p *Processor) ProcessPendingNotifications(ctx context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error {
|
||||
if rd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
// Check if there's data available to read
|
||||
replyType, err := rd.PeekReplyType()
|
||||
if err != nil {
|
||||
// No more data available or error reading
|
||||
// if timeout, it will be handled by the caller
|
||||
break
|
||||
}
|
||||
|
||||
// Only process push notifications (arrays starting with >)
|
||||
if replyType != proto.RespPush {
|
||||
break
|
||||
}
|
||||
|
||||
// see if we should skip this notification
|
||||
notificationName, err := rd.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if willHandleNotificationInClient(notificationName) {
|
||||
break
|
||||
}
|
||||
|
||||
// Read the push notification
|
||||
reply, err := rd.ReadReply()
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error reading push notification: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
// Convert to slice of interfaces
|
||||
notification, ok := reply.([]interface{})
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
// Handle the notification directly
|
||||
if len(notification) > 0 {
|
||||
// Extract the notification type (first element)
|
||||
if notificationType, ok := notification[0].(string); ok {
|
||||
// Get the handler for this notification type
|
||||
if handler := p.registry.GetHandler(notificationType); handler != nil {
|
||||
// Handle the notification
|
||||
err := handler.HandlePushNotification(ctx, handlerCtx, notification)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error handling push notification: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VoidProcessor discards all push notifications without processing them
|
||||
type VoidProcessor struct{}
|
||||
|
||||
// NewVoidProcessor creates a new void push notification processor
|
||||
func NewVoidProcessor() *VoidProcessor {
|
||||
return &VoidProcessor{}
|
||||
}
|
||||
|
||||
// GetHandler returns nil for void processor since it doesn't maintain handlers
|
||||
func (v *VoidProcessor) GetHandler(_ string) NotificationHandler {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterHandler returns an error for void processor since it doesn't maintain handlers
|
||||
func (v *VoidProcessor) RegisterHandler(pushNotificationName string, _ NotificationHandler, _ bool) error {
|
||||
return ErrVoidProcessorRegister(pushNotificationName)
|
||||
}
|
||||
|
||||
// UnregisterHandler returns an error for void processor since it doesn't maintain handlers
|
||||
func (v *VoidProcessor) UnregisterHandler(pushNotificationName string) error {
|
||||
return ErrVoidProcessorUnregister(pushNotificationName)
|
||||
}
|
||||
|
||||
// ProcessPendingNotifications for VoidProcessor does nothing since push notifications
|
||||
// are only available in RESP3 and this processor is used for RESP2 connections.
|
||||
// This avoids unnecessary buffer scanning overhead.
|
||||
// It does however read and discard all push notifications from the buffer to avoid
|
||||
// them being interpreted as a reply.
|
||||
// This method should be called by the client in WithReader before reading the reply
|
||||
// to be sure there are no buffered push notifications.
|
||||
// It will try to read from the socket and if it is empty - it may block.
|
||||
func (v *VoidProcessor) ProcessPendingNotifications(_ context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error {
|
||||
// read and discard all push notifications
|
||||
if rd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
// Check if there's data available to read
|
||||
replyType, err := rd.PeekReplyType()
|
||||
if err != nil {
|
||||
// No more data available or error reading
|
||||
// if timeout, it will be handled by the caller
|
||||
break
|
||||
}
|
||||
|
||||
// Only process push notifications (arrays starting with >)
|
||||
if replyType != proto.RespPush {
|
||||
break
|
||||
}
|
||||
// see if we should skip this notification
|
||||
notificationName, err := rd.PeekPushNotificationName()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if willHandleNotificationInClient(notificationName) {
|
||||
break
|
||||
}
|
||||
|
||||
// Read the push notification
|
||||
_, err = rd.ReadReply()
|
||||
if err != nil {
|
||||
internal.Logger.Printf(context.Background(), "push: error reading push notification: %v", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// willHandleNotificationInClient checks if a notification type should be ignored by the push notification
|
||||
// processor and handled by other specialized systems instead (pub/sub, streams, keyspace, etc.).
|
||||
func willHandleNotificationInClient(notificationType string) bool {
|
||||
switch notificationType {
|
||||
// Pub/Sub notifications - handled by pub/sub system
|
||||
case "message", // Regular pub/sub message
|
||||
"pmessage", // Pattern pub/sub message
|
||||
"subscribe", // Subscription confirmation
|
||||
"unsubscribe", // Unsubscription confirmation
|
||||
"psubscribe", // Pattern subscription confirmation
|
||||
"punsubscribe", // Pattern unsubscription confirmation
|
||||
"smessage", // Sharded pub/sub message (Redis 7.0+)
|
||||
"ssubscribe", // Sharded subscription confirmation
|
||||
"sunsubscribe": // Sharded unsubscription confirmation
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
+7
@@ -0,0 +1,7 @@
|
||||
// Package push provides push notifications for Redis.
|
||||
// This is an EXPERIMENTAL API for handling push notifications from Redis.
|
||||
// It is not yet stable and may change in the future.
|
||||
// Although this is in a public package, in its current form public use is not advised.
|
||||
// Pending push notifications should be processed before executing any readReply from the connection
|
||||
// as per RESP3 specification push notifications can be sent at any time.
|
||||
package push
|
||||
+61
@@ -0,0 +1,61 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Registry manages push notification handlers
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
handlers map[string]NotificationHandler
|
||||
protected map[string]bool
|
||||
}
|
||||
|
||||
// NewRegistry creates a new push notification registry
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
handlers: make(map[string]NotificationHandler),
|
||||
protected: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a specific push notification name
|
||||
func (r *Registry) RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error {
|
||||
if handler == nil {
|
||||
return ErrHandlerNil
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Check if handler already exists
|
||||
if _, exists := r.protected[pushNotificationName]; exists {
|
||||
return ErrHandlerExists(pushNotificationName)
|
||||
}
|
||||
|
||||
r.handlers[pushNotificationName] = handler
|
||||
r.protected[pushNotificationName] = protected
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHandler returns the handler for a specific push notification name
|
||||
func (r *Registry) GetHandler(pushNotificationName string) NotificationHandler {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.handlers[pushNotificationName]
|
||||
}
|
||||
|
||||
// UnregisterHandler removes a handler for a specific push notification name
|
||||
func (r *Registry) UnregisterHandler(pushNotificationName string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Check if handler is protected
|
||||
if protected, exists := r.protected[pushNotificationName]; exists && protected {
|
||||
return ErrProtectedHandler(pushNotificationName)
|
||||
}
|
||||
|
||||
delete(r.handlers, pushNotificationName)
|
||||
delete(r.protected, pushNotificationName)
|
||||
return nil
|
||||
}
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// NewPushNotificationProcessor creates a new push notification processor
|
||||
// This processor maintains a registry of handlers and processes push notifications
|
||||
// It is used for RESP3 connections where push notifications are available
|
||||
func NewPushNotificationProcessor() push.NotificationProcessor {
|
||||
return push.NewProcessor()
|
||||
}
|
||||
|
||||
// NewVoidPushNotificationProcessor creates a new void push notification processor
|
||||
// This processor does not maintain any handlers and always returns nil for all operations
|
||||
// It is used for RESP2 connections where push notifications are not available
|
||||
// It can also be used to disable push notifications for RESP3 connections, where
|
||||
// it will discard all push notifications without processing them
|
||||
func NewVoidPushNotificationProcessor() push.NotificationProcessor {
|
||||
return push.NewVoidProcessor()
|
||||
}
|
||||
+557
-77
@@ -11,9 +11,12 @@ import (
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/auth/streaming"
|
||||
"github.com/redis/go-redis/v9/internal/hscan"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Scanner internal/hscan.Scanner exposed interface.
|
||||
@@ -23,10 +26,16 @@ type Scanner = hscan.Scanner
|
||||
const Nil = proto.Nil
|
||||
|
||||
// SetLogger set custom log
|
||||
// Use with VoidLogger to disable logging.
|
||||
func SetLogger(logger internal.Logging) {
|
||||
internal.Logger = logger
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the library.
|
||||
func SetLogLevel(logLevel internal.LogLevelT) {
|
||||
internal.LogLevel = logLevel
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type Hook interface {
|
||||
@@ -202,16 +211,39 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type baseClient struct {
|
||||
opt *Options
|
||||
connPool pool.Pooler
|
||||
opt *Options
|
||||
optLock sync.RWMutex
|
||||
connPool pool.Pooler
|
||||
pubSubPool *pool.PubSubPool
|
||||
hooksMixin
|
||||
|
||||
onClose func() error // hook called when client is closed
|
||||
|
||||
// Push notification processing
|
||||
pushProcessor push.NotificationProcessor
|
||||
|
||||
// Maintenance notifications manager
|
||||
maintNotificationsManager *maintnotifications.Manager
|
||||
maintNotificationsManagerLock sync.RWMutex
|
||||
|
||||
// streamingCredentialsManager is used to manage streaming credentials
|
||||
streamingCredentialsManager *streaming.Manager
|
||||
}
|
||||
|
||||
func (c *baseClient) clone() *baseClient {
|
||||
clone := *c
|
||||
return &clone
|
||||
c.maintNotificationsManagerLock.RLock()
|
||||
maintNotificationsManager := c.maintNotificationsManager
|
||||
c.maintNotificationsManagerLock.RUnlock()
|
||||
|
||||
clone := &baseClient{
|
||||
opt: c.opt,
|
||||
connPool: c.connPool,
|
||||
onClose: c.onClose,
|
||||
pushProcessor: c.pushProcessor,
|
||||
maintNotificationsManager: maintNotificationsManager,
|
||||
streamingCredentialsManager: c.streamingCredentialsManager,
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
|
||||
@@ -229,21 +261,6 @@ func (c *baseClient) String() string {
|
||||
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
|
||||
}
|
||||
|
||||
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
|
||||
cn, err := c.connPool.NewConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = c.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = c.connPool.CloseConn(cn)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
if c.opt.Limiter != nil {
|
||||
err := c.opt.Limiter.Allow()
|
||||
@@ -269,7 +286,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cn.Inited {
|
||||
if cn.IsInited() {
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
@@ -281,35 +298,39 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// initConn will transition to IDLE state, so we need to acquire it
|
||||
// before returning it to the user.
|
||||
if !cn.TryAcquire() {
|
||||
return nil, fmt.Errorf("redis: connection is not usable")
|
||||
}
|
||||
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener {
|
||||
return auth.NewReAuthCredentialsListener(
|
||||
c.reAuthConnection(poolCn),
|
||||
c.onAuthenticationErr(poolCn),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error {
|
||||
return func(credentials auth.Credentials) error {
|
||||
func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error {
|
||||
return func(poolCn *pool.Conn, credentials auth.Credentials) error {
|
||||
var err error
|
||||
username, password := credentials.BasicAuth()
|
||||
|
||||
// Use background context - timeout is handled by ReadTimeout in WithReader/WithWriter
|
||||
ctx := context.Background()
|
||||
|
||||
connPool := pool.NewSingleConnPool(c.connPool, poolCn)
|
||||
// hooksMixin are intentionally empty here
|
||||
cn := newConn(c.opt, connPool, nil)
|
||||
|
||||
// Pass hooks so that reauth commands are recorded/traced
|
||||
cn := newConn(c.opt, connPool, &c.hooksMixin)
|
||||
|
||||
if username != "" {
|
||||
err = cn.AuthACL(ctx, username, password).Err()
|
||||
} else {
|
||||
err = cn.Auth(ctx, password).Err()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) {
|
||||
return func(err error) {
|
||||
func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) {
|
||||
return func(poolCn *pool.Conn, err error) {
|
||||
if err != nil {
|
||||
if isBadConn(err, false, c.opt.Addr) {
|
||||
// Close the connection to force a reconnection.
|
||||
@@ -351,29 +372,113 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error {
|
||||
}
|
||||
|
||||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
if cn.Inited {
|
||||
// This function is called in two scenarios:
|
||||
// 1. First-time init: Connection is in CREATED state (from pool.Get())
|
||||
// - We need to transition CREATED → INITIALIZING and do the initialization
|
||||
// - If another goroutine is already initializing, we WAIT for it to finish
|
||||
// 2. Re-initialization: Connection is in INITIALIZING state (from SetNetConnAndInitConn())
|
||||
// - We're already in INITIALIZING, so just proceed with initialization
|
||||
|
||||
currentState := cn.GetStateMachine().GetState()
|
||||
|
||||
// Fast path: Check if already initialized (IDLE or IN_USE)
|
||||
if currentState == pool.StateIdle || currentState == pool.StateInUse {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
cn.Inited = true
|
||||
// If in CREATED state, try to transition to INITIALIZING
|
||||
if currentState == pool.StateCreated {
|
||||
finalState, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateCreated}, pool.StateInitializing)
|
||||
if err != nil {
|
||||
// Another goroutine is initializing or connection is in unexpected state
|
||||
// Check what state we're in now
|
||||
if finalState == pool.StateIdle || finalState == pool.StateInUse {
|
||||
// Already initialized by another goroutine
|
||||
return nil
|
||||
}
|
||||
|
||||
if finalState == pool.StateInitializing {
|
||||
// Another goroutine is initializing - WAIT for it to complete
|
||||
// Use a context with timeout = min(remaining command timeout, DialTimeout)
|
||||
// This prevents waiting too long while respecting the caller's deadline
|
||||
var waitCtx context.Context
|
||||
var cancel context.CancelFunc
|
||||
dialTimeout := c.opt.DialTimeout
|
||||
|
||||
if cmdDeadline, hasCmdDeadline := ctx.Deadline(); hasCmdDeadline {
|
||||
// Calculate remaining time until command deadline
|
||||
remainingTime := time.Until(cmdDeadline)
|
||||
// Use the minimum of remaining time and DialTimeout
|
||||
if remainingTime < dialTimeout {
|
||||
// Command deadline is sooner, use it
|
||||
waitCtx = ctx
|
||||
} else {
|
||||
// DialTimeout is shorter, cap the wait at DialTimeout
|
||||
waitCtx, cancel = context.WithTimeout(ctx, dialTimeout)
|
||||
}
|
||||
} else {
|
||||
// No command deadline, use DialTimeout to prevent waiting indefinitely
|
||||
waitCtx, cancel = context.WithTimeout(ctx, dialTimeout)
|
||||
}
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
finalState, err := cn.GetStateMachine().AwaitAndTransition(
|
||||
waitCtx,
|
||||
[]pool.ConnState{pool.StateIdle, pool.StateInUse},
|
||||
pool.StateIdle, // Target is IDLE (but we're already there, so this is a no-op)
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Verify we're now initialized
|
||||
if finalState == pool.StateIdle || finalState == pool.StateInUse {
|
||||
return nil
|
||||
}
|
||||
// Unexpected state after waiting
|
||||
return fmt.Errorf("connection in unexpected state after initialization: %s", finalState)
|
||||
}
|
||||
|
||||
// Unexpected state (CLOSED, UNUSABLE, etc.)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// At this point, we're in INITIALIZING state and we own the initialization
|
||||
// If we fail, we must transition to CLOSED
|
||||
var initErr error
|
||||
connPool := pool.NewSingleConnPool(c.connPool, cn)
|
||||
conn := newConn(c.opt, connPool, &c.hooksMixin)
|
||||
|
||||
username, password := "", ""
|
||||
if c.opt.StreamingCredentialsProvider != nil {
|
||||
credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider.
|
||||
Subscribe(c.newReAuthCredentialsListener(cn))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", err)
|
||||
credListener, initErr := c.streamingCredentialsManager.Listener(
|
||||
cn,
|
||||
c.reAuthConnection(),
|
||||
c.onAuthenticationErr(),
|
||||
)
|
||||
if initErr != nil {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to create credentials listener: %w", initErr)
|
||||
}
|
||||
|
||||
credentials, unsubscribeFromCredentialsProvider, initErr := c.opt.StreamingCredentialsProvider.
|
||||
Subscribe(credListener)
|
||||
if initErr != nil {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to subscribe to streaming credentials: %w", initErr)
|
||||
}
|
||||
|
||||
c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider)
|
||||
cn.SetOnClose(unsubscribeFromCredentialsProvider)
|
||||
|
||||
username, password = credentials.BasicAuth()
|
||||
} else if c.opt.CredentialsProviderContext != nil {
|
||||
username, password, err = c.opt.CredentialsProviderContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get credentials from context provider: %w", err)
|
||||
username, password, initErr = c.opt.CredentialsProviderContext(ctx)
|
||||
if initErr != nil {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to get credentials from context provider: %w", initErr)
|
||||
}
|
||||
} else if c.opt.CredentialsProvider != nil {
|
||||
username, password = c.opt.CredentialsProvider()
|
||||
@@ -383,9 +488,9 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
|
||||
// for redis-server versions that do not support the HELLO command,
|
||||
// RESP2 will continue to be used.
|
||||
if err = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); err == nil {
|
||||
if initErr = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); initErr == nil {
|
||||
// Authentication successful with HELLO command
|
||||
} else if !isRedisError(err) {
|
||||
} else if !isRedisError(initErr) {
|
||||
// When the server responds with the RESP protocol and the result is not a normal
|
||||
// execution result of the HELLO command, we consider it to be an indication that
|
||||
// the server does not support the HELLO command.
|
||||
@@ -393,20 +498,22 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
// or it could be DragonflyDB or a third-party redis-proxy. They all respond
|
||||
// with different error string results for unsupported commands, making it
|
||||
// difficult to rely on error strings to determine all results.
|
||||
return err
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return initErr
|
||||
} else if password != "" {
|
||||
// Try legacy AUTH command if HELLO failed
|
||||
if username != "" {
|
||||
err = conn.AuthACL(ctx, username, password).Err()
|
||||
initErr = conn.AuthACL(ctx, username, password).Err()
|
||||
} else {
|
||||
err = conn.Auth(ctx, password).Err()
|
||||
initErr = conn.Auth(ctx, password).Err()
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to authenticate: %w", err)
|
||||
if initErr != nil {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to authenticate: %w", initErr)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
||||
_, initErr = conn.Pipelined(ctx, func(pipe Pipeliner) error {
|
||||
if c.opt.DB > 0 {
|
||||
pipe.Select(ctx, c.opt.DB)
|
||||
}
|
||||
@@ -421,8 +528,58 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize connection options: %w", err)
|
||||
if initErr != nil {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to initialize connection options: %w", initErr)
|
||||
}
|
||||
|
||||
// Enable maintnotifications if maintnotifications are configured
|
||||
c.optLock.RLock()
|
||||
maintNotifEnabled := c.opt.MaintNotificationsConfig != nil && c.opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled
|
||||
protocol := c.opt.Protocol
|
||||
endpointType := c.opt.MaintNotificationsConfig.EndpointType
|
||||
c.optLock.RUnlock()
|
||||
var maintNotifHandshakeErr error
|
||||
if maintNotifEnabled && protocol == 3 {
|
||||
maintNotifHandshakeErr = conn.ClientMaintNotifications(
|
||||
ctx,
|
||||
true,
|
||||
endpointType.String(),
|
||||
).Err()
|
||||
if maintNotifHandshakeErr != nil {
|
||||
if !isRedisError(maintNotifHandshakeErr) {
|
||||
// if not redis error, fail the connection
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return maintNotifHandshakeErr
|
||||
}
|
||||
c.optLock.Lock()
|
||||
// handshake failed - check and modify config atomically
|
||||
switch c.opt.MaintNotificationsConfig.Mode {
|
||||
case maintnotifications.ModeEnabled:
|
||||
// enabled mode, fail the connection
|
||||
c.optLock.Unlock()
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr)
|
||||
default: // will handle auto and any other
|
||||
// Disabling logging here as it's too noisy.
|
||||
// TODO: Enable when we have a better logging solution for log levels
|
||||
// internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr)
|
||||
c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled
|
||||
c.optLock.Unlock()
|
||||
// auto mode, disable maintnotifications and continue
|
||||
if initErr := c.disableMaintNotificationsUpgrades(); initErr != nil {
|
||||
// Log error but continue - auto mode should be resilient
|
||||
internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", initErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// handshake was executed successfully
|
||||
// to make sure that the handshake will be executed on other connections as well if it was successfully
|
||||
// executed on this connection, we will force the handshake to be executed on all connections
|
||||
c.optLock.Lock()
|
||||
c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeEnabled
|
||||
c.optLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if !c.opt.DisableIdentity && !c.opt.DisableIndentity {
|
||||
@@ -436,13 +593,31 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
|
||||
p.ClientSetInfo(ctx, WithLibraryVersion(libVer))
|
||||
// Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid
|
||||
// out of order responses later on.
|
||||
if _, err = p.Exec(ctx); err != nil && !isRedisError(err) {
|
||||
return err
|
||||
if _, initErr = p.Exec(ctx); initErr != nil && !isRedisError(initErr) {
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return initErr
|
||||
}
|
||||
}
|
||||
|
||||
// Set the connection initialization function for potential reconnections
|
||||
// This must be set before transitioning to IDLE so that handoff/reauth can use it
|
||||
cn.SetInitConnFunc(c.createInitConnFunc())
|
||||
|
||||
// Initialization succeeded - transition to IDLE state
|
||||
// This marks the connection as initialized and ready for use
|
||||
// NOTE: The connection is still owned by the calling goroutine at this point
|
||||
// and won't be available to other goroutines until it's Put() back into the pool
|
||||
cn.GetStateMachine().Transition(pool.StateIdle)
|
||||
|
||||
// Call OnConnect hook if configured
|
||||
// The connection is in IDLE state but still owned by this goroutine
|
||||
// If OnConnect needs to send commands, it can use the connection safely
|
||||
if c.opt.OnConnect != nil {
|
||||
return c.opt.OnConnect(ctx, conn)
|
||||
if initErr = c.opt.OnConnect(ctx, conn); initErr != nil {
|
||||
// OnConnect failed - transition to closed
|
||||
cn.GetStateMachine().Transition(pool.StateClosed)
|
||||
return initErr
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -456,6 +631,10 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)
|
||||
if isBadConn(err, false, c.opt.Addr) {
|
||||
c.connPool.Remove(ctx, cn, err)
|
||||
} else {
|
||||
// process any pending push notifications before returning the connection to the pool
|
||||
if err := c.processPushNotifications(ctx, cn); err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err)
|
||||
}
|
||||
c.connPool.Put(ctx, cn)
|
||||
}
|
||||
}
|
||||
@@ -497,16 +676,16 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (c *baseClient) assertUnstableCommand(cmd Cmder) bool {
|
||||
func (c *baseClient) assertUnstableCommand(cmd Cmder) (bool, error) {
|
||||
switch cmd.(type) {
|
||||
case *AggregateCmd, *FTInfoCmd, *FTSpellCheckCmd, *FTSearchCmd, *FTSynDumpCmd:
|
||||
if c.opt.UnstableResp3 {
|
||||
return true
|
||||
return true, nil
|
||||
} else {
|
||||
panic("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3 . See the [README](https://github.com/redis/go-redis/blob/master/README.md) and the release notes for guidance.")
|
||||
return false, fmt.Errorf("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3. See the README and the release notes for guidance")
|
||||
}
|
||||
default:
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -519,6 +698,11 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
|
||||
|
||||
retryTimeout := uint32(0)
|
||||
if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
|
||||
// Process any pending push notifications before executing the command
|
||||
if err := c.processPushNotifications(ctx, cn); err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err)
|
||||
}
|
||||
|
||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
||||
return writeCmd(wr, cmd)
|
||||
}); err != nil {
|
||||
@@ -527,10 +711,22 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool
|
||||
}
|
||||
readReplyFunc := cmd.readReply
|
||||
// Apply unstable RESP3 search module.
|
||||
if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) {
|
||||
readReplyFunc = cmd.readRawReply
|
||||
if c.opt.Protocol != 2 {
|
||||
useRawReply, err := c.assertUnstableCommand(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if useRawReply {
|
||||
readReplyFunc = cmd.readRawReply
|
||||
}
|
||||
}
|
||||
if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), readReplyFunc); err != nil {
|
||||
if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error {
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
return readReplyFunc(rd)
|
||||
}); err != nil {
|
||||
if cmd.readTimeout() == nil {
|
||||
atomic.StoreUint32(&retryTimeout, 1)
|
||||
} else {
|
||||
@@ -573,19 +769,76 @@ func (c *baseClient) context(ctx context.Context) context.Context {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// createInitConnFunc creates a connection initialization function that can be used for reconnections.
|
||||
func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error {
|
||||
return func(ctx context.Context, cn *pool.Conn) error {
|
||||
return c.initConn(ctx, cn)
|
||||
}
|
||||
}
|
||||
|
||||
// enableMaintNotificationsUpgrades initializes the maintnotifications upgrade manager and pool hook.
|
||||
// This function is called during client initialization.
|
||||
// will register push notification handlers for all maintenance upgrade events.
|
||||
// will start background workers for handoff processing in the pool hook.
|
||||
func (c *baseClient) enableMaintNotificationsUpgrades() error {
|
||||
// Create client adapter
|
||||
clientAdapterInstance := newClientAdapter(c)
|
||||
|
||||
// Create maintnotifications manager directly
|
||||
manager, err := maintnotifications.NewManager(clientAdapterInstance, c.connPool, c.opt.MaintNotificationsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Set the manager reference and initialize pool hook
|
||||
c.maintNotificationsManagerLock.Lock()
|
||||
c.maintNotificationsManager = manager
|
||||
c.maintNotificationsManagerLock.Unlock()
|
||||
|
||||
// Initialize pool hook (safe to call without lock since manager is now set)
|
||||
manager.InitPoolHook(c.dialHook)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *baseClient) disableMaintNotificationsUpgrades() error {
|
||||
c.maintNotificationsManagerLock.Lock()
|
||||
defer c.maintNotificationsManagerLock.Unlock()
|
||||
|
||||
// Close the maintnotifications manager
|
||||
if c.maintNotificationsManager != nil {
|
||||
// Closing the manager will also shutdown the pool hook
|
||||
// and remove it from the pool
|
||||
c.maintNotificationsManager.Close()
|
||||
c.maintNotificationsManager = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the client, releasing any open resources.
|
||||
//
|
||||
// It is rare to Close a Client, as the Client is meant to be
|
||||
// long-lived and shared between many goroutines.
|
||||
func (c *baseClient) Close() error {
|
||||
var firstErr error
|
||||
|
||||
// Close maintnotifications manager first
|
||||
if err := c.disableMaintNotificationsUpgrades(); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
|
||||
if c.onClose != nil {
|
||||
if err := c.onClose(); err != nil {
|
||||
if err := c.onClose(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if err := c.connPool.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
if c.connPool != nil {
|
||||
if err := c.connPool.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if c.pubSubPool != nil {
|
||||
if err := c.pubSubPool.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
@@ -625,12 +878,19 @@ func (c *baseClient) generalProcessPipeline(
|
||||
// Enable retries by default to retry dial errors returned by withConn.
|
||||
canRetry := true
|
||||
lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
|
||||
// Process any pending push notifications before executing the pipeline
|
||||
if err := c.processPushNotifications(ctx, cn); err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before processing pipeline: %v", err)
|
||||
}
|
||||
var err error
|
||||
canRetry, err = p(ctx, cn, cmds)
|
||||
return err
|
||||
})
|
||||
if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) {
|
||||
setCmdsErr(cmds, lastErr)
|
||||
// The error should be set here only when failing to obtain the conn.
|
||||
if !isRedisError(lastErr) {
|
||||
setCmdsErr(cmds, lastErr)
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
}
|
||||
@@ -640,6 +900,11 @@ func (c *baseClient) generalProcessPipeline(
|
||||
func (c *baseClient) pipelineProcessCmds(
|
||||
ctx context.Context, cn *pool.Conn, cmds []Cmder,
|
||||
) (bool, error) {
|
||||
// Process any pending push notifications before executing the pipeline
|
||||
if err := c.processPushNotifications(ctx, cn); err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before writing pipeline: %v", err)
|
||||
}
|
||||
|
||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
||||
return writeCmds(wr, cmds)
|
||||
}); err != nil {
|
||||
@@ -648,7 +913,8 @@ func (c *baseClient) pipelineProcessCmds(
|
||||
}
|
||||
|
||||
if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
|
||||
return pipelineReadCmds(rd, cmds)
|
||||
// read all replies
|
||||
return c.pipelineReadCmds(ctx, cn, rd, cmds)
|
||||
}); err != nil {
|
||||
return true, err
|
||||
}
|
||||
@@ -656,8 +922,12 @@ func (c *baseClient) pipelineProcessCmds(
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
|
||||
func (c *baseClient) pipelineReadCmds(ctx context.Context, cn *pool.Conn, rd *proto.Reader, cmds []Cmder) error {
|
||||
for i, cmd := range cmds {
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
err := cmd.readReply(rd)
|
||||
cmd.SetErr(err)
|
||||
if err != nil && !isRedisError(err) {
|
||||
@@ -672,6 +942,11 @@ func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
|
||||
func (c *baseClient) txPipelineProcessCmds(
|
||||
ctx context.Context, cn *pool.Conn, cmds []Cmder,
|
||||
) (bool, error) {
|
||||
// Process any pending push notifications before executing the transaction pipeline
|
||||
if err := c.processPushNotifications(ctx, cn); err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err)
|
||||
}
|
||||
|
||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
||||
return writeCmds(wr, cmds)
|
||||
}); err != nil {
|
||||
@@ -684,12 +959,13 @@ func (c *baseClient) txPipelineProcessCmds(
|
||||
// Trim multi and exec.
|
||||
trimmedCmds := cmds[1 : len(cmds)-1]
|
||||
|
||||
if err := txPipelineReadQueued(rd, statusCmd, trimmedCmds); err != nil {
|
||||
if err := c.txPipelineReadQueued(ctx, cn, rd, statusCmd, trimmedCmds); err != nil {
|
||||
setCmdsErr(cmds, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return pipelineReadCmds(rd, trimmedCmds)
|
||||
// Read replies.
|
||||
return c.pipelineReadCmds(ctx, cn, rd, trimmedCmds)
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -697,7 +973,13 @@ func (c *baseClient) txPipelineProcessCmds(
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
|
||||
// txPipelineReadQueued reads queued replies from the Redis server.
|
||||
// It returns an error if the server returns an error or if the number of replies does not match the number of commands.
|
||||
func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
// Parse +OK.
|
||||
if err := statusCmd.readReply(rd); err != nil {
|
||||
return err
|
||||
@@ -705,6 +987,10 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
|
||||
|
||||
// Parse +QUEUED.
|
||||
for _, cmd := range cmds {
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
if err := statusCmd.readReply(rd); err != nil {
|
||||
cmd.SetErr(err)
|
||||
if !isRedisError(err) {
|
||||
@@ -713,6 +999,10 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
|
||||
}
|
||||
}
|
||||
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
// Parse number of replies.
|
||||
line, err := rd.ReadLine()
|
||||
if err != nil {
|
||||
@@ -746,15 +1036,61 @@ func NewClient(opt *Options) *Client {
|
||||
if opt == nil {
|
||||
panic("redis: NewClient nil options")
|
||||
}
|
||||
// clone to not share options with the caller
|
||||
opt = opt.clone()
|
||||
opt.init()
|
||||
|
||||
// Push notifications are always enabled for RESP3 (cannot be disabled)
|
||||
|
||||
c := Client{
|
||||
baseClient: &baseClient{
|
||||
opt: opt,
|
||||
},
|
||||
}
|
||||
c.init()
|
||||
c.connPool = newConnPool(opt, c.dialHook)
|
||||
|
||||
// Initialize push notification processor using shared helper
|
||||
// Use void processor for RESP2 connections (push notifications not available)
|
||||
c.pushProcessor = initializePushProcessor(opt)
|
||||
// set opt push processor for child clients
|
||||
c.opt.PushNotificationProcessor = c.pushProcessor
|
||||
|
||||
// Create connection pools
|
||||
var err error
|
||||
c.connPool, err = newConnPool(opt, c.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
|
||||
}
|
||||
c.pubSubPool, err = newPubSubPool(opt, c.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
|
||||
}
|
||||
|
||||
if opt.StreamingCredentialsProvider != nil {
|
||||
c.streamingCredentialsManager = streaming.NewManager(c.connPool, c.opt.PoolTimeout)
|
||||
c.connPool.AddPoolHook(c.streamingCredentialsManager.PoolHook())
|
||||
}
|
||||
|
||||
// Initialize maintnotifications first if enabled and protocol is RESP3
|
||||
if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 {
|
||||
err := c.enableMaintNotificationsUpgrades()
|
||||
if err != nil {
|
||||
internal.Logger.Printf(context.Background(), "failed to initialize maintnotifications: %v", err)
|
||||
if opt.MaintNotificationsConfig.Mode == maintnotifications.ModeEnabled {
|
||||
/*
|
||||
Design decision: panic here to fail fast if maintnotifications cannot be enabled when explicitly requested.
|
||||
We choose to panic instead of returning an error to avoid breaking the existing client API, which does not expect
|
||||
an error from NewClient. This ensures that misconfiguration or critical initialization failures are surfaced
|
||||
immediately, rather than allowing the client to continue in a partially initialized or inconsistent state.
|
||||
Clients relying on maintnotifications should be aware that initialization errors will cause a panic, and should
|
||||
handle this accordingly (e.g., via recover or by validating configuration before calling NewClient).
|
||||
This approach is only used when MaintNotificationsConfig.Mode is MaintNotificationsEnabled, indicating that maintnotifications
|
||||
upgrades are required for correct operation. In other modes, initialization failures are logged but do not panic.
|
||||
*/
|
||||
panic(fmt.Errorf("failed to enable maintnotifications: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &c
|
||||
}
|
||||
@@ -791,11 +1127,51 @@ func (c *Client) Options() *Options {
|
||||
return c.opt
|
||||
}
|
||||
|
||||
// GetMaintNotificationsManager returns the maintnotifications manager instance for monitoring and control.
|
||||
// Returns nil if maintnotifications are not enabled.
|
||||
func (c *Client) GetMaintNotificationsManager() *maintnotifications.Manager {
|
||||
c.maintNotificationsManagerLock.RLock()
|
||||
defer c.maintNotificationsManagerLock.RUnlock()
|
||||
return c.maintNotificationsManager
|
||||
}
|
||||
|
||||
// initializePushProcessor initializes the push notification processor for any client type.
|
||||
// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient.
|
||||
func initializePushProcessor(opt *Options) push.NotificationProcessor {
|
||||
// Always use custom processor if provided
|
||||
if opt.PushNotificationProcessor != nil {
|
||||
return opt.PushNotificationProcessor
|
||||
}
|
||||
|
||||
// Push notifications are always enabled for RESP3, disabled for RESP2
|
||||
if opt.Protocol == 3 {
|
||||
// Create default processor for RESP3 connections
|
||||
return NewPushNotificationProcessor()
|
||||
}
|
||||
|
||||
// Create void processor for RESP2 connections (push notifications not available)
|
||||
return NewVoidPushNotificationProcessor()
|
||||
}
|
||||
|
||||
// RegisterPushNotificationHandler registers a handler for a specific push notification name.
|
||||
// Returns an error if a handler is already registered for this push notification name.
|
||||
// If protected is true, the handler cannot be unregistered.
|
||||
func (c *Client) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error {
|
||||
return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected)
|
||||
}
|
||||
|
||||
// GetPushNotificationHandler returns the handler for a specific push notification name.
|
||||
// Returns nil if no handler is registered for the given name.
|
||||
func (c *Client) GetPushNotificationHandler(pushNotificationName string) push.NotificationHandler {
|
||||
return c.pushProcessor.GetHandler(pushNotificationName)
|
||||
}
|
||||
|
||||
type PoolStats pool.Stats
|
||||
|
||||
// PoolStats returns connection pool stats.
|
||||
func (c *Client) PoolStats() *PoolStats {
|
||||
stats := c.connPool.Stats()
|
||||
stats.PubSubStats = *(c.pubSubPool.Stats())
|
||||
return (*PoolStats)(stats)
|
||||
}
|
||||
|
||||
@@ -830,13 +1206,31 @@ func (c *Client) TxPipeline() Pipeliner {
|
||||
func (c *Client) pubSub() *PubSub {
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt,
|
||||
|
||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
||||
return c.newConn(ctx)
|
||||
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// will return nil if already initialized
|
||||
err = c.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = cn.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Track connection in PubSubPool
|
||||
c.pubSubPool.TrackConn(cn)
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: c.connPool.CloseConn,
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
// Untrack connection from PubSubPool
|
||||
c.pubSubPool.UntrackConn(cn)
|
||||
_ = cn.Close()
|
||||
return nil
|
||||
},
|
||||
pushProcessor: c.pushProcessor,
|
||||
}
|
||||
pubsub.init()
|
||||
|
||||
return pubsub
|
||||
}
|
||||
|
||||
@@ -920,6 +1314,10 @@ func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn
|
||||
c.hooksMixin = parentHooks.clone()
|
||||
}
|
||||
|
||||
// Initialize push notification processor using shared helper
|
||||
// Use void processor for RESP2 connections (push notifications not available)
|
||||
c.pushProcessor = initializePushProcessor(opt)
|
||||
|
||||
c.cmdable = c.Process
|
||||
c.statefulCmdable = c.Process
|
||||
c.initHooks(hooks{
|
||||
@@ -938,6 +1336,13 @@ func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// RegisterPushNotificationHandler registers a handler for a specific push notification name.
|
||||
// Returns an error if a handler is already registered for this push notification name.
|
||||
// If protected is true, the handler cannot be unregistered.
|
||||
func (c *Conn) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error {
|
||||
return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected)
|
||||
}
|
||||
|
||||
func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
|
||||
return c.Pipeline().Pipelined(ctx, fn)
|
||||
}
|
||||
@@ -965,3 +1370,78 @@ func (c *Conn) TxPipeline() Pipeliner {
|
||||
pipe.init()
|
||||
return &pipe
|
||||
}
|
||||
|
||||
// processPushNotifications processes all pending push notifications on a connection
|
||||
// This ensures that cluster topology changes are handled immediately before the connection is used
|
||||
// This method should be called by the client before using WithReader for command execution
|
||||
//
|
||||
// Performance optimization: Skip the expensive MaybeHasData() syscall if a health check
|
||||
// was performed recently (within 5 seconds). The health check already verified the connection
|
||||
// is healthy and checked for unexpected data (push notifications).
|
||||
func (c *baseClient) processPushNotifications(ctx context.Context, cn *pool.Conn) error {
|
||||
// Only process push notifications for RESP3 connections with a processor
|
||||
if c.opt.Protocol != 3 || c.pushProcessor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Performance optimization: Skip MaybeHasData() syscall if health check was recent
|
||||
// If the connection was health-checked within the last 5 seconds, we can skip the
|
||||
// expensive syscall since the health check already verified no unexpected data.
|
||||
// This is safe because:
|
||||
// 0. lastHealthCheckNs is set in pool/conn.go:putConn() after a successful health check
|
||||
// 1. Health check (connCheck) uses the same syscall (Recvfrom with MSG_PEEK)
|
||||
// 2. If push notifications arrived, they would have been detected by health check
|
||||
// 3. 5 seconds is short enough that connection state is still fresh
|
||||
// 4. Push notifications will be processed by the next WithReader call
|
||||
// used it is set on getConn, so we should use another timer (lastPutAt?)
|
||||
lastHealthCheckNs := cn.LastPutAtNs()
|
||||
if lastHealthCheckNs > 0 {
|
||||
// Use pool's cached time to avoid expensive time.Now() syscall
|
||||
nowNs := pool.GetCachedTimeNs()
|
||||
if nowNs-lastHealthCheckNs < int64(5*time.Second) {
|
||||
// Recent health check confirmed no unexpected data, skip the syscall
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there is any data to read before processing
|
||||
// This is an optimization on UNIX systems where MaybeHasData is a syscall
|
||||
// On Windows, MaybeHasData always returns true, so this check is a no-op
|
||||
if !cn.MaybeHasData() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use WithReader to access the reader and process push notifications
|
||||
// This is critical for maintnotifications to work properly
|
||||
// NOTE: almost no timeouts are set for this read, so it should not block
|
||||
// longer than necessary, 10us should be plenty of time to read if there are any push notifications
|
||||
// on the socket.
|
||||
return cn.WithReader(ctx, 10*time.Microsecond, func(rd *proto.Reader) error {
|
||||
// Create handler context with client, connection pool, and connection information
|
||||
handlerCtx := c.pushNotificationHandlerContext(cn)
|
||||
return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd)
|
||||
})
|
||||
}
|
||||
|
||||
// processPendingPushNotificationWithReader processes all pending push notifications on a connection
|
||||
// This method should be called by the client in WithReader before reading the reply
|
||||
func (c *baseClient) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error {
|
||||
// if we have the reader, we don't need to check for data on the socket, we are waiting
|
||||
// for either a reply or a push notification, so we can block until we get a reply or reach the timeout
|
||||
if c.opt.Protocol != 3 || c.pushProcessor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create handler context with client, connection pool, and connection information
|
||||
handlerCtx := c.pushNotificationHandlerContext(cn)
|
||||
return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd)
|
||||
}
|
||||
|
||||
// pushNotificationHandlerContext creates a handler context for push notification processing
|
||||
func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.NotificationHandlerContext {
|
||||
return push.NotificationHandlerContext{
|
||||
Client: c,
|
||||
ConnPool: c.connPool,
|
||||
Conn: cn, // Wrap in adapter for easier interface access
|
||||
}
|
||||
}
|
||||
|
||||
+541
-6
@@ -29,6 +29,8 @@ type SearchCmdable interface {
|
||||
FTDropIndexWithArgs(ctx context.Context, index string, options *FTDropIndexOptions) *StatusCmd
|
||||
FTExplain(ctx context.Context, index string, query string) *StringCmd
|
||||
FTExplainWithArgs(ctx context.Context, index string, query string, options *FTExplainOptions) *StringCmd
|
||||
FTHybrid(ctx context.Context, index string, searchExpr string, vectorField string, vectorData Vector) *FTHybridCmd
|
||||
FTHybridWithArgs(ctx context.Context, index string, options *FTHybridOptions) *FTHybridCmd
|
||||
FTInfo(ctx context.Context, index string) *FTInfoCmd
|
||||
FTSpellCheck(ctx context.Context, index string, query string) *FTSpellCheckCmd
|
||||
FTSpellCheckWithArgs(ctx context.Context, index string, query string, options *FTSpellCheckOptions) *FTSpellCheckCmd
|
||||
@@ -344,6 +346,92 @@ type FTSearchOptions struct {
|
||||
DialectVersion int
|
||||
}
|
||||
|
||||
// FTHybridCombineMethod represents the fusion method for combining search and vector results
|
||||
type FTHybridCombineMethod string
|
||||
|
||||
const (
|
||||
FTHybridCombineRRF FTHybridCombineMethod = "RRF"
|
||||
FTHybridCombineLinear FTHybridCombineMethod = "LINEAR"
|
||||
FTHybridCombineFunction FTHybridCombineMethod = "FUNCTION"
|
||||
)
|
||||
|
||||
// FTHybridSearchExpression represents a search expression in hybrid search
|
||||
type FTHybridSearchExpression struct {
|
||||
Query string
|
||||
Scorer string
|
||||
ScorerParams []interface{}
|
||||
YieldScoreAs string
|
||||
}
|
||||
|
||||
type FTHybridVectorMethod = string
|
||||
|
||||
const (
|
||||
KNN FTHybridCombineMethod = "KNN"
|
||||
RANGE FTHybridCombineMethod = "RANGE"
|
||||
)
|
||||
|
||||
// FTHybridVectorExpression represents a vector expression in hybrid search
|
||||
type FTHybridVectorExpression struct {
|
||||
VectorField string
|
||||
VectorData Vector
|
||||
Method FTHybridVectorMethod
|
||||
MethodParams []interface{}
|
||||
Filter string
|
||||
YieldScoreAs string
|
||||
}
|
||||
|
||||
// FTHybridCombineOptions represents options for result fusion
|
||||
type FTHybridCombineOptions struct {
|
||||
Method FTHybridCombineMethod
|
||||
Count int
|
||||
Window int // For RRF
|
||||
Constant float64 // For RRF
|
||||
Alpha float64 // For LINEAR
|
||||
Beta float64 // For LINEAR
|
||||
YieldScoreAs string
|
||||
}
|
||||
|
||||
// FTHybridGroupBy represents GROUP BY functionality
|
||||
type FTHybridGroupBy struct {
|
||||
Count int
|
||||
Fields []string
|
||||
ReduceFunc string
|
||||
ReduceCount int
|
||||
ReduceParams []interface{}
|
||||
}
|
||||
|
||||
// FTHybridApply represents APPLY functionality
|
||||
type FTHybridApply struct {
|
||||
Expression string
|
||||
AsField string
|
||||
}
|
||||
|
||||
// FTHybridWithCursor represents cursor configuration for hybrid search
|
||||
type FTHybridWithCursor struct {
|
||||
Count int // Number of results to return per cursor read
|
||||
MaxIdle int // Maximum idle time in milliseconds before cursor is automatically deleted
|
||||
}
|
||||
|
||||
// FTHybridOptions hold options that can be passed to the FT.HYBRID command
|
||||
type FTHybridOptions struct {
|
||||
CountExpressions int // Number of search/vector expressions
|
||||
SearchExpressions []FTHybridSearchExpression // Multiple search expressions
|
||||
VectorExpressions []FTHybridVectorExpression // Multiple vector expressions
|
||||
Combine *FTHybridCombineOptions // Fusion step options
|
||||
Load []string // Projected fields
|
||||
GroupBy *FTHybridGroupBy // Aggregation grouping
|
||||
Apply []FTHybridApply // Field transformations
|
||||
SortBy []FTSearchSortBy // Reuse from FTSearch
|
||||
Filter string // Post-filter expression
|
||||
LimitOffset int // Result limiting
|
||||
Limit int
|
||||
Params map[string]interface{} // Parameter substitution
|
||||
ExplainScore bool // Include score explanations
|
||||
Timeout int // Runtime timeout
|
||||
WithCursor bool // Enable cursor support for large result sets
|
||||
WithCursorOptions *FTHybridWithCursor // Cursor configuration options
|
||||
}
|
||||
|
||||
type FTSynDumpResult struct {
|
||||
Term string
|
||||
Synonyms []string
|
||||
@@ -423,6 +511,14 @@ type FTAttribute struct {
|
||||
PhoneticMatcher string
|
||||
CaseSensitive bool
|
||||
WithSuffixtrie bool
|
||||
|
||||
// Vector specific attributes
|
||||
Algorithm string
|
||||
DataType string
|
||||
Dim int
|
||||
DistanceMetric string
|
||||
M int
|
||||
EFConstruction int
|
||||
}
|
||||
|
||||
type CursorStats struct {
|
||||
@@ -1296,21 +1392,26 @@ func parseFTInfo(data map[string]interface{}) (FTInfoResult, error) {
|
||||
for _, attr := range attributes {
|
||||
if attrMap, ok := attr.([]interface{}); ok {
|
||||
att := FTAttribute{}
|
||||
for i := 0; i < len(attrMap); i++ {
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "attribute" {
|
||||
attrLen := len(attrMap)
|
||||
for i := 0; i < attrLen; i++ {
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "attribute" && i+1 < attrLen {
|
||||
att.Attribute = internal.ToString(attrMap[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "identifier" {
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "identifier" && i+1 < attrLen {
|
||||
att.Identifier = internal.ToString(attrMap[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "type" {
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "type" && i+1 < attrLen {
|
||||
att.Type = internal.ToString(attrMap[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "weight" {
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "weight" && i+1 < attrLen {
|
||||
att.Weight = internal.ToFloat(attrMap[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "nostem" {
|
||||
@@ -1329,7 +1430,7 @@ func parseFTInfo(data map[string]interface{}) (FTInfoResult, error) {
|
||||
att.UNF = true
|
||||
continue
|
||||
}
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "phonetic" {
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "phonetic" && i+1 < attrLen {
|
||||
att.PhoneticMatcher = internal.ToString(attrMap[i+1])
|
||||
continue
|
||||
}
|
||||
@@ -1342,6 +1443,38 @@ func parseFTInfo(data map[string]interface{}) (FTInfoResult, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
// vector specific attributes
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "algorithm" && i+1 < attrLen {
|
||||
att.Algorithm = internal.ToString(attrMap[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "data_type" && i+1 < attrLen {
|
||||
att.DataType = internal.ToString(attrMap[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "dim" && i+1 < attrLen {
|
||||
att.Dim = internal.ToInteger(attrMap[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "distance_metric" && i+1 < attrLen {
|
||||
att.DistanceMetric = internal.ToString(attrMap[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "m" && i+1 < attrLen {
|
||||
att.M = internal.ToInteger(attrMap[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if internal.ToLower(internal.ToString(attrMap[i])) == "ef_construction" && i+1 < attrLen {
|
||||
att.EFConstruction = internal.ToInteger(attrMap[i+1])
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
ftInfo.Attributes = append(ftInfo.Attributes, att)
|
||||
}
|
||||
@@ -1819,6 +1952,207 @@ func (cmd *FTSearchCmd) readReply(rd *proto.Reader) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FTHybridResult represents the result of a hybrid search operation
|
||||
type FTHybridResult struct {
|
||||
TotalResults int
|
||||
Results []map[string]interface{}
|
||||
Warnings []string
|
||||
ExecutionTime float64
|
||||
}
|
||||
|
||||
// FTHybridCursorResult represents cursor result for hybrid search
|
||||
type FTHybridCursorResult struct {
|
||||
SearchCursorID int
|
||||
VsimCursorID int
|
||||
}
|
||||
|
||||
type FTHybridCmd struct {
|
||||
baseCmd
|
||||
val FTHybridResult
|
||||
cursorVal *FTHybridCursorResult
|
||||
options *FTHybridOptions
|
||||
withCursor bool
|
||||
}
|
||||
|
||||
func newFTHybridCmd(ctx context.Context, options *FTHybridOptions, args ...interface{}) *FTHybridCmd {
|
||||
var withCursor bool
|
||||
if options != nil && options.WithCursor {
|
||||
withCursor = true
|
||||
}
|
||||
return &FTHybridCmd{
|
||||
baseCmd: baseCmd{
|
||||
ctx: ctx,
|
||||
args: args,
|
||||
},
|
||||
options: options,
|
||||
withCursor: withCursor,
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *FTHybridCmd) String() string {
|
||||
return cmdString(cmd, cmd.val)
|
||||
}
|
||||
|
||||
func (cmd *FTHybridCmd) SetVal(val FTHybridResult) {
|
||||
cmd.val = val
|
||||
}
|
||||
|
||||
func (cmd *FTHybridCmd) Result() (FTHybridResult, error) {
|
||||
return cmd.val, cmd.err
|
||||
}
|
||||
|
||||
func (cmd *FTHybridCmd) CursorResult() (*FTHybridCursorResult, error) {
|
||||
return cmd.cursorVal, cmd.err
|
||||
}
|
||||
|
||||
func (cmd *FTHybridCmd) Val() FTHybridResult {
|
||||
return cmd.val
|
||||
}
|
||||
|
||||
func (cmd *FTHybridCmd) CursorVal() *FTHybridCursorResult {
|
||||
return cmd.cursorVal
|
||||
}
|
||||
|
||||
func (cmd *FTHybridCmd) RawVal() interface{} {
|
||||
return cmd.rawVal
|
||||
}
|
||||
|
||||
func (cmd *FTHybridCmd) RawResult() (interface{}, error) {
|
||||
return cmd.rawVal, cmd.err
|
||||
}
|
||||
|
||||
func parseFTHybrid(data []interface{}, withCursor bool) (FTHybridResult, *FTHybridCursorResult, error) {
|
||||
// Convert to map
|
||||
resultMap := make(map[string]interface{})
|
||||
for i := 0; i < len(data); i += 2 {
|
||||
if i+1 < len(data) {
|
||||
key, ok := data[i].(string)
|
||||
if !ok {
|
||||
return FTHybridResult{}, nil, fmt.Errorf("invalid key type at index %d", i)
|
||||
}
|
||||
resultMap[key] = data[i+1]
|
||||
}
|
||||
}
|
||||
|
||||
// Handle cursor result
|
||||
if withCursor {
|
||||
searchCursorID, ok1 := resultMap["SEARCH"].(int64)
|
||||
vsimCursorID, ok2 := resultMap["VSIM"].(int64)
|
||||
if !ok1 || !ok2 {
|
||||
return FTHybridResult{}, nil, fmt.Errorf("invalid cursor result format")
|
||||
}
|
||||
return FTHybridResult{}, &FTHybridCursorResult{
|
||||
SearchCursorID: int(searchCursorID),
|
||||
VsimCursorID: int(vsimCursorID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Parse regular result
|
||||
totalResults, ok := resultMap["total_results"].(int64)
|
||||
if !ok {
|
||||
return FTHybridResult{}, nil, fmt.Errorf("invalid total_results format")
|
||||
}
|
||||
|
||||
resultsData, ok := resultMap["results"].([]interface{})
|
||||
if !ok {
|
||||
return FTHybridResult{}, nil, fmt.Errorf("invalid results format")
|
||||
}
|
||||
|
||||
// Parse each result item
|
||||
results := make([]map[string]interface{}, 0, len(resultsData))
|
||||
for _, item := range resultsData {
|
||||
// Try parsing as map[string]interface{} first (RESP3 format)
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
results = append(results, itemMap)
|
||||
continue
|
||||
}
|
||||
|
||||
// Try parsing as map[interface{}]interface{} (alternative RESP3 format)
|
||||
if rawMap, ok := item.(map[interface{}]interface{}); ok {
|
||||
itemMap := make(map[string]interface{})
|
||||
for k, v := range rawMap {
|
||||
if keyStr, ok := k.(string); ok {
|
||||
itemMap[keyStr] = v
|
||||
}
|
||||
}
|
||||
results = append(results, itemMap)
|
||||
continue
|
||||
}
|
||||
|
||||
// Fall back to array format (RESP2 format - key-value pairs)
|
||||
itemData, ok := item.([]interface{})
|
||||
if !ok {
|
||||
return FTHybridResult{}, nil, fmt.Errorf("invalid result item format")
|
||||
}
|
||||
|
||||
itemMap := make(map[string]interface{})
|
||||
for i := 0; i < len(itemData); i += 2 {
|
||||
if i+1 < len(itemData) {
|
||||
key, ok := itemData[i].(string)
|
||||
if !ok {
|
||||
return FTHybridResult{}, nil, fmt.Errorf("invalid item key format")
|
||||
}
|
||||
itemMap[key] = itemData[i+1]
|
||||
}
|
||||
}
|
||||
results = append(results, itemMap)
|
||||
}
|
||||
|
||||
// Parse warnings (optional field)
|
||||
var warnings []string
|
||||
if warningsData, ok := resultMap["warnings"].([]interface{}); ok {
|
||||
warnings = make([]string, 0, len(warningsData))
|
||||
for _, w := range warningsData {
|
||||
if ws, ok := w.(string); ok {
|
||||
warnings = append(warnings, ws)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse execution time (optional field)
|
||||
var executionTime float64
|
||||
if execTimeVal, exists := resultMap["execution_time"]; exists {
|
||||
switch v := execTimeVal.(type) {
|
||||
case string:
|
||||
var err error
|
||||
executionTime, err = strconv.ParseFloat(v, 64)
|
||||
if err != nil {
|
||||
return FTHybridResult{}, nil, fmt.Errorf("invalid execution_time format: %v", err)
|
||||
}
|
||||
case float64:
|
||||
executionTime = v
|
||||
case int64:
|
||||
executionTime = float64(v)
|
||||
}
|
||||
}
|
||||
|
||||
return FTHybridResult{
|
||||
TotalResults: int(totalResults),
|
||||
Results: results,
|
||||
Warnings: warnings,
|
||||
ExecutionTime: executionTime,
|
||||
}, nil, nil
|
||||
}
|
||||
|
||||
func (cmd *FTHybridCmd) readReply(rd *proto.Reader) (err error) {
|
||||
data, err := rd.ReadSlice()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, cursorResult, err := parseFTHybrid(data, cmd.withCursor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cmd.withCursor {
|
||||
cmd.cursorVal = cursorResult
|
||||
} else {
|
||||
cmd.val = result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FTSearch - Executes a search query on an index.
|
||||
// The 'index' parameter specifies the index to search, and the 'query' parameter specifies the search query.
|
||||
// For more information, please refer to the Redis documentation about [FT.SEARCH].
|
||||
@@ -2191,3 +2525,204 @@ func (c cmdable) FTTagVals(ctx context.Context, index string, field string) *Str
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// FTHybrid - Executes a hybrid search combining full-text search and vector similarity
|
||||
// The 'index' parameter specifies the index to search, 'searchExpr' is the search query,
|
||||
// 'vectorField' is the name of the vector field, and 'vectorData' is the vector to search with.
|
||||
// FTHybrid is still experimental, the command behaviour and signature may change
|
||||
func (c cmdable) FTHybrid(ctx context.Context, index string, searchExpr string, vectorField string, vectorData Vector) *FTHybridCmd {
|
||||
options := &FTHybridOptions{
|
||||
CountExpressions: 2,
|
||||
SearchExpressions: []FTHybridSearchExpression{
|
||||
{Query: searchExpr},
|
||||
},
|
||||
VectorExpressions: []FTHybridVectorExpression{
|
||||
{VectorField: vectorField, VectorData: vectorData},
|
||||
},
|
||||
}
|
||||
return c.FTHybridWithArgs(ctx, index, options)
|
||||
}
|
||||
|
||||
// FTHybridWithArgs - Executes a hybrid search with advanced options
|
||||
// FTHybridWithArgs is still experimental, the command behaviour and signature may change
|
||||
func (c cmdable) FTHybridWithArgs(ctx context.Context, index string, options *FTHybridOptions) *FTHybridCmd {
|
||||
args := []interface{}{"FT.HYBRID", index}
|
||||
|
||||
if options != nil {
|
||||
// Add search expressions
|
||||
for _, searchExpr := range options.SearchExpressions {
|
||||
args = append(args, "SEARCH", searchExpr.Query)
|
||||
|
||||
if searchExpr.Scorer != "" {
|
||||
args = append(args, "SCORER", searchExpr.Scorer)
|
||||
if len(searchExpr.ScorerParams) > 0 {
|
||||
args = append(args, searchExpr.ScorerParams...)
|
||||
}
|
||||
}
|
||||
|
||||
if searchExpr.YieldScoreAs != "" {
|
||||
args = append(args, "YIELD_SCORE_AS", searchExpr.YieldScoreAs)
|
||||
}
|
||||
}
|
||||
|
||||
// Add vector expressions
|
||||
for _, vectorExpr := range options.VectorExpressions {
|
||||
args = append(args, "VSIM", "@"+vectorExpr.VectorField)
|
||||
|
||||
// For FT.HYBRID, we need to send just the raw vector bytes, not the Value() format
|
||||
// Value() returns [format, data] but FT.HYBRID expects just the blob
|
||||
vectorValue := vectorExpr.VectorData.Value()
|
||||
if len(vectorValue) >= 2 {
|
||||
// vectorValue is [format, data, ...] - we only want the data part
|
||||
args = append(args, vectorValue[1])
|
||||
} else {
|
||||
// Fallback for unexpected format
|
||||
args = append(args, vectorValue...)
|
||||
}
|
||||
|
||||
if vectorExpr.Method != "" {
|
||||
args = append(args, vectorExpr.Method)
|
||||
if len(vectorExpr.MethodParams) > 0 {
|
||||
// MethodParams should be key-value pairs, count them
|
||||
args = append(args, len(vectorExpr.MethodParams))
|
||||
args = append(args, vectorExpr.MethodParams...)
|
||||
}
|
||||
}
|
||||
|
||||
if vectorExpr.Filter != "" {
|
||||
args = append(args, "FILTER", vectorExpr.Filter)
|
||||
}
|
||||
|
||||
if vectorExpr.YieldScoreAs != "" {
|
||||
args = append(args, "YIELD_SCORE_AS", vectorExpr.YieldScoreAs)
|
||||
}
|
||||
}
|
||||
|
||||
// Add combine/fusion options
|
||||
if options.Combine != nil {
|
||||
// Build combine parameters
|
||||
combineParams := []interface{}{}
|
||||
|
||||
switch options.Combine.Method {
|
||||
case FTHybridCombineRRF:
|
||||
if options.Combine.Window > 0 {
|
||||
combineParams = append(combineParams, "WINDOW", options.Combine.Window)
|
||||
}
|
||||
if options.Combine.Constant > 0 {
|
||||
combineParams = append(combineParams, "CONSTANT", options.Combine.Constant)
|
||||
}
|
||||
case FTHybridCombineLinear:
|
||||
if options.Combine.Alpha > 0 {
|
||||
combineParams = append(combineParams, "ALPHA", options.Combine.Alpha)
|
||||
}
|
||||
if options.Combine.Beta > 0 {
|
||||
combineParams = append(combineParams, "BETA", options.Combine.Beta)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Combine.YieldScoreAs != "" {
|
||||
combineParams = append(combineParams, "YIELD_SCORE_AS", options.Combine.YieldScoreAs)
|
||||
}
|
||||
|
||||
// Add COMBINE with method and parameter count
|
||||
args = append(args, "COMBINE", string(options.Combine.Method))
|
||||
if len(combineParams) > 0 {
|
||||
args = append(args, len(combineParams))
|
||||
args = append(args, combineParams...)
|
||||
}
|
||||
}
|
||||
|
||||
// Add LOAD (projected fields)
|
||||
if len(options.Load) > 0 {
|
||||
args = append(args, "LOAD", len(options.Load))
|
||||
for _, field := range options.Load {
|
||||
args = append(args, field)
|
||||
}
|
||||
}
|
||||
|
||||
// Add GROUPBY
|
||||
if options.GroupBy != nil {
|
||||
args = append(args, "GROUPBY", options.GroupBy.Count)
|
||||
for _, field := range options.GroupBy.Fields {
|
||||
args = append(args, field)
|
||||
}
|
||||
if options.GroupBy.ReduceFunc != "" {
|
||||
args = append(args, "REDUCE", options.GroupBy.ReduceFunc, options.GroupBy.ReduceCount)
|
||||
args = append(args, options.GroupBy.ReduceParams...)
|
||||
}
|
||||
}
|
||||
|
||||
// Add APPLY transformations
|
||||
for _, apply := range options.Apply {
|
||||
args = append(args, "APPLY", apply.Expression, "AS", apply.AsField)
|
||||
}
|
||||
|
||||
// Add SORTBY
|
||||
if len(options.SortBy) > 0 {
|
||||
sortByOptions := []interface{}{}
|
||||
for _, sortBy := range options.SortBy {
|
||||
sortByOptions = append(sortByOptions, sortBy.FieldName)
|
||||
if sortBy.Asc && sortBy.Desc {
|
||||
cmd := newFTHybridCmd(ctx, options, args...)
|
||||
cmd.SetErr(fmt.Errorf("FT.HYBRID: ASC and DESC are mutually exclusive"))
|
||||
return cmd
|
||||
}
|
||||
if sortBy.Asc {
|
||||
sortByOptions = append(sortByOptions, "ASC")
|
||||
}
|
||||
if sortBy.Desc {
|
||||
sortByOptions = append(sortByOptions, "DESC")
|
||||
}
|
||||
}
|
||||
args = append(args, "SORTBY", len(sortByOptions))
|
||||
args = append(args, sortByOptions...)
|
||||
}
|
||||
|
||||
// Add FILTER (post-filter)
|
||||
if options.Filter != "" {
|
||||
args = append(args, "FILTER", options.Filter)
|
||||
}
|
||||
|
||||
// Add LIMIT
|
||||
if options.LimitOffset >= 0 && options.Limit > 0 || options.LimitOffset > 0 && options.Limit == 0 {
|
||||
args = append(args, "LIMIT", options.LimitOffset, options.Limit)
|
||||
}
|
||||
|
||||
// Add PARAMS
|
||||
if len(options.Params) > 0 {
|
||||
args = append(args, "PARAMS", len(options.Params)*2)
|
||||
for key, value := range options.Params {
|
||||
// Parameter keys should already have '$' prefix from the user
|
||||
// Don't add it again if it's already there
|
||||
args = append(args, key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Add EXPLAINSCORE
|
||||
if options.ExplainScore {
|
||||
args = append(args, "EXPLAINSCORE")
|
||||
}
|
||||
|
||||
// Add TIMEOUT
|
||||
if options.Timeout > 0 {
|
||||
args = append(args, "TIMEOUT", options.Timeout)
|
||||
}
|
||||
|
||||
// Add WITHCURSOR support
|
||||
if options.WithCursor {
|
||||
args = append(args, "WITHCURSOR")
|
||||
if options.WithCursorOptions != nil {
|
||||
if options.WithCursorOptions.Count > 0 {
|
||||
args = append(args, "COUNT", options.WithCursorOptions.Count)
|
||||
}
|
||||
if options.WithCursorOptions.MaxIdle > 0 {
|
||||
args = append(args, "MAXIDLE", options.WithCursorOptions.MaxIdle)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cmd := newFTHybridCmd(ctx, options, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
+100
-13
@@ -17,6 +17,8 @@ import (
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/rand"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -62,6 +64,8 @@ type FailoverOptions struct {
|
||||
Protocol int
|
||||
Username string
|
||||
Password string
|
||||
|
||||
// Push notifications are always enabled for RESP3 connections
|
||||
// CredentialsProvider allows the username and password to be updated
|
||||
// before reconnecting. It should return the current username and password.
|
||||
CredentialsProvider func() (username string, password string)
|
||||
@@ -136,6 +140,15 @@ type FailoverOptions struct {
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
UnstableResp3 bool
|
||||
|
||||
// MaintNotificationsConfig is not supported for FailoverClients at the moment
|
||||
// MaintNotificationsConfig provides custom configuration for maintnotifications upgrades.
|
||||
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
|
||||
// upgrade notifications gracefully and manage connection/pool state transitions
|
||||
// seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, maintnotifications upgrades are disabled.
|
||||
// (however if Mode is nil, it defaults to "auto" - enable if server supports it)
|
||||
//MaintNotificationsConfig *maintnotifications.Config
|
||||
}
|
||||
|
||||
func (opt *FailoverOptions) clientOptions() *Options {
|
||||
@@ -182,6 +195,10 @@ func (opt *FailoverOptions) clientOptions() *Options {
|
||||
|
||||
IdentitySuffix: opt.IdentitySuffix,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeDisabled,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -226,6 +243,10 @@ func (opt *FailoverOptions) sentinelOptions(addr string) *Options {
|
||||
|
||||
IdentitySuffix: opt.IdentitySuffix,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeDisabled,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,6 +296,10 @@ func (opt *FailoverOptions) clusterOptions() *ClusterOptions {
|
||||
DisableIndentity: opt.DisableIndentity,
|
||||
IdentitySuffix: opt.IdentitySuffix,
|
||||
FailingTimeoutSeconds: opt.FailingTimeoutSeconds,
|
||||
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeDisabled,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -454,8 +479,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
||||
opt.Dialer = masterReplicaDialer(failover)
|
||||
opt.init()
|
||||
|
||||
var connPool *pool.ConnPool
|
||||
|
||||
rdb := &Client{
|
||||
baseClient: &baseClient{
|
||||
opt: opt,
|
||||
@@ -463,15 +486,29 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
|
||||
}
|
||||
rdb.init()
|
||||
|
||||
connPool = newConnPool(opt, rdb.dialHook)
|
||||
rdb.connPool = connPool
|
||||
// Initialize push notification processor using shared helper
|
||||
// Use void processor by default for RESP2 connections
|
||||
rdb.pushProcessor = initializePushProcessor(opt)
|
||||
|
||||
var err error
|
||||
rdb.connPool, err = newConnPool(opt, rdb.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
|
||||
}
|
||||
rdb.pubSubPool, err = newPubSubPool(opt, rdb.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
|
||||
}
|
||||
|
||||
rdb.onClose = rdb.wrappedOnClose(failover.Close)
|
||||
|
||||
failover.mu.Lock()
|
||||
failover.onFailover = func(ctx context.Context, addr string) {
|
||||
_ = connPool.Filter(func(cn *pool.Conn) bool {
|
||||
return cn.RemoteAddr().String() != addr
|
||||
})
|
||||
if connPool, ok := rdb.connPool.(*pool.ConnPool); ok {
|
||||
_ = connPool.Filter(func(cn *pool.Conn) bool {
|
||||
return cn.RemoteAddr().String() != addr
|
||||
})
|
||||
}
|
||||
}
|
||||
failover.mu.Unlock()
|
||||
|
||||
@@ -529,15 +566,40 @@ func NewSentinelClient(opt *Options) *SentinelClient {
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize push notification processor using shared helper
|
||||
// Use void processor for Sentinel clients
|
||||
c.pushProcessor = NewVoidPushNotificationProcessor()
|
||||
|
||||
c.initHooks(hooks{
|
||||
dial: c.baseClient.dial,
|
||||
process: c.baseClient.process,
|
||||
})
|
||||
c.connPool = newConnPool(opt, c.dialHook)
|
||||
var err error
|
||||
c.connPool, err = newConnPool(opt, c.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create connection pool: %w", err))
|
||||
}
|
||||
c.pubSubPool, err = newPubSubPool(opt, c.dialHook)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err))
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// GetPushNotificationHandler returns the handler for a specific push notification name.
|
||||
// Returns nil if no handler is registered for the given name.
|
||||
func (c *SentinelClient) GetPushNotificationHandler(pushNotificationName string) push.NotificationHandler {
|
||||
return c.pushProcessor.GetHandler(pushNotificationName)
|
||||
}
|
||||
|
||||
// RegisterPushNotificationHandler registers a handler for a specific push notification name.
|
||||
// Returns an error if a handler is already registered for this push notification name.
|
||||
// If protected is true, the handler cannot be unregistered.
|
||||
func (c *SentinelClient) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error {
|
||||
return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected)
|
||||
}
|
||||
|
||||
func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
|
||||
err := c.processHook(ctx, cmd)
|
||||
cmd.SetErr(err)
|
||||
@@ -547,13 +609,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error {
|
||||
func (c *SentinelClient) pubSub() *PubSub {
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt,
|
||||
|
||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
||||
return c.newConn(ctx)
|
||||
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||
cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// will return nil if already initialized
|
||||
err = c.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = cn.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Track connection in PubSubPool
|
||||
c.pubSubPool.TrackConn(cn)
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: c.connPool.CloseConn,
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
// Untrack connection from PubSubPool
|
||||
c.pubSubPool.UntrackConn(cn)
|
||||
_ = cn.Close()
|
||||
return nil
|
||||
},
|
||||
pushProcessor: c.pushProcessor,
|
||||
}
|
||||
pubsub.init()
|
||||
|
||||
return pubsub
|
||||
}
|
||||
|
||||
@@ -776,6 +856,11 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// short circuit if no sentinels configured
|
||||
if len(c.sentinelAddrs) == 0 {
|
||||
return "", errors.New("redis: no sentinels configured")
|
||||
}
|
||||
|
||||
var (
|
||||
masterAddr string
|
||||
wg sync.WaitGroup
|
||||
@@ -823,10 +908,12 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) {
|
||||
}
|
||||
|
||||
func joinErrors(errs []error) string {
|
||||
if len(errs) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(errs) == 1 {
|
||||
return errs[0].Error()
|
||||
}
|
||||
|
||||
b := []byte(errs[0].Error())
|
||||
for _, err := range errs[1:] {
|
||||
b = append(b, '\n')
|
||||
|
||||
+7
-2
@@ -2,6 +2,7 @@ package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -313,7 +314,9 @@ func (c cmdable) ZPopMax(ctx context.Context, key string, count ...int64) *ZSlic
|
||||
case 1:
|
||||
args = append(args, count[0])
|
||||
default:
|
||||
panic("too many arguments")
|
||||
cmd := NewZSliceCmd(ctx)
|
||||
cmd.SetErr(errors.New("too many arguments"))
|
||||
return cmd
|
||||
}
|
||||
|
||||
cmd := NewZSliceCmd(ctx, args...)
|
||||
@@ -333,7 +336,9 @@ func (c cmdable) ZPopMin(ctx context.Context, key string, count ...int64) *ZSlic
|
||||
case 1:
|
||||
args = append(args, count[0])
|
||||
default:
|
||||
panic("too many arguments")
|
||||
cmd := NewZSliceCmd(ctx)
|
||||
cmd.SetErr(errors.New("too many arguments"))
|
||||
return cmd
|
||||
}
|
||||
|
||||
cmd := NewZSliceCmd(ctx, args...)
|
||||
|
||||
+5
@@ -263,6 +263,7 @@ type XReadGroupArgs struct {
|
||||
Count int64
|
||||
Block time.Duration
|
||||
NoAck bool
|
||||
Claim time.Duration // Claim idle pending entries older than this duration
|
||||
}
|
||||
|
||||
func (c cmdable) XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSliceCmd {
|
||||
@@ -282,6 +283,10 @@ func (c cmdable) XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSlic
|
||||
args = append(args, "noack")
|
||||
keyPos++
|
||||
}
|
||||
if a.Claim > 0 {
|
||||
args = append(args, "claim", int64(a.Claim/time.Millisecond))
|
||||
keyPos += 2
|
||||
}
|
||||
args = append(args, "streams")
|
||||
keyPos++
|
||||
for _, s := range a.Streams {
|
||||
|
||||
+446
-1
@@ -2,6 +2,7 @@ package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -9,6 +10,8 @@ type StringCmdable interface {
|
||||
Append(ctx context.Context, key, value string) *IntCmd
|
||||
Decr(ctx context.Context, key string) *IntCmd
|
||||
DecrBy(ctx context.Context, key string, decrement int64) *IntCmd
|
||||
DelExArgs(ctx context.Context, key string, a DelExArgs) *IntCmd
|
||||
Digest(ctx context.Context, key string) *DigestCmd
|
||||
Get(ctx context.Context, key string) *StringCmd
|
||||
GetRange(ctx context.Context, key string, start, end int64) *StringCmd
|
||||
GetSet(ctx context.Context, key string, value interface{}) *StringCmd
|
||||
@@ -21,9 +24,18 @@ type StringCmdable interface {
|
||||
MGet(ctx context.Context, keys ...string) *SliceCmd
|
||||
MSet(ctx context.Context, values ...interface{}) *StatusCmd
|
||||
MSetNX(ctx context.Context, values ...interface{}) *BoolCmd
|
||||
MSetEX(ctx context.Context, args MSetEXArgs, values ...interface{}) *IntCmd
|
||||
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd
|
||||
SetArgs(ctx context.Context, key string, value interface{}, a SetArgs) *StatusCmd
|
||||
SetEx(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd
|
||||
SetIFEQ(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd
|
||||
SetIFEQGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd
|
||||
SetIFNE(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd
|
||||
SetIFNEGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd
|
||||
SetIFDEQ(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd
|
||||
SetIFDEQGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd
|
||||
SetIFDNE(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd
|
||||
SetIFDNEGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd
|
||||
SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *BoolCmd
|
||||
SetXX(ctx context.Context, key string, value interface{}, expiration time.Duration) *BoolCmd
|
||||
SetRange(ctx context.Context, key string, offset int64, value string) *IntCmd
|
||||
@@ -48,6 +60,76 @@ func (c cmdable) DecrBy(ctx context.Context, key string, decrement int64) *IntCm
|
||||
return cmd
|
||||
}
|
||||
|
||||
// DelExArgs provides arguments for the DelExArgs function.
|
||||
type DelExArgs struct {
|
||||
// Mode can be `IFEQ`, `IFNE`, `IFDEQ`, or `IFDNE`.
|
||||
Mode string
|
||||
|
||||
// MatchValue is used with IFEQ/IFNE modes for compare-and-delete operations.
|
||||
// - IFEQ: only delete if current value equals MatchValue
|
||||
// - IFNE: only delete if current value does not equal MatchValue
|
||||
MatchValue interface{}
|
||||
|
||||
// MatchDigest is used with IFDEQ/IFDNE modes for digest-based compare-and-delete.
|
||||
// - IFDEQ: only delete if current value's digest equals MatchDigest
|
||||
// - IFDNE: only delete if current value's digest does not equal MatchDigest
|
||||
//
|
||||
// The digest is a uint64 xxh3 hash value.
|
||||
//
|
||||
// For examples of client-side digest generation, see:
|
||||
// example/digest-optimistic-locking/
|
||||
MatchDigest uint64
|
||||
}
|
||||
|
||||
// DelExArgs Redis `DELEX key [IFEQ|IFNE|IFDEQ|IFDNE] match-value` command.
|
||||
// Compare-and-delete with flexible conditions.
|
||||
//
|
||||
// Returns the number of keys that were removed (0 or 1).
|
||||
//
|
||||
// NOTE DelExArgs is still experimental
|
||||
// it's signature and behaviour may change
|
||||
func (c cmdable) DelExArgs(ctx context.Context, key string, a DelExArgs) *IntCmd {
|
||||
args := []interface{}{"delex", key}
|
||||
|
||||
if a.Mode != "" {
|
||||
args = append(args, a.Mode)
|
||||
|
||||
// Add match value/digest based on mode
|
||||
switch a.Mode {
|
||||
case "ifeq", "IFEQ", "ifne", "IFNE":
|
||||
if a.MatchValue != nil {
|
||||
args = append(args, a.MatchValue)
|
||||
}
|
||||
case "ifdeq", "IFDEQ", "ifdne", "IFDNE":
|
||||
if a.MatchDigest != 0 {
|
||||
args = append(args, fmt.Sprintf("%016x", a.MatchDigest))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cmd := NewIntCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Digest returns the xxh3 hash (uint64) of the specified key's value.
|
||||
//
|
||||
// The digest is a 64-bit xxh3 hash that can be used for optimistic locking
|
||||
// with SetIFDEQ, SetIFDNE, and DelExArgs commands.
|
||||
//
|
||||
// For examples of client-side digest generation and usage patterns, see:
|
||||
// example/digest-optimistic-locking/
|
||||
//
|
||||
// Redis 8.4+. See https://redis.io/commands/digest/
|
||||
//
|
||||
// NOTE Digest is still experimental
|
||||
// it's signature and behaviour may change
|
||||
func (c cmdable) Digest(ctx context.Context, key string) *DigestCmd {
|
||||
cmd := NewDigestCmd(ctx, "digest", key)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Get Redis `GET key` command. It returns redis.Nil error when key does not exist.
|
||||
func (c cmdable) Get(ctx context.Context, key string) *StringCmd {
|
||||
cmd := NewStringCmd(ctx, "get", key)
|
||||
@@ -112,6 +194,35 @@ func (c cmdable) IncrByFloat(ctx context.Context, key string, value float64) *Fl
|
||||
return cmd
|
||||
}
|
||||
|
||||
type SetCondition string
|
||||
|
||||
const (
|
||||
// NX only set the keys and their expiration if none exist
|
||||
NX SetCondition = "NX"
|
||||
// XX only set the keys and their expiration if all already exist
|
||||
XX SetCondition = "XX"
|
||||
)
|
||||
|
||||
type ExpirationMode string
|
||||
|
||||
const (
|
||||
// EX sets expiration in seconds
|
||||
EX ExpirationMode = "EX"
|
||||
// PX sets expiration in milliseconds
|
||||
PX ExpirationMode = "PX"
|
||||
// EXAT sets expiration as Unix timestamp in seconds
|
||||
EXAT ExpirationMode = "EXAT"
|
||||
// PXAT sets expiration as Unix timestamp in milliseconds
|
||||
PXAT ExpirationMode = "PXAT"
|
||||
// KEEPTTL keeps the existing TTL
|
||||
KEEPTTL ExpirationMode = "KEEPTTL"
|
||||
)
|
||||
|
||||
type ExpirationOption struct {
|
||||
Mode ExpirationMode
|
||||
Value int64
|
||||
}
|
||||
|
||||
func (c cmdable) LCS(ctx context.Context, q *LCSQuery) *LCSCmd {
|
||||
cmd := NewLCSCmd(ctx, q)
|
||||
_ = c(ctx, cmd)
|
||||
@@ -157,6 +268,49 @@ func (c cmdable) MSetNX(ctx context.Context, values ...interface{}) *BoolCmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
type MSetEXArgs struct {
|
||||
Condition SetCondition
|
||||
Expiration *ExpirationOption
|
||||
}
|
||||
|
||||
// MSetEX sets the given keys to their respective values.
|
||||
// This command is an extension of the MSETNX that adds expiration and XX options.
|
||||
// Available since Redis 8.4
|
||||
// Important: When this method is used with Cluster clients, all keys
|
||||
// must be in the same hash slot, otherwise CROSSSLOT error will be returned.
|
||||
// For more information, see https://redis.io/commands/msetex
|
||||
func (c cmdable) MSetEX(ctx context.Context, args MSetEXArgs, values ...interface{}) *IntCmd {
|
||||
expandedArgs := appendArgs([]interface{}{}, values)
|
||||
numkeys := len(expandedArgs) / 2
|
||||
|
||||
cmdArgs := make([]interface{}, 0, 2+len(expandedArgs)+3)
|
||||
cmdArgs = append(cmdArgs, "msetex", numkeys)
|
||||
cmdArgs = append(cmdArgs, expandedArgs...)
|
||||
|
||||
if args.Condition != "" {
|
||||
cmdArgs = append(cmdArgs, string(args.Condition))
|
||||
}
|
||||
|
||||
if args.Expiration != nil {
|
||||
switch args.Expiration.Mode {
|
||||
case EX:
|
||||
cmdArgs = append(cmdArgs, "ex", args.Expiration.Value)
|
||||
case PX:
|
||||
cmdArgs = append(cmdArgs, "px", args.Expiration.Value)
|
||||
case EXAT:
|
||||
cmdArgs = append(cmdArgs, "exat", args.Expiration.Value)
|
||||
case PXAT:
|
||||
cmdArgs = append(cmdArgs, "pxat", args.Expiration.Value)
|
||||
case KEEPTTL:
|
||||
cmdArgs = append(cmdArgs, "keepttl")
|
||||
}
|
||||
}
|
||||
|
||||
cmd := NewIntCmd(ctx, cmdArgs...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Set Redis `SET key value [expiration]` command.
|
||||
// Use expiration for `SETEx`-like behavior.
|
||||
//
|
||||
@@ -185,9 +339,24 @@ func (c cmdable) Set(ctx context.Context, key string, value interface{}, expirat
|
||||
|
||||
// SetArgs provides arguments for the SetArgs function.
|
||||
type SetArgs struct {
|
||||
// Mode can be `NX` or `XX` or empty.
|
||||
// Mode can be `NX`, `XX`, `IFEQ`, `IFNE`, `IFDEQ`, `IFDNE` or empty.
|
||||
Mode string
|
||||
|
||||
// MatchValue is used with IFEQ/IFNE modes for compare-and-set operations.
|
||||
// - IFEQ: only set if current value equals MatchValue
|
||||
// - IFNE: only set if current value does not equal MatchValue
|
||||
MatchValue interface{}
|
||||
|
||||
// MatchDigest is used with IFDEQ/IFDNE modes for digest-based compare-and-set.
|
||||
// - IFDEQ: only set if current value's digest equals MatchDigest
|
||||
// - IFDNE: only set if current value's digest does not equal MatchDigest
|
||||
//
|
||||
// The digest is a uint64 xxh3 hash value.
|
||||
//
|
||||
// For examples of client-side digest generation, see:
|
||||
// example/digest-optimistic-locking/
|
||||
MatchDigest uint64
|
||||
|
||||
// Zero `TTL` or `Expiration` means that the key has no expiration time.
|
||||
TTL time.Duration
|
||||
ExpireAt time.Time
|
||||
@@ -223,6 +392,18 @@ func (c cmdable) SetArgs(ctx context.Context, key string, value interface{}, a S
|
||||
|
||||
if a.Mode != "" {
|
||||
args = append(args, a.Mode)
|
||||
|
||||
// Add match value/digest for CAS modes
|
||||
switch a.Mode {
|
||||
case "ifeq", "IFEQ", "ifne", "IFNE":
|
||||
if a.MatchValue != nil {
|
||||
args = append(args, a.MatchValue)
|
||||
}
|
||||
case "ifdeq", "IFDEQ", "ifdne", "IFDNE":
|
||||
if a.MatchDigest != 0 {
|
||||
args = append(args, fmt.Sprintf("%016x", a.MatchDigest))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if a.Get {
|
||||
@@ -290,6 +471,270 @@ func (c cmdable) SetXX(ctx context.Context, key string, value interface{}, expir
|
||||
return cmd
|
||||
}
|
||||
|
||||
// SetIFEQ Redis `SET key value [expiration] IFEQ match-value` command.
|
||||
// Compare-and-set: only sets the value if the current value equals matchValue.
|
||||
//
|
||||
// Returns "OK" on success.
|
||||
// Returns nil if the operation was aborted due to condition not matching.
|
||||
// Zero expiration means the key has no expiration time.
|
||||
//
|
||||
// NOTE SetIFEQ is still experimental
|
||||
// it's signature and behaviour may change
|
||||
func (c cmdable) SetIFEQ(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd {
|
||||
args := []interface{}{"set", key, value}
|
||||
|
||||
if expiration > 0 {
|
||||
if usePrecise(expiration) {
|
||||
args = append(args, "px", formatMs(ctx, expiration))
|
||||
} else {
|
||||
args = append(args, "ex", formatSec(ctx, expiration))
|
||||
}
|
||||
} else if expiration == KeepTTL {
|
||||
args = append(args, "keepttl")
|
||||
}
|
||||
|
||||
args = append(args, "ifeq", matchValue)
|
||||
|
||||
cmd := NewStatusCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// SetIFEQGet Redis `SET key value [expiration] IFEQ match-value GET` command.
|
||||
// Compare-and-set with GET: only sets the value if the current value equals matchValue,
|
||||
// and returns the previous value.
|
||||
//
|
||||
// Returns the previous value on success.
|
||||
// Returns nil if the operation was aborted due to condition not matching.
|
||||
// Zero expiration means the key has no expiration time.
|
||||
//
|
||||
// NOTE SetIFEQGet is still experimental
|
||||
// it's signature and behaviour may change
|
||||
func (c cmdable) SetIFEQGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd {
|
||||
args := []interface{}{"set", key, value}
|
||||
|
||||
if expiration > 0 {
|
||||
if usePrecise(expiration) {
|
||||
args = append(args, "px", formatMs(ctx, expiration))
|
||||
} else {
|
||||
args = append(args, "ex", formatSec(ctx, expiration))
|
||||
}
|
||||
} else if expiration == KeepTTL {
|
||||
args = append(args, "keepttl")
|
||||
}
|
||||
|
||||
args = append(args, "ifeq", matchValue, "get")
|
||||
|
||||
cmd := NewStringCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// SetIFNE Redis `SET key value [expiration] IFNE match-value` command.
|
||||
// Compare-and-set: only sets the value if the current value does not equal matchValue.
|
||||
//
|
||||
// Returns "OK" on success.
|
||||
// Returns nil if the operation was aborted due to condition not matching.
|
||||
// Zero expiration means the key has no expiration time.
|
||||
//
|
||||
// NOTE SetIFNE is still experimental
|
||||
// it's signature and behaviour may change
|
||||
func (c cmdable) SetIFNE(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd {
|
||||
args := []interface{}{"set", key, value}
|
||||
|
||||
if expiration > 0 {
|
||||
if usePrecise(expiration) {
|
||||
args = append(args, "px", formatMs(ctx, expiration))
|
||||
} else {
|
||||
args = append(args, "ex", formatSec(ctx, expiration))
|
||||
}
|
||||
} else if expiration == KeepTTL {
|
||||
args = append(args, "keepttl")
|
||||
}
|
||||
|
||||
args = append(args, "ifne", matchValue)
|
||||
|
||||
cmd := NewStatusCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// SetIFNEGet Redis `SET key value [expiration] IFNE match-value GET` command.
|
||||
// Compare-and-set with GET: only sets the value if the current value does not equal matchValue,
|
||||
// and returns the previous value.
|
||||
//
|
||||
// Returns the previous value on success.
|
||||
// Returns nil if the operation was aborted due to condition not matching.
|
||||
// Zero expiration means the key has no expiration time.
|
||||
//
|
||||
// NOTE SetIFNEGet is still experimental
|
||||
// it's signature and behaviour may change
|
||||
func (c cmdable) SetIFNEGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd {
|
||||
args := []interface{}{"set", key, value}
|
||||
|
||||
if expiration > 0 {
|
||||
if usePrecise(expiration) {
|
||||
args = append(args, "px", formatMs(ctx, expiration))
|
||||
} else {
|
||||
args = append(args, "ex", formatSec(ctx, expiration))
|
||||
}
|
||||
} else if expiration == KeepTTL {
|
||||
args = append(args, "keepttl")
|
||||
}
|
||||
|
||||
args = append(args, "ifne", matchValue, "get")
|
||||
|
||||
cmd := NewStringCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// SetIFDEQ sets the value only if the current value's digest equals matchDigest.
|
||||
//
|
||||
// This is a compare-and-set operation using xxh3 digest for optimistic locking.
|
||||
// The matchDigest parameter is a uint64 xxh3 hash value.
|
||||
//
|
||||
// Returns "OK" on success.
|
||||
// Returns redis.Nil if the digest doesn't match (value was modified).
|
||||
// Zero expiration means the key has no expiration time.
|
||||
//
|
||||
// For examples of client-side digest generation and usage patterns, see:
|
||||
// example/digest-optimistic-locking/
|
||||
//
|
||||
// Redis 8.4+. See https://redis.io/commands/set/
|
||||
//
|
||||
// NOTE SetIFNEQ is still experimental
|
||||
// it's signature and behaviour may change
|
||||
func (c cmdable) SetIFDEQ(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd {
|
||||
args := []interface{}{"set", key, value}
|
||||
|
||||
if expiration > 0 {
|
||||
if usePrecise(expiration) {
|
||||
args = append(args, "px", formatMs(ctx, expiration))
|
||||
} else {
|
||||
args = append(args, "ex", formatSec(ctx, expiration))
|
||||
}
|
||||
} else if expiration == KeepTTL {
|
||||
args = append(args, "keepttl")
|
||||
}
|
||||
|
||||
args = append(args, "ifdeq", fmt.Sprintf("%016x", matchDigest))
|
||||
|
||||
cmd := NewStatusCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// SetIFDEQGet sets the value only if the current value's digest equals matchDigest,
|
||||
// and returns the previous value.
|
||||
//
|
||||
// This is a compare-and-set operation using xxh3 digest for optimistic locking.
|
||||
// The matchDigest parameter is a uint64 xxh3 hash value.
|
||||
//
|
||||
// Returns the previous value on success.
|
||||
// Returns redis.Nil if the digest doesn't match (value was modified).
|
||||
// Zero expiration means the key has no expiration time.
|
||||
//
|
||||
// For examples of client-side digest generation and usage patterns, see:
|
||||
// example/digest-optimistic-locking/
|
||||
//
|
||||
// Redis 8.4+. See https://redis.io/commands/set/
|
||||
//
|
||||
// NOTE SetIFNEQGet is still experimental
|
||||
// it's signature and behaviour may change
|
||||
func (c cmdable) SetIFDEQGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd {
|
||||
args := []interface{}{"set", key, value}
|
||||
|
||||
if expiration > 0 {
|
||||
if usePrecise(expiration) {
|
||||
args = append(args, "px", formatMs(ctx, expiration))
|
||||
} else {
|
||||
args = append(args, "ex", formatSec(ctx, expiration))
|
||||
}
|
||||
} else if expiration == KeepTTL {
|
||||
args = append(args, "keepttl")
|
||||
}
|
||||
|
||||
args = append(args, "ifdeq", fmt.Sprintf("%016x", matchDigest), "get")
|
||||
|
||||
cmd := NewStringCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// SetIFDNE sets the value only if the current value's digest does NOT equal matchDigest.
|
||||
//
|
||||
// This is a compare-and-set operation using xxh3 digest for optimistic locking.
|
||||
// The matchDigest parameter is a uint64 xxh3 hash value.
|
||||
//
|
||||
// Returns "OK" on success (digest didn't match, value was set).
|
||||
// Returns redis.Nil if the digest matches (value was not modified).
|
||||
// Zero expiration means the key has no expiration time.
|
||||
//
|
||||
// For examples of client-side digest generation and usage patterns, see:
|
||||
// example/digest-optimistic-locking/
|
||||
//
|
||||
// Redis 8.4+. See https://redis.io/commands/set/
|
||||
//
|
||||
// NOTE SetIFDNE is still experimental
|
||||
// it's signature and behaviour may change
|
||||
func (c cmdable) SetIFDNE(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd {
|
||||
args := []interface{}{"set", key, value}
|
||||
|
||||
if expiration > 0 {
|
||||
if usePrecise(expiration) {
|
||||
args = append(args, "px", formatMs(ctx, expiration))
|
||||
} else {
|
||||
args = append(args, "ex", formatSec(ctx, expiration))
|
||||
}
|
||||
} else if expiration == KeepTTL {
|
||||
args = append(args, "keepttl")
|
||||
}
|
||||
|
||||
args = append(args, "ifdne", fmt.Sprintf("%016x", matchDigest))
|
||||
|
||||
cmd := NewStatusCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// SetIFDNEGet sets the value only if the current value's digest does NOT equal matchDigest,
|
||||
// and returns the previous value.
|
||||
//
|
||||
// This is a compare-and-set operation using xxh3 digest for optimistic locking.
|
||||
// The matchDigest parameter is a uint64 xxh3 hash value.
|
||||
//
|
||||
// Returns the previous value on success (digest didn't match, value was set).
|
||||
// Returns redis.Nil if the digest matches (value was not modified).
|
||||
// Zero expiration means the key has no expiration time.
|
||||
//
|
||||
// For examples of client-side digest generation and usage patterns, see:
|
||||
// example/digest-optimistic-locking/
|
||||
//
|
||||
// Redis 8.4+. See https://redis.io/commands/set/
|
||||
//
|
||||
// NOTE SetIFDNEGet is still experimental
|
||||
// it's signature and behaviour may change
|
||||
func (c cmdable) SetIFDNEGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd {
|
||||
args := []interface{}{"set", key, value}
|
||||
|
||||
if expiration > 0 {
|
||||
if usePrecise(expiration) {
|
||||
args = append(args, "px", formatMs(ctx, expiration))
|
||||
} else {
|
||||
args = append(args, "ex", formatSec(ctx, expiration))
|
||||
}
|
||||
} else if expiration == KeepTTL {
|
||||
args = append(args, "keepttl")
|
||||
}
|
||||
|
||||
args = append(args, "ifdne", fmt.Sprintf("%016x", matchDigest), "get")
|
||||
|
||||
cmd := NewStringCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) SetRange(ctx context.Context, key string, offset int64, value string) *IntCmd {
|
||||
cmd := NewIntCmd(ctx, "setrange", key, offset, value)
|
||||
_ = c(ctx, cmd)
|
||||
|
||||
+4
-3
@@ -24,9 +24,10 @@ type Tx struct {
|
||||
func (c *Client) newTx() *Tx {
|
||||
tx := Tx{
|
||||
baseClient: baseClient{
|
||||
opt: c.opt,
|
||||
connPool: pool.NewStickyConnPool(c.connPool),
|
||||
hooksMixin: c.hooksMixin.clone(),
|
||||
opt: c.opt.clone(), // Clone options to avoid sharing mutable state between transaction and parent client
|
||||
connPool: pool.NewStickyConnPool(c.connPool),
|
||||
hooksMixin: c.hooksMixin.clone(),
|
||||
pushProcessor: c.pushProcessor, // Copy push processor from parent client
|
||||
},
|
||||
}
|
||||
tx.init()
|
||||
|
||||
+16
-9
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
)
|
||||
|
||||
// UniversalOptions information is required by UniversalClient to establish
|
||||
@@ -122,6 +123,9 @@ type UniversalOptions struct {
|
||||
|
||||
// IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint).
|
||||
IsClusterMode bool
|
||||
|
||||
// MaintNotificationsConfig provides configuration for maintnotifications upgrades.
|
||||
MaintNotificationsConfig *maintnotifications.Config
|
||||
}
|
||||
|
||||
// Cluster returns cluster options created from the universal options.
|
||||
@@ -172,11 +176,12 @@ func (o *UniversalOptions) Cluster() *ClusterOptions {
|
||||
|
||||
TLSConfig: o.TLSConfig,
|
||||
|
||||
DisableIdentity: o.DisableIdentity,
|
||||
DisableIndentity: o.DisableIndentity,
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
DisableIdentity: o.DisableIdentity,
|
||||
DisableIndentity: o.DisableIndentity,
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
FailingTimeoutSeconds: o.FailingTimeoutSeconds,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
MaintNotificationsConfig: o.MaintNotificationsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,6 +242,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions {
|
||||
DisableIndentity: o.DisableIndentity,
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
// Note: MaintNotificationsConfig not supported for FailoverOptions
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,10 +290,11 @@ func (o *UniversalOptions) Simple() *Options {
|
||||
|
||||
TLSConfig: o.TLSConfig,
|
||||
|
||||
DisableIdentity: o.DisableIdentity,
|
||||
DisableIndentity: o.DisableIndentity,
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
DisableIdentity: o.DisableIdentity,
|
||||
DisableIndentity: o.DisableIndentity,
|
||||
IdentitySuffix: o.IdentitySuffix,
|
||||
UnstableResp3: o.UnstableResp3,
|
||||
MaintNotificationsConfig: o.MaintNotificationsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+11
@@ -26,6 +26,7 @@ type VectorSetCmdable interface {
|
||||
VSimWithScores(ctx context.Context, key string, val Vector) *VectorScoreSliceCmd
|
||||
VSimWithArgs(ctx context.Context, key string, val Vector, args *VSimArgs) *StringSliceCmd
|
||||
VSimWithArgsWithScores(ctx context.Context, key string, val Vector, args *VSimArgs) *VectorScoreSliceCmd
|
||||
VRange(ctx context.Context, key, start, end string, count int64) *StringSliceCmd
|
||||
}
|
||||
|
||||
type Vector interface {
|
||||
@@ -345,3 +346,13 @@ func (c cmdable) VSimWithArgsWithScores(ctx context.Context, key string, val Vec
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// `VRANGE key start end count`
|
||||
// a negative count means to return all the elements in the vector set.
|
||||
// note: the API is experimental and may be subject to change.
|
||||
func (c cmdable) VRange(ctx context.Context, key, start, end string, count int64) *StringSliceCmd {
|
||||
args := []any{"vrange", key, start, end, count}
|
||||
cmd := NewStringSliceCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
+1
-1
@@ -2,5 +2,5 @@ package redis
|
||||
|
||||
// Version is the current release version.
|
||||
func Version() string {
|
||||
return "9.14.0"
|
||||
return "9.17.2"
|
||||
}
|
||||
|
||||
Vendored
+6
-1
@@ -30,17 +30,22 @@ github.com/gorilla/sessions
|
||||
# github.com/pmezard/go-difflib v1.0.0
|
||||
## explicit
|
||||
github.com/pmezard/go-difflib/difflib
|
||||
# github.com/redis/go-redis/v9 v9.14.0
|
||||
# github.com/redis/go-redis/v9 v9.17.2
|
||||
## explicit; go 1.18
|
||||
github.com/redis/go-redis/v9
|
||||
github.com/redis/go-redis/v9/auth
|
||||
github.com/redis/go-redis/v9/internal
|
||||
github.com/redis/go-redis/v9/internal/auth/streaming
|
||||
github.com/redis/go-redis/v9/internal/hashtag
|
||||
github.com/redis/go-redis/v9/internal/hscan
|
||||
github.com/redis/go-redis/v9/internal/interfaces
|
||||
github.com/redis/go-redis/v9/internal/maintnotifications/logs
|
||||
github.com/redis/go-redis/v9/internal/pool
|
||||
github.com/redis/go-redis/v9/internal/proto
|
||||
github.com/redis/go-redis/v9/internal/rand
|
||||
github.com/redis/go-redis/v9/internal/util
|
||||
github.com/redis/go-redis/v9/maintnotifications
|
||||
github.com/redis/go-redis/v9/push
|
||||
# github.com/stretchr/testify v1.10.0
|
||||
## explicit; go 1.17
|
||||
github.com/stretchr/testify/assert
|
||||
|
||||
Reference in New Issue
Block a user