Compare commits

...

3 Commits

Author SHA1 Message Date
lukaszraczylo 68c150eba4 fix(cache/redis): honor enableTLS for Redis backend (#133)
The redis.enableTLS / redis.tlsSkipVerify settings were accepted by the
config layer but silently dropped before reaching the connection pool, so
the plugin always dialed Redis in plaintext. This blocked TLS-only Redis
deployments such as AWS ElastiCache with in-transit encryption.

- Add EnableTLS, TLSSkipVerify, TLSServerName to backends.Config and
  PoolConfig and forward them through universal_cache_singleton ->
  backends.Config -> PoolConfig.
- In the connection pool, dial via tls.Dialer.DialContext (TLS 1.2
  minimum) with SNI defaulting to the host part of the configured
  Address when TLSServerName is empty, so ElastiCache cluster endpoints
  validate out of the box. Plain dial path now also propagates ctx.
- Add regression tests covering successful TLS negotiation with skip-
  verify, rejection of self-signed certs without skip-verify, rejection
  of plain TCP servers when EnableTLS=true, and unaffected plaintext
  behavior.
- Document maxRefreshTokenAgeSeconds (added in 1b6c861) and the implicit
  SSE / WebSocket auth bypass (added in 684a990) in README.md,
  docs/CONFIGURATION.md and docs/index.html.
- Add the missing redis.tlsSkipVerify row to docs/index.html and clarify
  the redis.enableTLS description.

patch-release
2026-05-07 12:24:13 +01:00
lukaszraczylo 9cbca4c4fb fix(refresh): honor userIdentifierClaim in token refresh path (#132)
patch-release

The refresh path in token_manager.go hardcoded the "email" claim when
extracting the user identifier from a refreshed ID token, ignoring the
configured userIdentifierClaim. Keycloak users without an email claim
(using sub or another identifier) were kicked out on refresh even
though their initial login worked.

The callback path (auth_flow.go:226-239) already honored
userIdentifierClaim with "sub" fallback; PR #100 (commit a316a98)
added that support but missed the refresh path.

Mirror the callback logic in refreshToken so both paths behave the same.

Cleanup: rename Get/SetEmail to Get/SetUserIdentifier on SessionData
to match the actual semantics. The slot already stored the configured
identifier (email, sub, oid, upn, preferred_username), only the API
name was misleading. Storage key "email" → "user_identifier" and
combinedSessionPayload field E (json:"e") → Ui (json:"ui").

Compat note: existing user sessions invalidate on upgrade — every active
user re-authenticates once after deploying this change.
2026-05-07 09:21:41 +01:00
lukaszraczylo 684a990f59 fix: reduce yaegi CPU footprint + require auth on SSE/WebSocket bypass
minor-release

Behaviour changes (potentially breaking for operators relying on the prior
unauthenticated SSE bypass):

* SSE (`Accept: text/event-stream`) and WebSocket upgrade requests now
  return 401 when no authenticated session is present. Previously the
  bypass forwarded unconditionally, which let any caller reach the
  backend by setting the right header. Excluded URLs are unchanged.
  Operators relying on unauthenticated SSE/WS access must move the path
  into ExcludedURLs.

Performance fixes (target: long-running dashboards like Grafana / ArgoCD
where many panels poll concurrently while the page stays open):

* Stop honouring isTestMode() for the singleton-token-cleanup interval
  under yaegi (the Traefik plugin runtime). In production the plugin was
  running a 20 Hz no-op cleanup ticker because runtime.Compiler ==
  "yaegi" tripped the test-mode branch.
* processAuthorizedRequest now resolves ID-token claims at most once per
  request via SessionData.GetIDTokenClaims (already cached on the
  session) and reuses them for both groups/roles extraction and
  header-template rendering. Previously every authenticated request
  parsed the JWT twice.
* Added extractGroupsAndRolesFromClaims to drive groups/roles off
  pre-parsed claims; extractGroupsAndRoles still works for tests.
* Removed the unconditional session.MarkDirty() in the header-templates
  branch. Templates only mutate request headers, not session state, so
  the prior MarkDirty was re-encrypting and rewriting all session
  cookies on every authenticated request that used header templates.

Other:

* Added isWebSocketUpgrade (RFC 6455 handshake detection — Connection:
  Upgrade + Upgrade: websocket, tolerant of multi-token Connection
  headers and case).
* Renamed applySSEUserHeaders -> applyBypassUserHeaders; it now returns
  bool so the dispatcher can reject unauthenticated SSE/WS with 401.
* Added tests for SSE and WS bypass covering both the auth-rejection
  path and the authenticated forward path.
2026-05-02 03:12:20 +01:00
24 changed files with 912 additions and 226 deletions
+17
View File
@@ -121,6 +121,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
| `cookiePrefix` | `_oidc_raczylo_` | Unique prefix per middleware instance to isolate sessions. |
| `sessionMaxAge` | `86400` | Session lifetime in seconds. |
| `refreshGracePeriodSeconds` | `60` | Proactively refresh tokens this many seconds before expiry. |
| `maxRefreshTokenAgeSeconds` | `21600` | Heuristic max stored refresh-token lifetime (6h). Past this, the plugin treats the RT as expired without contacting the IdP — returns 401 to AJAX, full re-auth on navigations. Set `0` to disable. Tune to match your IdP's RT TTL. |
| `rateLimit` | `100` | Requests/sec. Min `10`. |
| `logLevel` | `info` | `debug`, `info`, `error`. |
| `audience` | `clientID` | Custom access-token audience (Auth0 custom APIs). |
@@ -165,6 +166,22 @@ Each instance must use a unique `cookiePrefix` **and** `sessionEncryptionKey`,
otherwise a session minted by one instance can grant access through another.
See [issue #87](https://github.com/lukaszraczylo/traefikoidc/issues/87).
### SSE and WebSocket endpoints
Browser clients cannot follow an OIDC `302` redirect on an SSE stream or a
WebSocket upgrade. The middleware handles this automatically:
- **SSE** (`Accept: text/event-stream`) and **WebSocket** (`Upgrade: websocket`)
requests skip the OIDC redirect.
- They are **not** unauthenticated — a valid encrypted session cookie is
required, otherwise the request is rejected. The session must already exist
(i.e. the user logged in via a normal HTTP page first).
- `X-Forwarded-User` is forwarded from the session.
- Validation is cookie-only (no JWK fetch), so streaming keeps working during
brief IdP outages.
No configuration needed — this is implicit behavior.
### HTTP 431 from backends
Either the ID token or the chunked OIDC cookies overflow your backend's header
+1 -1
View File
@@ -1491,7 +1491,7 @@ func TestAudienceEndToEndScenario(t *testing.T) {
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("Failed to set authenticated: %v", err)
}
session.SetEmail("user@company.com")
session.SetUserIdentifier("user@company.com")
session.SetIDToken(validJWT)
session.SetAccessToken(validJWT)
+3 -3
View File
@@ -43,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
// Clear all existing session data
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
session.SetEmail("")
session.SetUserIdentifier("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetIDToken("")
@@ -250,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
session.SetUserIdentifier(userIdentifier)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
@@ -290,7 +290,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
session.SetIDToken("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
session.SetUserIdentifier("")
// Clear CSRF tokens to prevent replay attacks
session.SetCSRF("")
session.SetNonce("")
+4 -4
View File
@@ -192,7 +192,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
// Pre-populate session with old data
_ = session.SetAuthenticated(true)
session.SetEmail("old@example.com")
session.SetUserIdentifier("old@example.com")
session.SetAccessToken("old-access-token-with-many-characters")
session.SetRefreshToken("old-refresh-token-with-many-characters")
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
@@ -207,7 +207,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
// Verify old data is cleared
s.False(session.GetAuthenticated())
s.Empty(session.GetEmail())
s.Empty(session.GetUserIdentifier())
// Verify new data is set
s.Equal(csrfToken, session.GetCSRF())
@@ -711,7 +711,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
session, err := sessionManager.GetSession(req)
s.Require().NoError(err)
_ = session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
session.mainSession.Values["redirect_count"] = 3
@@ -720,7 +720,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
// Session should be cleared
s.False(session.GetAuthenticated())
s.Empty(session.GetEmail())
s.Empty(session.GetUserIdentifier())
s.Empty(session.GetIDToken())
// Redirect count should be reset to 0 and then incremented by defaultInitiateAuthentication
+4 -4
View File
@@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
session.SetCSRF(csrfToken)
session.SetNonce("test-nonce")
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken("old-access-token")
session.SetRefreshToken("old-refresh-token")
session.SetIDToken("old-id-token")
@@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Now perform selective clearing (as done in the fix)
session2.SetAuthenticated(false)
session2.SetEmail("")
session2.SetUserIdentifier("")
session2.SetAccessToken("")
session2.SetRefreshToken("")
session2.SetIDToken("")
@@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// Set initial session data
session.SetAuthenticated(true)
session.SetEmail("old@example.com")
session.SetUserIdentifier("old@example.com")
session.SetAccessToken("old-token")
session.SetCSRF("existing-csrf")
@@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF
// NEW BEHAVIOR: Selective clearing
session2.SetAuthenticated(false)
session2.SetEmail("")
session2.SetUserIdentifier("")
session2.SetAccessToken("")
session2.SetRefreshToken("")
session2.SetIDToken("")
+28
View File
@@ -70,6 +70,33 @@ overwrite it).
Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local
dev). Otherwise leave it at default.
### Streaming Endpoints (SSE and WebSocket)
The middleware automatically bypasses the OIDC redirect for two request kinds
that browsers cannot follow a 302 on:
| Bypass | Triggered by |
|--------|--------------|
| Server-Sent Events (SSE) | `Accept: text/event-stream` |
| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) |
These requests do **not** require any explicit configuration — they are
handled implicitly. However, the bypass is **not** unauthenticated:
- A valid, encrypted session cookie is required. Requests without one are
rejected (the connection cannot proceed to the backend).
- The session cookie is sealed with `sessionEncryptionKey`, so the
`authenticated` flag cannot be forged.
- Validation is cookie-only — no JWK fetch / signature verification — so
streaming endpoints keep working when the OIDC provider is briefly
unavailable.
- The user identifier from the session is forwarded as `X-Forwarded-User`
(and `X-Auth-Request-User` unless `minimalHeaders: true`).
For browser clients, the user must complete the normal OIDC flow on a
regular HTTP page first; the resulting session cookie is then reused on the
SSE / WebSocket connection.
---
## Security Options
@@ -113,6 +140,7 @@ strictAudienceValidation: true
|-----------|------|---------|-------------|
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). |
| `cookieDomain` | string | auto-detected | Domain for session cookies |
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
+11 -1
View File
@@ -718,6 +718,11 @@ spec:
<td class="py-2 px-3">86400</td>
<td class="py-2 px-3">Maximum session age in seconds (24 hours default)</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">maxRefreshTokenAgeSeconds</code></td>
<td class="py-2 px-3">21600</td>
<td class="py-2 px-3">Heuristic upper bound on stored refresh-token lifetime (6 hours default). Past this, the plugin treats the RT as expired without contacting the IdP. Set <code>0</code> to disable.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">cookiePrefix</code></td>
<td class="py-2 px-3">_oidc_raczylo_</td>
@@ -858,7 +863,12 @@ spec:
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.enableTLS</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Enable TLS for Redis connections</td>
<td class="py-2 px-3">Enable TLS for Redis connections (e.g. AWS ElastiCache in-transit encryption)</td>
</tr>
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.tlsSkipVerify</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Skip TLS server certificate verification (testing only; not recommended in production)</td>
</tr>
</tbody>
</table>
+3
View File
@@ -24,6 +24,7 @@ type Config struct {
Type BackendType
RedisAddr string
RedisPassword string
TLSServerName string
PoolSize int
RedisDB int
CleanupInterval time.Duration
@@ -34,6 +35,8 @@ type Config struct {
EnableCircuitBreaker bool
EnableHealthCheck bool
EnableMetrics bool
EnableTLS bool
TLSSkipVerify bool
}
// DefaultConfig returns a default configuration for in-memory caching
+3
View File
@@ -49,6 +49,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
poolConfig := &PoolConfig{
Address: config.RedisAddr,
Password: config.RedisPassword,
TLSServerName: config.TLSServerName,
DB: config.RedisDB,
MaxConnections: config.PoolSize,
ConnectTimeout: 2 * time.Second,
@@ -57,6 +58,8 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
EnableHealthCheck: true,
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
EnableTLS: config.EnableTLS,
TLSSkipVerify: config.TLSSkipVerify,
}
pool, err := NewConnectionPool(poolConfig)
+25 -3
View File
@@ -2,6 +2,7 @@ package backends
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
@@ -31,6 +32,7 @@ type ConnectionPool struct {
type PoolConfig struct {
Address string
Password string
TLSServerName string // SNI server name; defaults to host(Address) when empty
DB int
MaxConnections int
ConnectTimeout time.Duration
@@ -39,6 +41,8 @@ type PoolConfig struct {
EnableHealthCheck bool // Enable connection health validation
MaxRetries int // Max retries for failed operations
RetryDelay time.Duration // Initial delay between retries
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
}
// NewConnectionPool creates a new connection pool
@@ -96,7 +100,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
// No available connection, create new one if under limit
// #nosec G115 -- MaxConnections is a small config value that fits in int32
if p.totalConns.Load() < int32(p.config.MaxConnections) {
conn, err = p.createConnection()
conn, err = p.createConnection(ctx)
if err != nil {
// If this is the last attempt, return error
if attempt == maxAttempts-1 {
@@ -193,13 +197,31 @@ func (p *ConnectionPool) Stats() map[string]interface{} {
}
// createConnection creates a new Redis connection
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
// Connect with timeout
dialer := &net.Dialer{
Timeout: p.config.ConnectTimeout,
}
conn, err := dialer.Dial("tcp", p.config.Address)
var conn net.Conn
var err error
if p.config.EnableTLS {
serverName := p.config.TLSServerName
if serverName == "" {
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
serverName = host
}
}
tlsCfg := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
MinVersion: tls.VersionTLS12,
}
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
} else {
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
+230
View File
@@ -0,0 +1,230 @@
package backends
import (
"bufio"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// drainRESPRequest consumes a single RESP request (array or inline) from r and
// returns true on success. Any read error returns false.
func drainRESPRequest(r *bufio.Reader) bool {
header, err := r.ReadString('\n')
if err != nil {
return false
}
if !strings.HasPrefix(header, "*") {
return true // inline command (single line) — already consumed
}
n, err := strconv.Atoi(strings.TrimRight(strings.TrimPrefix(header, "*"), "\r\n"))
if err != nil || n <= 0 {
return false
}
for i := 0; i < n; i++ {
// Each bulk: "$len\r\n<bytes>\r\n"
if _, err := r.ReadString('\n'); err != nil {
return false
}
if _, err := r.ReadString('\n'); err != nil {
return false
}
}
return true
}
// startTLSPingServer spins up a TLS listener that speaks just enough RESP to
// answer PING with +PONG. Returns the listener address and a self-signed cert.
func startTLSPingServer(t *testing.T) (addr string, certPEM []byte, stop func()) {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "localhost"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
require.NoError(t, err)
tlsCert := tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: priv,
}
listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{tlsCert},
MinVersion: tls.VersionTLS12,
})
require.NoError(t, err)
var wg sync.WaitGroup
stopCh := make(chan struct{})
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopCh:
return
default:
}
c, acceptErr := listener.Accept()
if acceptErr != nil {
return
}
wg.Add(1)
go func(conn net.Conn) {
defer wg.Done()
defer conn.Close()
reader := bufio.NewReader(conn)
for {
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if !drainRESPRequest(reader) {
return
}
_, _ = conn.Write([]byte("+PONG\r\n"))
}
}(c)
}
}()
stop = func() {
close(stopCh)
_ = listener.Close()
wg.Wait()
}
return listener.Addr().String(), der, stop
}
// TestConnectionPool_TLSDial_SkipVerify verifies that EnableTLS=true with
// TLSSkipVerify=true successfully negotiates TLS and exchanges a Redis command.
// Regression test for issue #133 (enableTLS not propagated to client).
func TestConnectionPool_TLSDial_SkipVerify(t *testing.T) {
addr, _, stop := startTLSPingServer(t)
defer stop()
pool, err := NewConnectionPool(&PoolConfig{
Address: addr,
MaxConnections: 2,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: true,
})
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn)
defer pool.Put(conn)
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
}
// TestConnectionPool_TLSDial_VerifyFails verifies that EnableTLS=true with
// TLSSkipVerify=false rejects a self-signed server cert.
func TestConnectionPool_TLSDial_VerifyFails(t *testing.T) {
addr, _, stop := startTLSPingServer(t)
defer stop()
pool, err := NewConnectionPool(&PoolConfig{
Address: addr,
MaxConnections: 2,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: false,
})
require.NoError(t, err)
defer pool.Close()
_, err = pool.Get(context.Background())
require.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "tls")
}
// TestConnectionPool_TLSDial_PlainServerRejected verifies that EnableTLS=true
// fails to handshake against a plain (non-TLS) listener.
func TestConnectionPool_TLSDial_PlainServerRejected(t *testing.T) {
plain, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer plain.Close()
go func() {
for {
c, acceptErr := plain.Accept()
if acceptErr != nil {
return
}
_ = c.Close()
}
}()
pool, err := NewConnectionPool(&PoolConfig{
Address: plain.Addr().String(),
MaxConnections: 1,
ConnectTimeout: 1 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: true,
})
require.NoError(t, err)
defer pool.Close()
_, err = pool.Get(context.Background())
require.Error(t, err)
}
// TestConnectionPool_PlainDial_StillWorks ensures non-TLS path is unaffected
// when EnableTLS=false (default).
func TestConnectionPool_PlainDial_StillWorks(t *testing.T) {
mr := NewMiniredisServer(t)
pool, err := NewConnectionPool(&PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: false,
})
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
defer pool.Put(conn)
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
}
+135
View File
@@ -0,0 +1,135 @@
package traefikoidc
import (
"net/http"
"net/http/httptest"
"testing"
)
// TestIssue132_RefreshTokenHonorsUserIdentifierClaim reproduces and verifies
// the fix for issue #132: token refresh path hardcoded the "email" claim and
// ignored the configured userIdentifierClaim. Keycloak users without an email
// claim (using sub or another identifier) were being kicked out on refresh
// even though their initial login worked.
//
// The callback path (auth_flow.go) already honored userIdentifierClaim with
// "sub" fallback. The refresh path (token_manager.go) had drifted out of sync
// after PR #100 (commit a316a98).
func TestIssue132_RefreshTokenHonorsUserIdentifierClaim(t *testing.T) {
tests := []struct {
claims map[string]any
name string
userIdentifierClaim string
expectedIdentifier string
expectSuccess bool
}{
{
name: "sub claim configured, only sub present (Keycloak no-email case)",
userIdentifierClaim: "sub",
claims: map[string]any{
"sub": "user-uuid-keycloak-12345",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "user-uuid-keycloak-12345",
},
{
name: "preferred_username configured, claim present",
userIdentifierClaim: "preferred_username",
claims: map[string]any{
"sub": "user-uuid-12345",
"preferred_username": "alice",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "alice",
},
{
name: "configured claim missing, falls back to sub",
userIdentifierClaim: "preferred_username",
claims: map[string]any{
"sub": "fallback-sub-id",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "fallback-sub-id",
},
{
name: "email default, email present (backward compatibility)",
userIdentifierClaim: "email",
claims: map[string]any{
"sub": "user-uuid-12345",
"email": "user@example.com",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "user@example.com",
},
{
name: "email default, no email and no sub - refresh fails",
userIdentifierClaim: "email",
claims: map[string]any{
"exp": float64(9999999999),
},
expectSuccess: false,
expectedIdentifier: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sessionManager, err := NewSessionManager(
"test-encryption-key-32-bytes-long!!",
false,
"",
"",
0,
NewLogger("error"),
)
if err != nil {
t.Fatalf("session manager: %v", err)
}
defer sessionManager.Shutdown()
capturedClaims := tt.claims
tOidc := &TraefikOidc{
logger: NewLogger("error"),
userIdentifierClaim: tt.userIdentifierClaim,
sessionManager: sessionManager,
tokenExchanger: &EnhancedMockTokenExchanger{
RefreshResponse: &TokenResponse{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
IDToken: "new-id-token-jwt",
ExpiresIn: 3600,
},
},
tokenVerifier: &EnhancedMockTokenVerifier{Err: nil},
extractClaimsFunc: func(token string) (map[string]any, error) {
return capturedClaims, nil
},
}
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("get session: %v", err)
}
defer session.returnToPoolSafely()
session.SetRefreshToken("initial-refresh-token")
refreshed := tOidc.refreshToken(rw, req, session)
if refreshed != tt.expectSuccess {
t.Fatalf("refreshToken() = %v, want %v", refreshed, tt.expectSuccess)
}
if got := session.GetUserIdentifier(); got != tt.expectedIdentifier {
t.Errorf("session.GetUserIdentifier() = %q, want %q", got, tt.expectedIdentifier)
}
})
}
}
+199 -47
View File
@@ -79,34 +79,186 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
}
}
// TestServeHTTP_EventStream tests the event-stream bypass functionality
// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the
// handshake must skip the OIDC redirect dance (clients can't follow it
// mid-stream) but it must STILL require an authenticated session, otherwise
// any caller could reach the backend by setting Accept: text/event-stream.
func TestServeHTTP_EventStream(t *testing.T) {
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
t.Run("unauthenticated_request_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated SSE bypass")
}
})
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
// Build an authenticated session and inject its cookies onto req.
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated SSE request to be forwarded to backend")
}
if forwardedUser != "user@example.com" {
t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser)
}
})
}
// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket
// handshake bypasses the OIDC redirect (clients can't follow it) but the
// session must already be authenticated, otherwise the backend is exposed
// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`.
func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
rw := httptest.NewRecorder()
t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
}))
oidc.ServeHTTP(rw, req)
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
rw := httptest.NewRecorder()
if !nextCalled {
t.Error("expected event-stream request to bypass OIDC")
}
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated WS bypass")
}
})
t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
}))
req := httptest.NewRequest("GET", "/ws", nil)
// Mixed-case + multi-token Connection header to exercise parsing.
req.Header.Set("Connection", "keep-alive, Upgrade")
req.Header.Set("Upgrade", "WebSocket")
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("ws-user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated WS handshake to be forwarded to backend")
}
if forwardedUser != "ws-user@example.com" {
t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser)
}
})
t.Run("plain_http_does_not_bypass", func(t *testing.T) {
// Sanity: requests without Upgrade headers must NOT hit the WS
// bypass branch (otherwise the new code path could short-circuit
// normal authentication).
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("backend must not be called for unauthenticated plain HTTP")
}))
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "keep-alive")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code == http.StatusOK {
t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200")
}
})
}
// TestServeHTTP_InitializationTimeout tests initialization timeout handling
@@ -256,7 +408,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "successful authorization with email",
setupSession: func() *MockSessionData {
session := &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
@@ -288,7 +440,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "no email triggers reauth",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "",
userIdentifier: "",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -309,7 +461,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "roles and groups authorization",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -342,7 +494,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "unauthorized role/group returns 403",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -369,7 +521,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "template headers processing",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
@@ -401,7 +553,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "OPTIONS request with CORS",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -452,7 +604,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
manager: &SessionManager{logger: NewLogger("debug")},
}
// Copy values from mock to concrete session
concreteSession.SetEmail(session.email)
concreteSession.SetUserIdentifier(session.userIdentifier)
concreteSession.SetIDToken(session.idToken)
concreteSession.SetAccessToken(session.accessToken)
concreteSession.SetRefreshToken(session.refreshToken)
@@ -502,23 +654,23 @@ func TestProcessAuthorizedRequest(t *testing.T) {
// MockSessionData is a test implementation of SessionData interface
type MockSessionData struct {
email string
idToken string
accessToken string
refreshToken string
csrf string
nonce string
codeVerifier string
redirectCount int
authenticated bool
isDirty bool
userIdentifier string
idToken string
accessToken string
refreshToken string
csrf string
nonce string
codeVerifier string
redirectCount int
authenticated bool
isDirty bool
}
func (m *MockSessionData) GetEmail() string { return m.email }
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
func (m *MockSessionData) GetIDToken() string { return m.idToken }
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
func (m *MockSessionData) SetEmail(email string) { m.email = email }
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
@@ -610,7 +762,7 @@ func TestMinimalHeaders(t *testing.T) {
}
// Set up session data
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Call processAuthorizedRequest directly
@@ -685,7 +837,7 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
@@ -771,7 +923,7 @@ func TestStripAuthCookies(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Now add OIDC session cookies (simulating what the browser would send)
@@ -852,7 +1004,7 @@ func TestStripAuthCookies_NoCookies(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
@@ -899,7 +1051,7 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only OIDC cookies
@@ -950,7 +1102,7 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only non-OIDC cookies
@@ -1013,7 +1165,7 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add cookies with the custom prefix (should be stripped)
+15 -15
View File
@@ -580,7 +580,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Generate a fresh valid token for this test case to avoid replay issues
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -603,7 +603,7 @@ func TestServeHTTP(t *testing.T) {
// even if session.SetAuthenticated(true) was called.
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -660,7 +660,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -678,7 +678,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -706,7 +706,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -741,7 +741,7 @@ func TestServeHTTP(t *testing.T) {
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken(nearExpiryToken)
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
},
@@ -772,7 +772,7 @@ func TestServeHTTP(t *testing.T) {
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken(validToken)
session.SetIDToken(validToken) // Ensure ID token is also set
session.SetRefreshToken("should-not-be-used-refresh-token")
@@ -792,7 +792,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@disallowed.com") // Use disallowed domain
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -814,7 +814,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@disallowed.com") // Use disallowed domain
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -2179,7 +2179,7 @@ func TestHandleExpiredToken(t *testing.T) {
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
},
expectedPath: "/original/path",
},
@@ -2756,7 +2756,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2782,7 +2782,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2809,7 +2809,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusForbidden,
},
@@ -2829,7 +2829,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2851,7 +2851,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{},
+150 -78
View File
@@ -14,21 +14,40 @@ import (
)
// bypassReason describes why a request is being forwarded without OIDC auth.
// It is only used for logging and to decide whether extra SSE-specific work
// It is only used for logging and to decide whether extra side-effects
// (propagating the user header from an existing session) should run.
const (
bypassReasonExcluded = "excluded-url"
bypassReasonSSE = "sse"
bypassReasonExcluded = "excluded-url"
bypassReasonSSE = "sse"
bypassReasonWebSocket = "websocket"
)
// isWebSocketUpgrade reports whether req is a WebSocket upgrade handshake
// (RFC 6455). The middleware can only see the handshake; once Traefik
// completes the upgrade it forwards frames directly, so we never re-process
// per-frame traffic. We bypass auth on the handshake the same way we do for
// SSE, because browser WebSocket clients cannot follow an OIDC redirect.
func isWebSocketUpgrade(req *http.Request) bool {
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
return false
}
for _, token := range strings.Split(req.Header.Get("Connection"), ",") {
if strings.EqualFold(strings.TrimSpace(token), "upgrade") {
return true
}
}
return false
}
// shouldBypassAuth decides whether a request must skip OIDC authentication
// entirely. It returns (true, reason) when either the request path matches a
// configured excluded URL or the Accept header asks for a text/event-stream
// response (SSE). The reason lets ServeHTTP apply any side-effects that are
// unique to the bypass kind (e.g. propagating user headers for SSE).
// configured excluded URL, the Accept header asks for a text/event-stream
// response (SSE), or the request is a WebSocket upgrade handshake. The
// reason lets ServeHTTP apply any side-effects that are unique to the bypass
// kind (e.g. propagating user headers).
//
// This must be called BEFORE waiting on t.initComplete so excluded and SSE
// traffic is never blocked by a slow/broken provider.
// This must be called BEFORE waiting on t.initComplete so excluded, SSE and
// WebSocket traffic is never blocked by a slow/broken provider.
func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
if t.determineExcludedURL(req.URL.Path) {
return true, bypassReasonExcluded
@@ -36,38 +55,55 @@ func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
if strings.Contains(req.Header.Get("Accept"), "text/event-stream") {
return true, bypassReasonSSE
}
if isWebSocketUpgrade(req) {
return true, bypassReasonWebSocket
}
return false, ""
}
// applySSEUserHeaders attempts to copy the authenticated user's identity from
// an existing session onto the outgoing SSE request so downstream services
// can still see who the user is. Failures are logged (not silenced) because
// they indicate either a corrupt cookie or a misconfigured session manager
// and are useful for debugging, but they never block the bypass itself.
func (t *TraefikOidc) applySSEUserHeaders(req *http.Request) {
// applyBypassUserHeaders enforces authentication on SSE / WebSocket bypass
// requests and, on success, copies the authenticated user's identity onto
// the outgoing request so downstream services can see who the user is.
//
// Returns true when the request carries a valid authenticated session and
// the bypass should proceed. Returns false when no usable session is
// present; callers must then reject the request (typically with 401) to
// prevent unauthenticated traffic from reaching the backend just by setting
// `Accept: text/event-stream` or sending a WebSocket upgrade.
//
// The check is cookie-only: the session cookie is sealed by our encryption
// key, so the authenticated flag cannot be forged. We do NOT run full token
// signature verification here so that SSE/WS keeps working when the OIDC
// provider is briefly unavailable for JWK fetches.
func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) bool {
if t.sessionManager == nil {
return
return false
}
session, err := t.sessionManager.GetSession(req)
if err != nil {
// Intentionally not fatal: SSE requests bypass auth, we just lose the
// forwarded-user header for this request.
t.logger.Debugf("SSE bypass: unable to load session for user header propagation: %v", err)
return
t.logger.Debugf("%s bypass: unable to load session: %v", reason, err)
return false
}
defer session.returnToPoolSafely()
email := session.GetEmail()
if email == "" {
return
if !session.GetAuthenticated() {
t.logger.Debugf("%s bypass: rejecting request without authenticated session", reason)
return false
}
req.Header.Set("X-Forwarded-User", email)
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-User", email)
userIdentifier := session.GetUserIdentifier()
if userIdentifier == "" {
t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason)
return false
}
t.logger.Debugf("SSE bypass: forwarded user %s from session", email)
req.Header.Set("X-Forwarded-User", userIdentifier)
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-User", userIdentifier)
}
t.logger.Debugf("%s bypass: forwarded user %s from session", reason, userIdentifier)
return true
}
// ServeHTTP implements the main middleware logic for processing HTTP requests.
@@ -124,16 +160,32 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.firstRequestMutex.Unlock()
}
// Evaluate auth-bypass once, before waiting for initialization. Excluded URLs
// and SSE requests must not block on provider init. For SSE we additionally
// attempt to forward the user identity from an existing session (best
// effort) so downstream handlers still see X-Forwarded-User.
// Evaluate auth-bypass once, before waiting for initialization. Excluded
// URLs, SSE and WebSocket upgrade requests must not block on provider
// init. For SSE/WebSocket we ALSO require an authenticated session
// (cookie-only check, no JWK fetch) and otherwise return 401 — clients
// of in-flight streams can't follow an OIDC redirect, so forwarding
// unauthenticated traffic would silently expose the backend.
if bypass, reason := t.shouldBypassAuth(req); bypass {
t.logger.Debugf("Bypassing OIDC for %s (%s)", req.URL.Path, reason)
if reason == bypassReasonSSE {
t.applySSEUserHeaders(req)
switch reason {
case bypassReasonExcluded:
// Operator-declared excluded URLs forward unconditionally.
t.next.ServeHTTP(rw, req)
case bypassReasonSSE, bypassReasonWebSocket:
// Skip the OIDC redirect dance (clients can't follow it
// mid-stream) but still require an authenticated session.
// Otherwise an unauthenticated client could hit the backend
// just by setting Accept: text/event-stream or sending a
// WebSocket upgrade.
if !t.applyBypassUserHeaders(req, reason) {
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
return
}
t.next.ServeHTTP(rw, req)
default:
t.next.ServeHTTP(rw, req)
}
t.next.ServeHTTP(rw, req)
return
}
@@ -237,7 +289,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
userIdentifier := session.GetUserIdentifier()
// User authorization check
if authenticated && userIdentifier != "" {
if !t.isAllowedUser(userIdentifier) {
@@ -309,7 +361,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
refreshed := t.refreshToken(rw, req, session)
if refreshed {
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
userIdentifier = session.GetUserIdentifier()
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
@@ -359,9 +411,9 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// - session: The user's session data containing tokens and claims.
// - redirectURL: The callback URL for re-authentication if needed.
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
email := session.GetEmail()
if email == "" {
t.logger.Info("No email found in session during final processing, initiating re-auth")
userIdentifier := session.GetUserIdentifier()
if userIdentifier == "" {
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
// Reset redirect count to prevent loops when session is invalid
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
@@ -374,7 +426,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
if idToken != "" {
sid, sub, createdAt := t.extractSessionInfo(idToken)
if t.isSessionInvalidated(sid, sub, createdAt) {
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email)
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
// Clear the session and redirect to login
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing invalidated session: %v", err)
@@ -386,31 +438,52 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
}
}
tokenForClaims := session.GetIDToken()
if tokenForClaims == "" {
tokenForClaims = session.GetAccessToken()
if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 {
t.logger.Error("No token available but roles/groups checks are required")
// Reset redirect count to prevent loops when token is missing
// Resolve ID-token claims at most once per request. SessionData caches
// the parsed claims keyed on the raw ID token, so concurrent dashboard
// panel requests on the same session don't repeatedly base64-decode and
// JSON-unmarshal the same JWT (a real cost under the yaegi interpreter
// that hosts Traefik plugins). idClaims is reused below by the
// header-templates branch.
idToken := session.GetIDToken()
var (
idClaims map[string]interface{}
idClaimsErr error
)
if idToken != "" {
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
}
// Choose which claims drive groups/roles extraction. Prefer the ID
// token (cached) and fall back to the access token if there is no ID
// token in the session — matching the prior behavior for opaque
// ID-token providers.
var (
groupClaims map[string]interface{}
groupClaimsErr error
)
if idToken != "" {
groupClaims, groupClaimsErr = idClaims, idClaimsErr
} else if accessToken := session.GetAccessToken(); accessToken != "" {
groupClaims, groupClaimsErr = t.extractClaimsFunc(accessToken)
} else if len(t.allowedRolesAndGroups) > 0 {
t.logger.Error("No token available but roles/groups checks are required")
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
var groups, roles []string
if groupClaimsErr == nil && groupClaims != nil {
var err error
groups, roles, err = t.extractGroupsAndRolesFromClaims(groupClaims)
if err != nil && len(t.allowedRolesAndGroups) > 0 {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
}
// Initialize empty slices
var groups, roles []string
if tokenForClaims != "" {
var err error
groups, roles, err = t.extractGroupsAndRoles(tokenForClaims)
if err != nil && len(t.allowedRolesAndGroups) > 0 {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
// Reset redirect count to prevent loops when claim extraction fails
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
} else if err == nil {
if err == nil {
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
@@ -429,54 +502,53 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
t.logger.Infof("User %s does not have any allowed roles or groups", userIdentifier)
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
req.Header.Set("X-Forwarded-User", userIdentifier)
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetIDToken(); idToken != "" {
req.Header.Set("X-Auth-Request-User", userIdentifier)
if idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
}
if len(t.headerTemplates) > 0 {
// Reuse claims parsed earlier in this request if the ID token has not
// changed. Saves an unnecessary JWT parse on every authenticated
// request that uses headerTemplates.
claims, err := session.GetIDTokenClaims(t.extractClaimsFunc)
if err != nil {
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", err)
if idClaimsErr != nil {
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", idClaimsErr)
} else {
// idClaims may be nil when no ID token is present; templates
// referencing .Claims.* will simply produce empty values, which
// matches the prior behavior.
templateData := map[string]interface{}{
"AccessToken": session.GetAccessToken(),
"IDToken": session.GetIDToken(),
"IDToken": idToken,
"RefreshToken": session.GetRefreshToken(),
"Claims": claims,
"Claims": idClaims,
}
for headerName, tmpl := range t.headerTemplates {
var buf bytes.Buffer
if err := tmpl.Execute(&buf, templateData); err != nil {
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
continue
}
headerValue := buf.String()
req.Header.Set(headerName, headerValue)
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
}
session.MarkDirty()
t.logger.Debugf("Session marked dirty after templated header processing.")
// NOTE: templates only mutate request headers (not session state),
// so we deliberately do NOT MarkDirty / Save here. Previously every
// authenticated request with header templates re-encrypted and
// rewrote all session cookies, which was a measurable CPU and
// Set-Cookie tax on dashboards that poll many panels per second.
}
}
@@ -515,7 +587,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
}
}
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", userIdentifier)
t.next.ServeHTTP(rw, req)
}
+7 -7
View File
@@ -161,7 +161,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
// Create authenticated session
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
session.SetIDToken("dummy-token")
session.Save(req, httptest.NewRecorder())
@@ -203,7 +203,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
// Create session with forbidden domain
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@forbidden.com")
session.SetUserIdentifier("user@forbidden.com")
session.SetAuthenticated(true)
// Save and inject cookies
@@ -252,7 +252,7 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
// Create session with opaque token
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots)
session.SetAuthenticated(true)
@@ -291,7 +291,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("") // No email
session.SetUserIdentifier("") // No email
session.SetIDToken("dummy-token")
rw := httptest.NewRecorder()
@@ -321,7 +321,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetIDToken("") // No ID token
session.SetAccessToken("") // No access token
@@ -349,7 +349,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetIDToken("dummy-token")
rw := httptest.NewRecorder()
@@ -383,7 +383,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
testEmail := "user@example.com"
session.SetEmail(testEmail)
session.SetUserIdentifier(testEmail)
session.SetIDToken("dummy-id-token")
rw := httptest.NewRecorder()
+2 -2
View File
@@ -129,7 +129,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
// Simulate successful Azure authentication
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Azure may use opaque access tokens
session.SetAccessToken("opaque-azure-access-token")
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ") // trufflehog:ignore
@@ -152,7 +152,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
require.NoError(t, err)
assert.True(t, session2.GetAuthenticated(), "User should remain authenticated")
assert.Equal(t, "user@example.com", session2.GetEmail())
assert.Equal(t, "user@example.com", session2.GetUserIdentifier())
assert.NotEmpty(t, session2.GetAccessToken(), "Access token should persist")
assert.NotEmpty(t, session2.GetIDToken(), "ID token should persist")
assert.NotEmpty(t, session2.GetRefreshToken(), "Refresh token should persist")
+2 -2
View File
@@ -485,7 +485,7 @@ func TestSessionFixationAttack(t *testing.T) {
// Set up the attacker's session with malicious data
attackerSession.SetAuthenticated(true)
attackerSession.SetEmail("attacker@evil.com")
attackerSession.SetUserIdentifier("attacker@evil.com")
attackerSession.SetIDToken(ValidIDToken)
attackerSession.SetAccessToken(ValidAccessToken)
@@ -512,7 +512,7 @@ func TestSessionFixationAttack(t *testing.T) {
}
// Get the email from the session
email := session.GetEmail()
email := session.GetUserIdentifier()
w.Header().Set("X-User-Email", email)
w.WriteHeader(http.StatusOK)
})
+26 -26
View File
@@ -100,7 +100,7 @@ type combinedSessionPayload struct {
A string `json:"a,omitempty"`
R string `json:"r,omitempty"`
I string `json:"i,omitempty"`
E string `json:"e,omitempty"`
Ui string `json:"ui,omitempty"`
Cs string `json:"cs,omitempty"`
N string `json:"n,omitempty"`
Cv string `json:"cv,omitempty"`
@@ -113,11 +113,11 @@ type combinedSessionPayload struct {
// knownSessionKeys are the standard keys that are handled explicitly in the combined payload.
// All other mainSession.Values keys are stored in the X (extra) field.
var knownSessionKeys = map[string]bool{
"access_token": true,
"refresh_token": true,
"id_token": true,
"email": true,
"authenticated": true,
"access_token": true,
"refresh_token": true,
"id_token": true,
"user_identifier": true,
"authenticated": true,
"csrf": true,
"nonce": true,
"code_verifier": true,
@@ -1134,7 +1134,7 @@ func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *
sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName())
// Populate legacy session values from combined payload
sessionData.mainSession.Values["email"] = payload.E
sessionData.mainSession.Values["user_identifier"] = payload.Ui
sessionData.mainSession.Values["authenticated"] = payload.Au
sessionData.mainSession.Values["csrf"] = payload.Cs
sessionData.mainSession.Values["nonce"] = payload.N
@@ -1278,7 +1278,7 @@ func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, opti
A: sd.getAccessTokenUnsafe(),
R: sd.getRefreshTokenUnsafe(),
I: sd.getIDTokenUnsafe(),
E: sd.getEmailUnsafe(),
Ui: sd.getUserIdentifierUnsafe(),
Au: sd.getAuthenticatedUnsafe(),
Cs: sd.getCSRFUnsafe(),
N: sd.getNonceUnsafe(),
@@ -2469,30 +2469,30 @@ func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
}
}
// GetEmail retrieves the authenticated user's email address.
// The email is extracted from ID token claims and used for
// authorization decisions and header injection.
// GetUserIdentifier retrieves the authenticated user's identifier as extracted
// from the configured userIdentifierClaim of the ID token (email, sub, oid,
// upn, preferred_username, etc.). The value is used for authorization
// decisions and header injection.
// Returns:
// - The user's email address string, or an empty string if not set.
func (sd *SessionData) GetEmail() string {
// - The user identifier string, or an empty string if not set.
func (sd *SessionData) GetUserIdentifier() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
email, _ := sd.mainSession.Values["email"].(string)
return email
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
return userIdentifier
}
// SetEmail stores the authenticated user's email address.
// The email is typically extracted from the 'email' claim in the ID token.
// SetUserIdentifier stores the authenticated user's identifier value.
// Parameters:
// - email: The user's email address to store.
func (sd *SessionData) SetEmail(email string) {
// - userIdentifier: The user identifier to store (email, sub, or other claim value).
func (sd *SessionData) SetUserIdentifier(userIdentifier string) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
currentVal, _ := sd.mainSession.Values["email"].(string)
if currentVal != email {
sd.mainSession.Values["email"] = email
currentVal, _ := sd.mainSession.Values["user_identifier"].(string)
if currentVal != userIdentifier {
sd.mainSession.Values["user_identifier"] = userIdentifier
sd.dirty = true
}
}
@@ -2626,10 +2626,10 @@ func (sd *SessionData) getRefreshTokenUnsafe() string {
return result.Token
}
// getEmailUnsafe retrieves the email without acquiring locks.
func (sd *SessionData) getEmailUnsafe() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
// getUserIdentifierUnsafe retrieves the user identifier without acquiring locks.
func (sd *SessionData) getUserIdentifierUnsafe() string {
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
return userIdentifier
}
// getCSRFUnsafe retrieves the CSRF token without acquiring locks.
+6 -7
View File
@@ -320,17 +320,16 @@ func (s *SessionBehaviourSuite) TestSessionData_DirtyTracking() {
s.False(session.IsDirty())
}
// TestSessionData_SetEmail tests email setter with dirty tracking
func (s *SessionBehaviourSuite) TestSessionData_SetEmail() {
// TestSessionData_SetUserIdentifier tests user identifier setter with dirty tracking
func (s *SessionBehaviourSuite) TestSessionData_SetUserIdentifier() {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
session, err := s.sessionManager.GetSession(req)
s.Require().NoError(err)
defer session.returnToPoolSafely()
// Set email
session.SetEmail("test@example.com")
s.Equal("test@example.com", session.GetEmail())
session.SetUserIdentifier("test@example.com")
s.Equal("test@example.com", session.GetUserIdentifier())
s.True(session.IsDirty())
}
@@ -568,7 +567,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Clear() {
// Set some data
err = session.SetAuthenticated(true)
s.Require().NoError(err)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
session.SetCSRF("csrf-token")
// Clear session
@@ -588,7 +587,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Save() {
defer session.returnToPoolSafely()
// Modify session
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
s.True(session.IsDirty())
// Save session
+6 -6
View File
@@ -2688,7 +2688,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
// Set up initial session state (what user has when first logging in)
session1.SetAuthenticated(true)
session1.SetEmail(originalUserData["email"].(string))
session1.SetUserIdentifier(originalUserData["email"].(string))
session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars")
session1.SetIDToken("initial-valid-id-token-longer-than-20-chars")
session1.SetRefreshToken("valid-refresh-token-should-last-30-days")
@@ -2732,7 +2732,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
// Simulate what happens when middleware detects expired tokens
// It should preserve session state while attempting token refresh
originalAuth := session2.GetAuthenticated()
originalEmail := session2.GetEmail()
originalEmail := session2.GetUserIdentifier()
// Reconstruct user data from individual stored keys
originalUserDataStored := make(map[string]interface{})
@@ -2813,7 +2813,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
// Verify all session data is still intact after token refresh
postRefreshAuth := session2.GetAuthenticated()
postRefreshEmail := session2.GetEmail()
postRefreshEmail := session2.GetUserIdentifier()
userDataPresent := true
for k := range originalUserData {
if session2.mainSession.Values["user_data_"+k] == nil {
@@ -2907,7 +2907,7 @@ func TestSessionExpiryVsTokenExpiry(t *testing.T) {
// Set up session with specific creation time
session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
session.mainSession.Values["created_at"] = sessionCreatedAt.Unix()
// Create tokens with specific expiry
@@ -3018,7 +3018,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
// Set up session with data that should be preserved or removed
session.SetAuthenticated(true)
session.SetEmail("cleanup@example.com")
session.SetUserIdentifier("cleanup@example.com")
session.mainSession.Values["user_data"] = "Test User|user-123"
session.mainSession.Values["preferences"] = "theme:dark,lang:en"
@@ -3049,7 +3049,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
if scenario.shouldCleanup {
if sessionTooOld {
session.SetAuthenticated(false)
session.SetEmail("")
session.SetUserIdentifier("")
session.SetAccessToken("")
session.SetRefreshToken("")
for key := range session.mainSession.Values {
+1 -1
View File
@@ -293,7 +293,7 @@ func (tf *TestFramework) CreateAuthenticatedRequest(method, path string) (*http.
}
session.SetAuthenticated(true)
session.SetEmail(tf.fixtures.UserEmail)
session.SetUserIdentifier(tf.fixtures.UserEmail)
session.SetAccessToken(tf.fixtures.AccessToken)
session.SetRefreshToken(tf.fixtures.RefreshToken)
session.SetIDToken(tf.GenerateJWT(tf.fixtures.Claims))
+32 -19
View File
@@ -11,6 +11,7 @@ import (
"io"
"net/http"
"net/url"
"runtime"
"strings"
"time"
)
@@ -433,7 +434,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
session.SetRefreshToken("")
session.SetAccessToken("")
session.SetIDToken("")
session.SetEmail("")
session.SetUserIdentifier("")
// Clear CSRF tokens as well to prevent any replay attacks
session.SetCSRF("")
session.SetNonce("")
@@ -475,12 +476,18 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
return false
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
return false
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
if userIdentifier == "" {
if t.userIdentifierClaim != "sub" {
userIdentifier, _ = claims["sub"].(string)
}
if userIdentifier == "" {
t.logger.Errorf("refreshToken failed: User identifier claim '%s' missing or empty in refreshed token", t.userIdentifierClaim)
return false
}
t.logger.Debugf("Configured claim '%s' not found in refreshed token, using 'sub' claim as fallback", t.userIdentifierClaim)
}
session.SetEmail(email)
session.SetUserIdentifier(userIdentifier)
// Get token expiry information for logging
var expiryTime time.Time
@@ -506,7 +513,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
session.SetAccessToken("")
session.SetIDToken("")
session.SetRefreshToken("")
session.SetEmail("")
session.SetUserIdentifier("")
return false
}
@@ -1193,9 +1200,14 @@ func (t *TraefikOidc) startTokenCleanup() {
sessionManager := t.sessionManager
logger := t.logger
// Only use the fast cleanup interval when actually running under `go test`.
// runtime.Compiler == "yaegi" makes isTestMode() return true in production
// (Traefik interprets the plugin via yaegi), which would otherwise pin this
// ticker to 20 Hz on a real cluster despite tokenCache.Cleanup and
// jwkCache.Cleanup both being no-ops there.
cleanupInterval := 1 * time.Minute
if isTestMode() {
cleanupInterval = 50 * time.Millisecond // Fast interval for tests
if isTestMode() && runtime.Compiler != "yaegi" {
cleanupInterval = 50 * time.Millisecond
}
// Create cleanup function
@@ -1237,25 +1249,27 @@ func (t *TraefikOidc) startTokenCleanup() {
}
// extractGroupsAndRoles extracts group and role information from token claims.
// It parses the 'groups' and 'roles' claims from the ID token and validates their format.
// Parameters:
// - idToken: The ID token containing claims to extract.
// It parses the configured group/role claims from the supplied ID token.
//
// Returns:
// - groups: Array of group names from the 'groups' claim.
// - roles: Array of role names from the 'roles' claim.
// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present
// but not arrays of strings.
// Most callers should prefer extractGroupsAndRolesFromClaims when claims have
// already been parsed for the request (e.g. via SessionData.GetIDTokenClaims),
// to avoid re-parsing the JWT.
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
}
return t.extractGroupsAndRolesFromClaims(claims)
}
// extractGroupsAndRolesFromClaims extracts group and role information from
// already-parsed claims. Hot path: callers that have a cached claims map (such
// as SessionData.GetIDTokenClaims) should use this to skip a redundant
// base64+JSON decode of the JWT on every authenticated request.
func (t *TraefikOidc) extractGroupsAndRolesFromClaims(claims map[string]interface{}) ([]string, []string, error) {
var groups []string
var roles []string
// Extract groups using configurable claim name (defaults to "groups")
if groupsClaim, exists := claims[t.groupClaimName]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
@@ -1271,7 +1285,6 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
}
}
// Extract roles using configurable claim name (defaults to "roles")
if rolesClaim, exists := claims[t.roleClaimName]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
+2
View File
@@ -210,6 +210,8 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
RedisPrefix: redisConfig.KeyPrefix,
PoolSize: redisConfig.PoolSize,
EnableMetrics: true,
EnableTLS: redisConfig.EnableTLS,
TLSSkipVerify: redisConfig.TLSSkipVerify,
}
// Use concrete type to avoid Yaegi reflection issues with interface assignment