mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
68c150eba4
The redis.enableTLS / redis.tlsSkipVerify settings were accepted by the config layer but silently dropped before reaching the connection pool, so the plugin always dialed Redis in plaintext. This blocked TLS-only Redis deployments such as AWS ElastiCache with in-transit encryption. - Add EnableTLS, TLSSkipVerify, TLSServerName to backends.Config and PoolConfig and forward them through universal_cache_singleton -> backends.Config -> PoolConfig. - In the connection pool, dial via tls.Dialer.DialContext (TLS 1.2 minimum) with SNI defaulting to the host part of the configured Address when TLSServerName is empty, so ElastiCache cluster endpoints validate out of the box. Plain dial path now also propagates ctx. - Add regression tests covering successful TLS negotiation with skip- verify, rejection of self-signed certs without skip-verify, rejection of plain TCP servers when EnableTLS=true, and unaffected plaintext behavior. - Document maxRefreshTokenAgeSeconds (added in1b6c861) and the implicit SSE / WebSocket auth bypass (added in684a990) in README.md, docs/CONFIGURATION.md and docs/index.html. - Add the missing redis.tlsSkipVerify row to docs/index.html and clarify the redis.enableTLS description. patch-release
478 lines
12 KiB
Go
478 lines
12 KiB
Go
package backends
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// ConnectionPool manages a pool of Redis connections
|
|
// Pure-Go implementation compatible with Yaegi
|
|
type ConnectionPool struct {
|
|
config *PoolConfig
|
|
|
|
connections chan *RedisConn
|
|
mu sync.Mutex
|
|
closed atomic.Bool
|
|
|
|
// Metrics
|
|
activeConns atomic.Int32
|
|
totalConns atomic.Int32
|
|
gets atomic.Int64
|
|
puts atomic.Int64
|
|
timeouts atomic.Int64
|
|
}
|
|
|
|
// PoolConfig holds connection pool configuration
|
|
type PoolConfig struct {
|
|
Address string
|
|
Password string
|
|
TLSServerName string // SNI server name; defaults to host(Address) when empty
|
|
DB int
|
|
MaxConnections int
|
|
ConnectTimeout time.Duration
|
|
ReadTimeout time.Duration
|
|
WriteTimeout time.Duration
|
|
EnableHealthCheck bool // Enable connection health validation
|
|
MaxRetries int // Max retries for failed operations
|
|
RetryDelay time.Duration // Initial delay between retries
|
|
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
|
|
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
|
|
}
|
|
|
|
// NewConnectionPool creates a new connection pool
|
|
func NewConnectionPool(config *PoolConfig) (*ConnectionPool, error) {
|
|
if config == nil {
|
|
return nil, errors.New("config is required")
|
|
}
|
|
|
|
if config.MaxConnections <= 0 {
|
|
config.MaxConnections = 10
|
|
}
|
|
|
|
if config.ConnectTimeout == 0 {
|
|
config.ConnectTimeout = 5 * time.Second
|
|
}
|
|
|
|
pool := &ConnectionPool{
|
|
config: config,
|
|
connections: make(chan *RedisConn, config.MaxConnections),
|
|
}
|
|
|
|
return pool, nil
|
|
}
|
|
|
|
// Get retrieves a connection from the pool or creates a new one
|
|
func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
|
if p.closed.Load() {
|
|
return nil, ErrBackendClosed
|
|
}
|
|
|
|
p.gets.Add(1)
|
|
|
|
// Try to get a connection with validation
|
|
maxAttempts := 3
|
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
|
var conn *RedisConn
|
|
var err error
|
|
|
|
select {
|
|
case conn = <-p.connections:
|
|
// Reuse existing connection - validate if health check enabled
|
|
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
|
// Connection is stale, close it and try again
|
|
_ = conn.Close()
|
|
p.totalConns.Add(-1)
|
|
continue
|
|
}
|
|
p.activeConns.Add(1)
|
|
return conn, nil
|
|
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
|
|
default:
|
|
// No available connection, create new one if under limit
|
|
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
|
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
|
conn, err = p.createConnection(ctx)
|
|
if err != nil {
|
|
// If this is the last attempt, return error
|
|
if attempt == maxAttempts-1 {
|
|
return nil, err
|
|
}
|
|
// Wait before retry with exponential backoff
|
|
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
|
|
continue
|
|
}
|
|
p.activeConns.Add(1)
|
|
p.totalConns.Add(1)
|
|
return conn, nil
|
|
}
|
|
|
|
// Pool exhausted, wait for a connection with timeout
|
|
select {
|
|
case conn = <-p.connections:
|
|
// Validate connection if health check enabled
|
|
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
|
_ = conn.Close()
|
|
p.totalConns.Add(-1)
|
|
continue
|
|
}
|
|
p.activeConns.Add(1)
|
|
return conn, nil
|
|
case <-ctx.Done():
|
|
p.timeouts.Add(1)
|
|
return nil, ctx.Err()
|
|
case <-time.After(p.config.ConnectTimeout):
|
|
p.timeouts.Add(1)
|
|
return nil, ErrPoolExhausted
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, errors.New("failed to get healthy connection after retries")
|
|
}
|
|
|
|
// Put returns a connection to the pool
|
|
func (p *ConnectionPool) Put(conn *RedisConn) {
|
|
if conn == nil {
|
|
return
|
|
}
|
|
|
|
p.puts.Add(1)
|
|
p.activeConns.Add(-1)
|
|
|
|
if p.closed.Load() || conn.closed.Load() {
|
|
_ = conn.Close()
|
|
p.totalConns.Add(-1)
|
|
return
|
|
}
|
|
|
|
// Return to pool (non-blocking)
|
|
select {
|
|
case p.connections <- conn:
|
|
// Successfully returned to pool
|
|
default:
|
|
// Pool full, close connection
|
|
_ = conn.Close()
|
|
p.totalConns.Add(-1)
|
|
}
|
|
}
|
|
|
|
// Close closes all connections in the pool
|
|
func (p *ConnectionPool) Close() error {
|
|
if p.closed.Swap(true) {
|
|
return nil
|
|
}
|
|
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
close(p.connections)
|
|
|
|
// Close all pooled connections
|
|
for conn := range p.connections {
|
|
_ = conn.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stats returns pool statistics
|
|
func (p *ConnectionPool) Stats() map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"active_connections": p.activeConns.Load(),
|
|
"total_connections": p.totalConns.Load(),
|
|
"max_connections": p.config.MaxConnections,
|
|
"gets": p.gets.Load(),
|
|
"puts": p.puts.Load(),
|
|
"timeouts": p.timeouts.Load(),
|
|
}
|
|
}
|
|
|
|
// createConnection creates a new Redis connection
|
|
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
|
|
// Connect with timeout
|
|
dialer := &net.Dialer{
|
|
Timeout: p.config.ConnectTimeout,
|
|
}
|
|
|
|
var conn net.Conn
|
|
var err error
|
|
if p.config.EnableTLS {
|
|
serverName := p.config.TLSServerName
|
|
if serverName == "" {
|
|
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
|
|
serverName = host
|
|
}
|
|
}
|
|
tlsCfg := &tls.Config{
|
|
ServerName: serverName,
|
|
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
|
|
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
|
|
} else {
|
|
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
|
}
|
|
|
|
redisConn := &RedisConn{
|
|
conn: conn,
|
|
readTimeout: p.config.ReadTimeout,
|
|
writeTimeout: p.config.WriteTimeout,
|
|
}
|
|
|
|
// Authenticate if password is provided
|
|
if p.config.Password != "" {
|
|
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
|
|
_ = redisConn.Close()
|
|
return nil, fmt.Errorf("authentication failed: %w", err)
|
|
}
|
|
}
|
|
|
|
// Select database
|
|
if p.config.DB != 0 {
|
|
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
|
|
_ = redisConn.Close()
|
|
return nil, fmt.Errorf("failed to select database: %w", err)
|
|
}
|
|
}
|
|
|
|
return redisConn, nil
|
|
}
|
|
|
|
// RedisConn represents a single Redis connection
|
|
type RedisConn struct {
|
|
conn net.Conn
|
|
readTimeout time.Duration
|
|
writeTimeout time.Duration
|
|
closed atomic.Bool
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// Do executes a Redis command and returns the response
|
|
func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
|
if c.closed.Load() {
|
|
return nil, ErrBackendClosed
|
|
}
|
|
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
// Validate argument count to prevent integer overflow in slice operations
|
|
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
|
|
const maxSafeArgs = (1 << 20) - 1
|
|
if len(args) > maxSafeArgs {
|
|
return nil, errors.New("too many arguments: exceeds maximum safe count")
|
|
}
|
|
|
|
// Build command arguments
|
|
// Validate total argument size to prevent memory exhaustion
|
|
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
|
|
totalBytes := len(command)
|
|
for _, s := range args {
|
|
// Protect against possible overflow
|
|
if len(s) > maxTotalArgBytes-totalBytes {
|
|
return nil, errors.New("arguments too large (would overflow maximum allowed total size)")
|
|
}
|
|
totalBytes += len(s)
|
|
if totalBytes > maxTotalArgBytes {
|
|
return nil, errors.New("total argument size exceeds maximum allowed")
|
|
}
|
|
}
|
|
// Build command slice: prepend command to args
|
|
// Using append avoids arithmetic on potentially large len(args)
|
|
cmdArgs := append([]string{command}, args...)
|
|
|
|
// Set write timeout
|
|
if c.writeTimeout > 0 {
|
|
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
|
}
|
|
|
|
// Write command (using pooled writer for memory efficiency)
|
|
writer := NewRESPWriter(c.conn)
|
|
err := writer.WriteCommand(cmdArgs...)
|
|
writer.Release() // Return to pool immediately after use
|
|
if err != nil {
|
|
c.closed.Store(true)
|
|
return nil, err
|
|
}
|
|
|
|
// Set read timeout
|
|
if c.readTimeout > 0 {
|
|
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
|
}
|
|
|
|
// Read response (using pooled reader for memory efficiency)
|
|
reader := NewRESPReader(c.conn)
|
|
resp, err := reader.ReadResponse()
|
|
reader.Release() // Return to pool immediately after use
|
|
if err != nil {
|
|
if !errors.Is(err, ErrNilResponse) {
|
|
c.closed.Store(true)
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// Close closes the connection
|
|
func (c *RedisConn) Close() error {
|
|
if c.closed.Swap(true) {
|
|
return nil
|
|
}
|
|
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.conn != nil {
|
|
return c.conn.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// isConnectionHealthy validates a connection is still working
|
|
func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
|
if conn == nil || conn.closed.Load() {
|
|
return false
|
|
}
|
|
|
|
// Set a read deadline for the ping
|
|
if conn.conn != nil {
|
|
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
|
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
|
|
}
|
|
|
|
_, err := conn.Do("PING")
|
|
return err == nil
|
|
}
|
|
|
|
// Pipeline represents a Redis pipeline for batch operations
|
|
// It queues multiple commands and executes them in a single round-trip
|
|
type Pipeline struct {
|
|
conn *RedisConn
|
|
commands []pipelineCommand
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// pipelineCommand represents a single command in the pipeline
|
|
type pipelineCommand struct {
|
|
command string
|
|
args []string
|
|
}
|
|
|
|
// NewPipeline creates a new pipeline for the connection
|
|
func (c *RedisConn) NewPipeline() *Pipeline {
|
|
return &Pipeline{
|
|
conn: c,
|
|
commands: make([]pipelineCommand, 0, 16), // Pre-allocate for typical batch size
|
|
}
|
|
}
|
|
|
|
// Queue adds a command to the pipeline
|
|
func (p *Pipeline) Queue(command string, args ...string) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
p.commands = append(p.commands, pipelineCommand{
|
|
command: command,
|
|
args: args,
|
|
})
|
|
}
|
|
|
|
// Execute sends all queued commands and returns all responses
|
|
// Returns a slice of responses in the same order as commands were queued
|
|
func (p *Pipeline) Execute() ([]interface{}, error) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
if len(p.commands) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
if p.conn.closed.Load() {
|
|
return nil, ErrBackendClosed
|
|
}
|
|
|
|
p.conn.mu.Lock()
|
|
defer p.conn.mu.Unlock()
|
|
|
|
// Set write timeout for all commands
|
|
if p.conn.writeTimeout > 0 {
|
|
// Use longer timeout for batch operations
|
|
timeout := p.conn.writeTimeout * time.Duration(len(p.commands))
|
|
if timeout > 30*time.Second {
|
|
timeout = 30 * time.Second // Cap at 30 seconds
|
|
}
|
|
_ = p.conn.conn.SetWriteDeadline(time.Now().Add(timeout))
|
|
}
|
|
|
|
// Write all commands (pipelining - send all before reading any responses)
|
|
writer := NewRESPWriter(p.conn.conn)
|
|
for _, cmd := range p.commands {
|
|
cmdArgs := append([]string{cmd.command}, cmd.args...)
|
|
if err := writer.WriteCommand(cmdArgs...); err != nil {
|
|
writer.Release()
|
|
p.conn.closed.Store(true)
|
|
return nil, fmt.Errorf("pipeline write error: %w", err)
|
|
}
|
|
}
|
|
writer.Release()
|
|
|
|
// Set read timeout for all responses
|
|
if p.conn.readTimeout > 0 {
|
|
timeout := p.conn.readTimeout * time.Duration(len(p.commands))
|
|
if timeout > 30*time.Second {
|
|
timeout = 30 * time.Second
|
|
}
|
|
_ = p.conn.conn.SetReadDeadline(time.Now().Add(timeout))
|
|
}
|
|
|
|
// Read all responses
|
|
responses := make([]interface{}, len(p.commands))
|
|
reader := NewRESPReader(p.conn.conn)
|
|
defer reader.Release()
|
|
|
|
for i := range p.commands {
|
|
resp, err := reader.ReadResponse()
|
|
if err != nil {
|
|
// For nil responses, store nil instead of erroring
|
|
if errors.Is(err, ErrNilResponse) {
|
|
responses[i] = nil
|
|
continue
|
|
}
|
|
p.conn.closed.Store(true)
|
|
return responses[:i], fmt.Errorf("pipeline read error at command %d: %w", i, err)
|
|
}
|
|
responses[i] = resp
|
|
}
|
|
|
|
return responses, nil
|
|
}
|
|
|
|
// Clear resets the pipeline for reuse
|
|
func (p *Pipeline) Clear() {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
p.commands = p.commands[:0]
|
|
}
|
|
|
|
// Len returns the number of queued commands
|
|
func (p *Pipeline) Len() int {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
return len(p.commands)
|
|
}
|