Files
traefikoidc/internal/cache/backends/redis_pool.go
T
lukaszraczylo 68c150eba4 fix(cache/redis): honor enableTLS for Redis backend (#133)
The redis.enableTLS / redis.tlsSkipVerify settings were accepted by the
config layer but silently dropped before reaching the connection pool, so
the plugin always dialed Redis in plaintext. This blocked TLS-only Redis
deployments such as AWS ElastiCache with in-transit encryption.

- Add EnableTLS, TLSSkipVerify, TLSServerName to backends.Config and
  PoolConfig and forward them through universal_cache_singleton ->
  backends.Config -> PoolConfig.
- In the connection pool, dial via tls.Dialer.DialContext (TLS 1.2
  minimum) with SNI defaulting to the host part of the configured
  Address when TLSServerName is empty, so ElastiCache cluster endpoints
  validate out of the box. Plain dial path now also propagates ctx.
- Add regression tests covering successful TLS negotiation with skip-
  verify, rejection of self-signed certs without skip-verify, rejection
  of plain TCP servers when EnableTLS=true, and unaffected plaintext
  behavior.
- Document maxRefreshTokenAgeSeconds (added in 1b6c861) and the implicit
  SSE / WebSocket auth bypass (added in 684a990) in README.md,
  docs/CONFIGURATION.md and docs/index.html.
- Add the missing redis.tlsSkipVerify row to docs/index.html and clarify
  the redis.enableTLS description.

patch-release
2026-05-07 12:24:13 +01:00

478 lines
12 KiB
Go

package backends
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
)
// ConnectionPool manages a pool of Redis connections
// Pure-Go implementation compatible with Yaegi
type ConnectionPool struct {
config *PoolConfig
connections chan *RedisConn
mu sync.Mutex
closed atomic.Bool
// Metrics
activeConns atomic.Int32
totalConns atomic.Int32
gets atomic.Int64
puts atomic.Int64
timeouts atomic.Int64
}
// PoolConfig holds connection pool configuration
type PoolConfig struct {
Address string
Password string
TLSServerName string // SNI server name; defaults to host(Address) when empty
DB int
MaxConnections int
ConnectTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
EnableHealthCheck bool // Enable connection health validation
MaxRetries int // Max retries for failed operations
RetryDelay time.Duration // Initial delay between retries
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
}
// NewConnectionPool creates a new connection pool
func NewConnectionPool(config *PoolConfig) (*ConnectionPool, error) {
if config == nil {
return nil, errors.New("config is required")
}
if config.MaxConnections <= 0 {
config.MaxConnections = 10
}
if config.ConnectTimeout == 0 {
config.ConnectTimeout = 5 * time.Second
}
pool := &ConnectionPool{
config: config,
connections: make(chan *RedisConn, config.MaxConnections),
}
return pool, nil
}
// Get retrieves a connection from the pool or creates a new one
func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
if p.closed.Load() {
return nil, ErrBackendClosed
}
p.gets.Add(1)
// Try to get a connection with validation
maxAttempts := 3
for attempt := 0; attempt < maxAttempts; attempt++ {
var conn *RedisConn
var err error
select {
case conn = <-p.connections:
// Reuse existing connection - validate if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
// Connection is stale, close it and try again
_ = conn.Close()
p.totalConns.Add(-1)
continue
}
p.activeConns.Add(1)
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
default:
// No available connection, create new one if under limit
// #nosec G115 -- MaxConnections is a small config value that fits in int32
if p.totalConns.Load() < int32(p.config.MaxConnections) {
conn, err = p.createConnection(ctx)
if err != nil {
// If this is the last attempt, return error
if attempt == maxAttempts-1 {
return nil, err
}
// Wait before retry with exponential backoff
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
continue
}
p.activeConns.Add(1)
p.totalConns.Add(1)
return conn, nil
}
// Pool exhausted, wait for a connection with timeout
select {
case conn = <-p.connections:
// Validate connection if health check enabled
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
_ = conn.Close()
p.totalConns.Add(-1)
continue
}
p.activeConns.Add(1)
return conn, nil
case <-ctx.Done():
p.timeouts.Add(1)
return nil, ctx.Err()
case <-time.After(p.config.ConnectTimeout):
p.timeouts.Add(1)
return nil, ErrPoolExhausted
}
}
}
return nil, errors.New("failed to get healthy connection after retries")
}
// Put returns a connection to the pool
func (p *ConnectionPool) Put(conn *RedisConn) {
if conn == nil {
return
}
p.puts.Add(1)
p.activeConns.Add(-1)
if p.closed.Load() || conn.closed.Load() {
_ = conn.Close()
p.totalConns.Add(-1)
return
}
// Return to pool (non-blocking)
select {
case p.connections <- conn:
// Successfully returned to pool
default:
// Pool full, close connection
_ = conn.Close()
p.totalConns.Add(-1)
}
}
// Close closes all connections in the pool
func (p *ConnectionPool) Close() error {
if p.closed.Swap(true) {
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
close(p.connections)
// Close all pooled connections
for conn := range p.connections {
_ = conn.Close()
}
return nil
}
// Stats returns pool statistics
func (p *ConnectionPool) Stats() map[string]interface{} {
return map[string]interface{}{
"active_connections": p.activeConns.Load(),
"total_connections": p.totalConns.Load(),
"max_connections": p.config.MaxConnections,
"gets": p.gets.Load(),
"puts": p.puts.Load(),
"timeouts": p.timeouts.Load(),
}
}
// createConnection creates a new Redis connection
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
// Connect with timeout
dialer := &net.Dialer{
Timeout: p.config.ConnectTimeout,
}
var conn net.Conn
var err error
if p.config.EnableTLS {
serverName := p.config.TLSServerName
if serverName == "" {
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
serverName = host
}
}
tlsCfg := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
MinVersion: tls.VersionTLS12,
}
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
} else {
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
redisConn := &RedisConn{
conn: conn,
readTimeout: p.config.ReadTimeout,
writeTimeout: p.config.WriteTimeout,
}
// Authenticate if password is provided
if p.config.Password != "" {
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
_ = redisConn.Close()
return nil, fmt.Errorf("authentication failed: %w", err)
}
}
// Select database
if p.config.DB != 0 {
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
_ = redisConn.Close()
return nil, fmt.Errorf("failed to select database: %w", err)
}
}
return redisConn, nil
}
// RedisConn represents a single Redis connection
type RedisConn struct {
conn net.Conn
readTimeout time.Duration
writeTimeout time.Duration
closed atomic.Bool
mu sync.Mutex
}
// Do executes a Redis command and returns the response
func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
if c.closed.Load() {
return nil, ErrBackendClosed
}
c.mu.Lock()
defer c.mu.Unlock()
// Validate argument count to prevent integer overflow in slice operations
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
const maxSafeArgs = (1 << 20) - 1
if len(args) > maxSafeArgs {
return nil, errors.New("too many arguments: exceeds maximum safe count")
}
// Build command arguments
// Validate total argument size to prevent memory exhaustion
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
totalBytes := len(command)
for _, s := range args {
// Protect against possible overflow
if len(s) > maxTotalArgBytes-totalBytes {
return nil, errors.New("arguments too large (would overflow maximum allowed total size)")
}
totalBytes += len(s)
if totalBytes > maxTotalArgBytes {
return nil, errors.New("total argument size exceeds maximum allowed")
}
}
// Build command slice: prepend command to args
// Using append avoids arithmetic on potentially large len(args)
cmdArgs := append([]string{command}, args...)
// Set write timeout
if c.writeTimeout > 0 {
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
// Write command (using pooled writer for memory efficiency)
writer := NewRESPWriter(c.conn)
err := writer.WriteCommand(cmdArgs...)
writer.Release() // Return to pool immediately after use
if err != nil {
c.closed.Store(true)
return nil, err
}
// Set read timeout
if c.readTimeout > 0 {
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
// Read response (using pooled reader for memory efficiency)
reader := NewRESPReader(c.conn)
resp, err := reader.ReadResponse()
reader.Release() // Return to pool immediately after use
if err != nil {
if !errors.Is(err, ErrNilResponse) {
c.closed.Store(true)
}
return nil, err
}
return resp, nil
}
// Close closes the connection
func (c *RedisConn) Close() error {
if c.closed.Swap(true) {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
return c.conn.Close()
}
return nil
}
// isConnectionHealthy validates a connection is still working
func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
if conn == nil || conn.closed.Load() {
return false
}
// Set a read deadline for the ping
if conn.conn != nil {
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
}
_, err := conn.Do("PING")
return err == nil
}
// Pipeline represents a Redis pipeline for batch operations
// It queues multiple commands and executes them in a single round-trip
type Pipeline struct {
conn *RedisConn
commands []pipelineCommand
mu sync.Mutex
}
// pipelineCommand represents a single command in the pipeline
type pipelineCommand struct {
command string
args []string
}
// NewPipeline creates a new pipeline for the connection
func (c *RedisConn) NewPipeline() *Pipeline {
return &Pipeline{
conn: c,
commands: make([]pipelineCommand, 0, 16), // Pre-allocate for typical batch size
}
}
// Queue adds a command to the pipeline
func (p *Pipeline) Queue(command string, args ...string) {
p.mu.Lock()
defer p.mu.Unlock()
p.commands = append(p.commands, pipelineCommand{
command: command,
args: args,
})
}
// Execute sends all queued commands and returns all responses
// Returns a slice of responses in the same order as commands were queued
func (p *Pipeline) Execute() ([]interface{}, error) {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.commands) == 0 {
return nil, nil
}
if p.conn.closed.Load() {
return nil, ErrBackendClosed
}
p.conn.mu.Lock()
defer p.conn.mu.Unlock()
// Set write timeout for all commands
if p.conn.writeTimeout > 0 {
// Use longer timeout for batch operations
timeout := p.conn.writeTimeout * time.Duration(len(p.commands))
if timeout > 30*time.Second {
timeout = 30 * time.Second // Cap at 30 seconds
}
_ = p.conn.conn.SetWriteDeadline(time.Now().Add(timeout))
}
// Write all commands (pipelining - send all before reading any responses)
writer := NewRESPWriter(p.conn.conn)
for _, cmd := range p.commands {
cmdArgs := append([]string{cmd.command}, cmd.args...)
if err := writer.WriteCommand(cmdArgs...); err != nil {
writer.Release()
p.conn.closed.Store(true)
return nil, fmt.Errorf("pipeline write error: %w", err)
}
}
writer.Release()
// Set read timeout for all responses
if p.conn.readTimeout > 0 {
timeout := p.conn.readTimeout * time.Duration(len(p.commands))
if timeout > 30*time.Second {
timeout = 30 * time.Second
}
_ = p.conn.conn.SetReadDeadline(time.Now().Add(timeout))
}
// Read all responses
responses := make([]interface{}, len(p.commands))
reader := NewRESPReader(p.conn.conn)
defer reader.Release()
for i := range p.commands {
resp, err := reader.ReadResponse()
if err != nil {
// For nil responses, store nil instead of erroring
if errors.Is(err, ErrNilResponse) {
responses[i] = nil
continue
}
p.conn.closed.Store(true)
return responses[:i], fmt.Errorf("pipeline read error at command %d: %w", i, err)
}
responses[i] = resp
}
return responses, nil
}
// Clear resets the pipeline for reuse
func (p *Pipeline) Clear() {
p.mu.Lock()
defer p.mu.Unlock()
p.commands = p.commands[:0]
}
// Len returns the number of queued commands
func (p *Pipeline) Len() int {
p.mu.Lock()
defer p.mu.Unlock()
return len(p.commands)
}