diff --git a/README.md b/README.md index f3b5cf8..92aafdf 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md). | `cookiePrefix` | `_oidc_raczylo_` | Unique prefix per middleware instance to isolate sessions. | | `sessionMaxAge` | `86400` | Session lifetime in seconds. | | `refreshGracePeriodSeconds` | `60` | Proactively refresh tokens this many seconds before expiry. | +| `maxRefreshTokenAgeSeconds` | `21600` | Heuristic max stored refresh-token lifetime (6h). Past this, the plugin treats the RT as expired without contacting the IdP — returns 401 to AJAX, full re-auth on navigations. Set `0` to disable. Tune to match your IdP's RT TTL. | | `rateLimit` | `100` | Requests/sec. Min `10`. | | `logLevel` | `info` | `debug`, `info`, `error`. | | `audience` | `clientID` | Custom access-token audience (Auth0 custom APIs). | @@ -165,6 +166,22 @@ Each instance must use a unique `cookiePrefix` **and** `sessionEncryptionKey`, otherwise a session minted by one instance can grant access through another. See [issue #87](https://github.com/lukaszraczylo/traefikoidc/issues/87). +### SSE and WebSocket endpoints + +Browser clients cannot follow an OIDC `302` redirect on an SSE stream or a +WebSocket upgrade. The middleware handles this automatically: + +- **SSE** (`Accept: text/event-stream`) and **WebSocket** (`Upgrade: websocket`) + requests skip the OIDC redirect. +- They are **not** unauthenticated — a valid encrypted session cookie is + required, otherwise the request is rejected. The session must already exist + (i.e. the user logged in via a normal HTTP page first). +- `X-Forwarded-User` is forwarded from the session. +- Validation is cookie-only (no JWK fetch), so streaming keeps working during + brief IdP outages. + +No configuration needed — this is implicit behavior. + ### HTTP 431 from backends Either the ID token or the chunked OIDC cookies overflow your backend's header diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 4298a24..fa516aa 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -70,6 +70,33 @@ overwrite it). Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local dev). Otherwise leave it at default. +### Streaming Endpoints (SSE and WebSocket) + +The middleware automatically bypasses the OIDC redirect for two request kinds +that browsers cannot follow a 302 on: + +| Bypass | Triggered by | +|--------|--------------| +| Server-Sent Events (SSE) | `Accept: text/event-stream` | +| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) | + +These requests do **not** require any explicit configuration — they are +handled implicitly. However, the bypass is **not** unauthenticated: + +- A valid, encrypted session cookie is required. Requests without one are + rejected (the connection cannot proceed to the backend). +- The session cookie is sealed with `sessionEncryptionKey`, so the + `authenticated` flag cannot be forged. +- Validation is cookie-only — no JWK fetch / signature verification — so + streaming endpoints keep working when the OIDC provider is briefly + unavailable. +- The user identifier from the session is forwarded as `X-Forwarded-User` + (and `X-Auth-Request-User` unless `minimalHeaders: true`). + +For browser clients, the user must complete the normal OIDC flow on a +regular HTTP page first; the resulting session cookie is then reused on the +SSE / WebSocket connection. + --- ## Security Options @@ -113,6 +140,7 @@ strictAudienceValidation: true |-----------|------|---------|-------------| | `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds | | `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh | +| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). | | `cookieDomain` | string | auto-detected | Domain for session cookies | | `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names | diff --git a/docs/index.html b/docs/index.html index 33a9bf1..92267d5 100644 --- a/docs/index.html +++ b/docs/index.html @@ -718,6 +718,11 @@ spec: 86400 Maximum session age in seconds (24 hours default) + + maxRefreshTokenAgeSeconds + 21600 + Heuristic upper bound on stored refresh-token lifetime (6 hours default). Past this, the plugin treats the RT as expired without contacting the IdP. Set 0 to disable. + cookiePrefix _oidc_raczylo_ @@ -858,7 +863,12 @@ spec: redis.enableTLS false - Enable TLS for Redis connections + Enable TLS for Redis connections (e.g. AWS ElastiCache in-transit encryption) + + + redis.tlsSkipVerify + false + Skip TLS server certificate verification (testing only; not recommended in production) diff --git a/internal/cache/backends/config.go b/internal/cache/backends/config.go index 6f60188..f33daf4 100644 --- a/internal/cache/backends/config.go +++ b/internal/cache/backends/config.go @@ -24,6 +24,7 @@ type Config struct { Type BackendType RedisAddr string RedisPassword string + TLSServerName string PoolSize int RedisDB int CleanupInterval time.Duration @@ -34,6 +35,8 @@ type Config struct { EnableCircuitBreaker bool EnableHealthCheck bool EnableMetrics bool + EnableTLS bool + TLSSkipVerify bool } // DefaultConfig returns a default configuration for in-memory caching diff --git a/internal/cache/backends/redis.go b/internal/cache/backends/redis.go index d8456f1..a27f813 100644 --- a/internal/cache/backends/redis.go +++ b/internal/cache/backends/redis.go @@ -49,6 +49,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) { poolConfig := &PoolConfig{ Address: config.RedisAddr, Password: config.RedisPassword, + TLSServerName: config.TLSServerName, DB: config.RedisDB, MaxConnections: config.PoolSize, ConnectTimeout: 2 * time.Second, @@ -57,6 +58,8 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) { EnableHealthCheck: true, MaxRetries: 3, RetryDelay: 100 * time.Millisecond, + EnableTLS: config.EnableTLS, + TLSSkipVerify: config.TLSSkipVerify, } pool, err := NewConnectionPool(poolConfig) diff --git a/internal/cache/backends/redis_pool.go b/internal/cache/backends/redis_pool.go index 16b79d0..994b1ef 100644 --- a/internal/cache/backends/redis_pool.go +++ b/internal/cache/backends/redis_pool.go @@ -2,6 +2,7 @@ package backends import ( "context" + "crypto/tls" "errors" "fmt" "net" @@ -31,6 +32,7 @@ type ConnectionPool struct { type PoolConfig struct { Address string Password string + TLSServerName string // SNI server name; defaults to host(Address) when empty DB int MaxConnections int ConnectTimeout time.Duration @@ -39,6 +41,8 @@ type PoolConfig struct { EnableHealthCheck bool // Enable connection health validation MaxRetries int // Max retries for failed operations RetryDelay time.Duration // Initial delay between retries + EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption) + TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended) } // NewConnectionPool creates a new connection pool @@ -96,7 +100,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) { // No available connection, create new one if under limit // #nosec G115 -- MaxConnections is a small config value that fits in int32 if p.totalConns.Load() < int32(p.config.MaxConnections) { - conn, err = p.createConnection() + conn, err = p.createConnection(ctx) if err != nil { // If this is the last attempt, return error if attempt == maxAttempts-1 { @@ -193,13 +197,31 @@ func (p *ConnectionPool) Stats() map[string]interface{} { } // createConnection creates a new Redis connection -func (p *ConnectionPool) createConnection() (*RedisConn, error) { +func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) { // Connect with timeout dialer := &net.Dialer{ Timeout: p.config.ConnectTimeout, } - conn, err := dialer.Dial("tcp", p.config.Address) + var conn net.Conn + var err error + if p.config.EnableTLS { + serverName := p.config.TLSServerName + if serverName == "" { + if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil { + serverName = host + } + } + tlsCfg := &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config + MinVersion: tls.VersionTLS12, + } + tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg} + conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address) + } else { + conn, err = dialer.DialContext(ctx, "tcp", p.config.Address) + } if err != nil { return nil, fmt.Errorf("failed to connect to Redis: %w", err) } diff --git a/internal/cache/backends/redis_pool_tls_test.go b/internal/cache/backends/redis_pool_tls_test.go new file mode 100644 index 0000000..d3d3aa6 --- /dev/null +++ b/internal/cache/backends/redis_pool_tls_test.go @@ -0,0 +1,230 @@ +package backends + +import ( + "bufio" + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// drainRESPRequest consumes a single RESP request (array or inline) from r and +// returns true on success. Any read error returns false. +func drainRESPRequest(r *bufio.Reader) bool { + header, err := r.ReadString('\n') + if err != nil { + return false + } + if !strings.HasPrefix(header, "*") { + return true // inline command (single line) — already consumed + } + n, err := strconv.Atoi(strings.TrimRight(strings.TrimPrefix(header, "*"), "\r\n")) + if err != nil || n <= 0 { + return false + } + for i := 0; i < n; i++ { + // Each bulk: "$len\r\n\r\n" + if _, err := r.ReadString('\n'); err != nil { + return false + } + if _, err := r.ReadString('\n'); err != nil { + return false + } + } + return true +} + +// startTLSPingServer spins up a TLS listener that speaks just enough RESP to +// answer PING with +PONG. Returns the listener address and a self-signed cert. +func startTLSPingServer(t *testing.T) (addr string, certPEM []byte, stop func()) { + t.Helper() + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "localhost"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + require.NoError(t, err) + + tlsCert := tls.Certificate{ + Certificate: [][]byte{der}, + PrivateKey: priv, + } + + listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + MinVersion: tls.VersionTLS12, + }) + require.NoError(t, err) + + var wg sync.WaitGroup + stopCh := make(chan struct{}) + + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stopCh: + return + default: + } + c, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + wg.Add(1) + go func(conn net.Conn) { + defer wg.Done() + defer conn.Close() + reader := bufio.NewReader(conn) + for { + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + if !drainRESPRequest(reader) { + return + } + _, _ = conn.Write([]byte("+PONG\r\n")) + } + }(c) + } + }() + + stop = func() { + close(stopCh) + _ = listener.Close() + wg.Wait() + } + return listener.Addr().String(), der, stop +} + +// TestConnectionPool_TLSDial_SkipVerify verifies that EnableTLS=true with +// TLSSkipVerify=true successfully negotiates TLS and exchanges a Redis command. +// Regression test for issue #133 (enableTLS not propagated to client). +func TestConnectionPool_TLSDial_SkipVerify(t *testing.T) { + addr, _, stop := startTLSPingServer(t) + defer stop() + + pool, err := NewConnectionPool(&PoolConfig{ + Address: addr, + MaxConnections: 2, + ConnectTimeout: 2 * time.Second, + ReadTimeout: 1 * time.Second, + WriteTimeout: 1 * time.Second, + EnableTLS: true, + TLSSkipVerify: true, + }) + require.NoError(t, err) + defer pool.Close() + + conn, err := pool.Get(context.Background()) + require.NoError(t, err) + require.NotNil(t, conn) + defer pool.Put(conn) + + resp, err := conn.Do("PING") + require.NoError(t, err) + assert.Equal(t, "PONG", resp) +} + +// TestConnectionPool_TLSDial_VerifyFails verifies that EnableTLS=true with +// TLSSkipVerify=false rejects a self-signed server cert. +func TestConnectionPool_TLSDial_VerifyFails(t *testing.T) { + addr, _, stop := startTLSPingServer(t) + defer stop() + + pool, err := NewConnectionPool(&PoolConfig{ + Address: addr, + MaxConnections: 2, + ConnectTimeout: 2 * time.Second, + ReadTimeout: 1 * time.Second, + WriteTimeout: 1 * time.Second, + EnableTLS: true, + TLSSkipVerify: false, + }) + require.NoError(t, err) + defer pool.Close() + + _, err = pool.Get(context.Background()) + require.Error(t, err) + assert.Contains(t, strings.ToLower(err.Error()), "tls") +} + +// TestConnectionPool_TLSDial_PlainServerRejected verifies that EnableTLS=true +// fails to handshake against a plain (non-TLS) listener. +func TestConnectionPool_TLSDial_PlainServerRejected(t *testing.T) { + plain, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer plain.Close() + + go func() { + for { + c, acceptErr := plain.Accept() + if acceptErr != nil { + return + } + _ = c.Close() + } + }() + + pool, err := NewConnectionPool(&PoolConfig{ + Address: plain.Addr().String(), + MaxConnections: 1, + ConnectTimeout: 1 * time.Second, + ReadTimeout: 1 * time.Second, + WriteTimeout: 1 * time.Second, + EnableTLS: true, + TLSSkipVerify: true, + }) + require.NoError(t, err) + defer pool.Close() + + _, err = pool.Get(context.Background()) + require.Error(t, err) +} + +// TestConnectionPool_PlainDial_StillWorks ensures non-TLS path is unaffected +// when EnableTLS=false (default). +func TestConnectionPool_PlainDial_StillWorks(t *testing.T) { + mr := NewMiniredisServer(t) + + pool, err := NewConnectionPool(&PoolConfig{ + Address: mr.GetAddr(), + MaxConnections: 1, + ConnectTimeout: 2 * time.Second, + ReadTimeout: 1 * time.Second, + WriteTimeout: 1 * time.Second, + EnableTLS: false, + }) + require.NoError(t, err) + defer pool.Close() + + conn, err := pool.Get(context.Background()) + require.NoError(t, err) + defer pool.Put(conn) + + resp, err := conn.Do("PING") + require.NoError(t, err) + assert.Equal(t, "PONG", resp) +} diff --git a/universal_cache_singleton.go b/universal_cache_singleton.go index 51aa644..0081b8b 100644 --- a/universal_cache_singleton.go +++ b/universal_cache_singleton.go @@ -210,6 +210,8 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r RedisPrefix: redisConfig.KeyPrefix, PoolSize: redisConfig.PoolSize, EnableMetrics: true, + EnableTLS: redisConfig.EnableTLS, + TLSSkipVerify: redisConfig.TLSSkipVerify, } // Use concrete type to avoid Yaegi reflection issues with interface assignment