mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 68c150eba4 | |||
| 9cbca4c4fb | |||
| 684a990f59 |
@@ -121,6 +121,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
|
||||
| `cookiePrefix` | `_oidc_raczylo_` | Unique prefix per middleware instance to isolate sessions. |
|
||||
| `sessionMaxAge` | `86400` | Session lifetime in seconds. |
|
||||
| `refreshGracePeriodSeconds` | `60` | Proactively refresh tokens this many seconds before expiry. |
|
||||
| `maxRefreshTokenAgeSeconds` | `21600` | Heuristic max stored refresh-token lifetime (6h). Past this, the plugin treats the RT as expired without contacting the IdP — returns 401 to AJAX, full re-auth on navigations. Set `0` to disable. Tune to match your IdP's RT TTL. |
|
||||
| `rateLimit` | `100` | Requests/sec. Min `10`. |
|
||||
| `logLevel` | `info` | `debug`, `info`, `error`. |
|
||||
| `audience` | `clientID` | Custom access-token audience (Auth0 custom APIs). |
|
||||
@@ -165,6 +166,22 @@ Each instance must use a unique `cookiePrefix` **and** `sessionEncryptionKey`,
|
||||
otherwise a session minted by one instance can grant access through another.
|
||||
See [issue #87](https://github.com/lukaszraczylo/traefikoidc/issues/87).
|
||||
|
||||
### SSE and WebSocket endpoints
|
||||
|
||||
Browser clients cannot follow an OIDC `302` redirect on an SSE stream or a
|
||||
WebSocket upgrade. The middleware handles this automatically:
|
||||
|
||||
- **SSE** (`Accept: text/event-stream`) and **WebSocket** (`Upgrade: websocket`)
|
||||
requests skip the OIDC redirect.
|
||||
- They are **not** unauthenticated — a valid encrypted session cookie is
|
||||
required, otherwise the request is rejected. The session must already exist
|
||||
(i.e. the user logged in via a normal HTTP page first).
|
||||
- `X-Forwarded-User` is forwarded from the session.
|
||||
- Validation is cookie-only (no JWK fetch), so streaming keeps working during
|
||||
brief IdP outages.
|
||||
|
||||
No configuration needed — this is implicit behavior.
|
||||
|
||||
### HTTP 431 from backends
|
||||
|
||||
Either the ID token or the chunked OIDC cookies overflow your backend's header
|
||||
|
||||
+1
-1
@@ -1491,7 +1491,7 @@ func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("Failed to set authenticated: %v", err)
|
||||
}
|
||||
session.SetEmail("user@company.com")
|
||||
session.SetUserIdentifier("user@company.com")
|
||||
session.SetIDToken(validJWT)
|
||||
session.SetAccessToken(validJWT)
|
||||
|
||||
|
||||
+3
-3
@@ -43,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
|
||||
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
|
||||
// Clear all existing session data
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetIDToken("")
|
||||
@@ -250,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetUserIdentifier(userIdentifier)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
@@ -290,7 +290,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
session.SetIDToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
// Clear CSRF tokens to prevent replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
|
||||
@@ -192,7 +192,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Pre-populate session with old data
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-access-token-with-many-characters")
|
||||
session.SetRefreshToken("old-refresh-token-with-many-characters")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
@@ -207,7 +207,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Verify old data is cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
|
||||
// Verify new data is set
|
||||
s.Equal(csrfToken, session.GetCSRF())
|
||||
@@ -711,7 +711,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
session, err := sessionManager.GetSession(req)
|
||||
s.Require().NoError(err)
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
session.mainSession.Values["redirect_count"] = 3
|
||||
|
||||
@@ -720,7 +720,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
|
||||
// Session should be cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
s.Empty(session.GetIDToken())
|
||||
|
||||
// Redirect count should be reset to 0 and then incremented by defaultInitiateAuthentication
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce("test-nonce")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("old-access-token")
|
||||
session.SetRefreshToken("old-refresh-token")
|
||||
session.SetIDToken("old-id-token")
|
||||
@@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Now perform selective clearing (as done in the fix)
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
@@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
|
||||
// Set initial session data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-token")
|
||||
session.SetCSRF("existing-csrf")
|
||||
|
||||
@@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
// OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF
|
||||
// NEW BEHAVIOR: Selective clearing
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
|
||||
@@ -70,6 +70,33 @@ overwrite it).
|
||||
Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local
|
||||
dev). Otherwise leave it at default.
|
||||
|
||||
### Streaming Endpoints (SSE and WebSocket)
|
||||
|
||||
The middleware automatically bypasses the OIDC redirect for two request kinds
|
||||
that browsers cannot follow a 302 on:
|
||||
|
||||
| Bypass | Triggered by |
|
||||
|--------|--------------|
|
||||
| Server-Sent Events (SSE) | `Accept: text/event-stream` |
|
||||
| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) |
|
||||
|
||||
These requests do **not** require any explicit configuration — they are
|
||||
handled implicitly. However, the bypass is **not** unauthenticated:
|
||||
|
||||
- A valid, encrypted session cookie is required. Requests without one are
|
||||
rejected (the connection cannot proceed to the backend).
|
||||
- The session cookie is sealed with `sessionEncryptionKey`, so the
|
||||
`authenticated` flag cannot be forged.
|
||||
- Validation is cookie-only — no JWK fetch / signature verification — so
|
||||
streaming endpoints keep working when the OIDC provider is briefly
|
||||
unavailable.
|
||||
- The user identifier from the session is forwarded as `X-Forwarded-User`
|
||||
(and `X-Auth-Request-User` unless `minimalHeaders: true`).
|
||||
|
||||
For browser clients, the user must complete the normal OIDC flow on a
|
||||
regular HTTP page first; the resulting session cookie is then reused on the
|
||||
SSE / WebSocket connection.
|
||||
|
||||
---
|
||||
|
||||
## Security Options
|
||||
@@ -113,6 +140,7 @@ strictAudienceValidation: true
|
||||
|-----------|------|---------|-------------|
|
||||
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
|
||||
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
|
||||
| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). |
|
||||
| `cookieDomain` | string | auto-detected | Domain for session cookies |
|
||||
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
|
||||
|
||||
|
||||
+11
-1
@@ -718,6 +718,11 @@ spec:
|
||||
<td class="py-2 px-3">86400</td>
|
||||
<td class="py-2 px-3">Maximum session age in seconds (24 hours default)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">maxRefreshTokenAgeSeconds</code></td>
|
||||
<td class="py-2 px-3">21600</td>
|
||||
<td class="py-2 px-3">Heuristic upper bound on stored refresh-token lifetime (6 hours default). Past this, the plugin treats the RT as expired without contacting the IdP. Set <code>0</code> to disable.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">cookiePrefix</code></td>
|
||||
<td class="py-2 px-3">_oidc_raczylo_</td>
|
||||
@@ -858,7 +863,12 @@ spec:
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.enableTLS</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections (e.g. AWS ElastiCache in-transit encryption)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.tlsSkipVerify</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Skip TLS server certificate verification (testing only; not recommended in production)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
Vendored
+3
@@ -24,6 +24,7 @@ type Config struct {
|
||||
Type BackendType
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
TLSServerName string
|
||||
PoolSize int
|
||||
RedisDB int
|
||||
CleanupInterval time.Duration
|
||||
@@ -34,6 +35,8 @@ type Config struct {
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
EnableMetrics bool
|
||||
EnableTLS bool
|
||||
TLSSkipVerify bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
|
||||
Vendored
+3
@@ -49,6 +49,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
poolConfig := &PoolConfig{
|
||||
Address: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
TLSServerName: config.TLSServerName,
|
||||
DB: config.RedisDB,
|
||||
MaxConnections: config.PoolSize,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
@@ -57,6 +58,8 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
EnableHealthCheck: true,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
EnableTLS: config.EnableTLS,
|
||||
TLSSkipVerify: config.TLSSkipVerify,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(poolConfig)
|
||||
|
||||
+25
-3
@@ -2,6 +2,7 @@ package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -31,6 +32,7 @@ type ConnectionPool struct {
|
||||
type PoolConfig struct {
|
||||
Address string
|
||||
Password string
|
||||
TLSServerName string // SNI server name; defaults to host(Address) when empty
|
||||
DB int
|
||||
MaxConnections int
|
||||
ConnectTimeout time.Duration
|
||||
@@ -39,6 +41,8 @@ type PoolConfig struct {
|
||||
EnableHealthCheck bool // Enable connection health validation
|
||||
MaxRetries int // Max retries for failed operations
|
||||
RetryDelay time.Duration // Initial delay between retries
|
||||
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
|
||||
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
|
||||
}
|
||||
|
||||
// NewConnectionPool creates a new connection pool
|
||||
@@ -96,7 +100,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
conn, err = p.createConnection(ctx)
|
||||
if err != nil {
|
||||
// If this is the last attempt, return error
|
||||
if attempt == maxAttempts-1 {
|
||||
@@ -193,13 +197,31 @@ func (p *ConnectionPool) Stats() map[string]interface{} {
|
||||
}
|
||||
|
||||
// createConnection creates a new Redis connection
|
||||
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
|
||||
// Connect with timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: p.config.ConnectTimeout,
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("tcp", p.config.Address)
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if p.config.EnableTLS {
|
||||
serverName := p.config.TLSServerName
|
||||
if serverName == "" {
|
||||
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
|
||||
serverName = host
|
||||
}
|
||||
}
|
||||
tlsCfg := &tls.Config{
|
||||
ServerName: serverName,
|
||||
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
|
||||
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
} else {
|
||||
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
+230
@@ -0,0 +1,230 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// drainRESPRequest consumes a single RESP request (array or inline) from r and
|
||||
// returns true on success. Any read error returns false.
|
||||
func drainRESPRequest(r *bufio.Reader) bool {
|
||||
header, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(header, "*") {
|
||||
return true // inline command (single line) — already consumed
|
||||
}
|
||||
n, err := strconv.Atoi(strings.TrimRight(strings.TrimPrefix(header, "*"), "\r\n"))
|
||||
if err != nil || n <= 0 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
// Each bulk: "$len\r\n<bytes>\r\n"
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// startTLSPingServer spins up a TLS listener that speaks just enough RESP to
|
||||
// answer PING with +PONG. Returns the listener address and a self-signed cert.
|
||||
func startTLSPingServer(t *testing.T) (addr string, certPEM []byte, stop func()) {
|
||||
t.Helper()
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "localhost"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsCert := tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
|
||||
listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stopCh := make(chan struct{})
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
c, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(conn net.Conn) {
|
||||
defer wg.Done()
|
||||
defer conn.Close()
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
if !drainRESPRequest(reader) {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("+PONG\r\n"))
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
stop = func() {
|
||||
close(stopCh)
|
||||
_ = listener.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
return listener.Addr().String(), der, stop
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_SkipVerify verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=true successfully negotiates TLS and exchanges a Redis command.
|
||||
// Regression test for issue #133 (enableTLS not propagated to client).
|
||||
func TestConnectionPool_TLSDial_SkipVerify(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_VerifyFails verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=false rejects a self-signed server cert.
|
||||
func TestConnectionPool_TLSDial_VerifyFails(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "tls")
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_PlainServerRejected verifies that EnableTLS=true
|
||||
// fails to handshake against a plain (non-TLS) listener.
|
||||
func TestConnectionPool_TLSDial_PlainServerRejected(t *testing.T) {
|
||||
plain, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer plain.Close()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, acceptErr := plain.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: plain.Addr().String(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 1 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestConnectionPool_PlainDial_StillWorks ensures non-TLS path is unaffected
|
||||
// when EnableTLS=false (default).
|
||||
func TestConnectionPool_PlainDial_StillWorks(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestIssue132_RefreshTokenHonorsUserIdentifierClaim reproduces and verifies
|
||||
// the fix for issue #132: token refresh path hardcoded the "email" claim and
|
||||
// ignored the configured userIdentifierClaim. Keycloak users without an email
|
||||
// claim (using sub or another identifier) were being kicked out on refresh
|
||||
// even though their initial login worked.
|
||||
//
|
||||
// The callback path (auth_flow.go) already honored userIdentifierClaim with
|
||||
// "sub" fallback. The refresh path (token_manager.go) had drifted out of sync
|
||||
// after PR #100 (commit a316a98).
|
||||
func TestIssue132_RefreshTokenHonorsUserIdentifierClaim(t *testing.T) {
|
||||
tests := []struct {
|
||||
claims map[string]any
|
||||
name string
|
||||
userIdentifierClaim string
|
||||
expectedIdentifier string
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "sub claim configured, only sub present (Keycloak no-email case)",
|
||||
userIdentifierClaim: "sub",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-keycloak-12345",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user-uuid-keycloak-12345",
|
||||
},
|
||||
{
|
||||
name: "preferred_username configured, claim present",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"preferred_username": "alice",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "alice",
|
||||
},
|
||||
{
|
||||
name: "configured claim missing, falls back to sub",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "fallback-sub-id",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "fallback-sub-id",
|
||||
},
|
||||
{
|
||||
name: "email default, email present (backward compatibility)",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"email": "user@example.com",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "email default, no email and no sub - refresh fails",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: false,
|
||||
expectedIdentifier: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager(
|
||||
"test-encryption-key-32-bytes-long!!",
|
||||
false,
|
||||
"",
|
||||
"",
|
||||
0,
|
||||
NewLogger("error"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("session manager: %v", err)
|
||||
}
|
||||
defer sessionManager.Shutdown()
|
||||
|
||||
capturedClaims := tt.claims
|
||||
tOidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
userIdentifierClaim: tt.userIdentifierClaim,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: &EnhancedMockTokenExchanger{
|
||||
RefreshResponse: &TokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
IDToken: "new-id-token-jwt",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
},
|
||||
tokenVerifier: &EnhancedMockTokenVerifier{Err: nil},
|
||||
extractClaimsFunc: func(token string) (map[string]any, error) {
|
||||
return capturedClaims, nil
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("get session: %v", err)
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
session.SetRefreshToken("initial-refresh-token")
|
||||
|
||||
refreshed := tOidc.refreshToken(rw, req, session)
|
||||
|
||||
if refreshed != tt.expectSuccess {
|
||||
t.Fatalf("refreshToken() = %v, want %v", refreshed, tt.expectSuccess)
|
||||
}
|
||||
|
||||
if got := session.GetUserIdentifier(); got != tt.expectedIdentifier {
|
||||
t.Errorf("session.GetUserIdentifier() = %q, want %q", got, tt.expectedIdentifier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+199
-47
@@ -79,34 +79,186 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestServeHTTP_EventStream tests the event-stream bypass functionality
|
||||
// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the
|
||||
// handshake must skip the OIDC redirect dance (clients can't follow it
|
||||
// mid-stream) but it must STILL require an authenticated session, otherwise
|
||||
// any caller could reach the backend by setting Accept: text/event-stream.
|
||||
func TestServeHTTP_EventStream(t *testing.T) {
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
newOidc := func(next http.Handler) *TraefikOidc {
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
return oidc
|
||||
}
|
||||
|
||||
t.Run("unauthenticated_request_is_rejected", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Error("backend handler must NOT be called for unauthenticated SSE bypass")
|
||||
}
|
||||
})
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
var forwardedUser string
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
forwardedUser = r.Header.Get("X-Forwarded-User")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
// Build an authenticated session and inject its cookies onto req.
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
setupRW := httptest.NewRecorder()
|
||||
if err := session.Save(req, setupRW); err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
}
|
||||
for _, c := range setupRW.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Fatal("expected authenticated SSE request to be forwarded to backend")
|
||||
}
|
||||
if forwardedUser != "user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket
|
||||
// handshake bypasses the OIDC redirect (clients can't follow it) but the
|
||||
// session must already be authenticated, otherwise the backend is exposed
|
||||
// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`.
|
||||
func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
newOidc := func(next http.Handler) *TraefikOidc {
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
return oidc
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
rw := httptest.NewRecorder()
|
||||
t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
}))
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
if !nextCalled {
|
||||
t.Error("expected event-stream request to bypass OIDC")
|
||||
}
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Error("backend handler must NOT be called for unauthenticated WS bypass")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
var forwardedUser string
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
forwardedUser = r.Header.Get("X-Forwarded-User")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
// Mixed-case + multi-token Connection header to exercise parsing.
|
||||
req.Header.Set("Connection", "keep-alive, Upgrade")
|
||||
req.Header.Set("Upgrade", "WebSocket")
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("ws-user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
setupRW := httptest.NewRecorder()
|
||||
if err := session.Save(req, setupRW); err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
}
|
||||
for _, c := range setupRW.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Fatal("expected authenticated WS handshake to be forwarded to backend")
|
||||
}
|
||||
if forwardedUser != "ws-user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("plain_http_does_not_bypass", func(t *testing.T) {
|
||||
// Sanity: requests without Upgrade headers must NOT hit the WS
|
||||
// bypass branch (otherwise the new code path could short-circuit
|
||||
// normal authentication).
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("backend must not be called for unauthenticated plain HTTP")
|
||||
}))
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
if rw.Code == http.StatusOK {
|
||||
t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestServeHTTP_InitializationTimeout tests initialization timeout handling
|
||||
@@ -256,7 +408,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "successful authorization with email",
|
||||
setupSession: func() *MockSessionData {
|
||||
session := &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -288,7 +440,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "no email triggers reauth",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "",
|
||||
userIdentifier: "",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -309,7 +461,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "roles and groups authorization",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -342,7 +494,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "unauthorized role/group returns 403",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -369,7 +521,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "template headers processing",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -401,7 +553,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "OPTIONS request with CORS",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -452,7 +604,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
manager: &SessionManager{logger: NewLogger("debug")},
|
||||
}
|
||||
// Copy values from mock to concrete session
|
||||
concreteSession.SetEmail(session.email)
|
||||
concreteSession.SetUserIdentifier(session.userIdentifier)
|
||||
concreteSession.SetIDToken(session.idToken)
|
||||
concreteSession.SetAccessToken(session.accessToken)
|
||||
concreteSession.SetRefreshToken(session.refreshToken)
|
||||
@@ -502,23 +654,23 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
|
||||
// MockSessionData is a test implementation of SessionData interface
|
||||
type MockSessionData struct {
|
||||
email string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
userIdentifier string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
}
|
||||
|
||||
func (m *MockSessionData) GetEmail() string { return m.email }
|
||||
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
|
||||
func (m *MockSessionData) GetIDToken() string { return m.idToken }
|
||||
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
|
||||
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
|
||||
func (m *MockSessionData) SetEmail(email string) { m.email = email }
|
||||
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
|
||||
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
|
||||
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
|
||||
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
|
||||
@@ -610,7 +762,7 @@ func TestMinimalHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
// Set up session data
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Call processAuthorizedRequest directly
|
||||
@@ -685,7 +837,7 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
@@ -771,7 +923,7 @@ func TestStripAuthCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Now add OIDC session cookies (simulating what the browser would send)
|
||||
@@ -852,7 +1004,7 @@ func TestStripAuthCookies_NoCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
@@ -899,7 +1051,7 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only OIDC cookies
|
||||
@@ -950,7 +1102,7 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only non-OIDC cookies
|
||||
@@ -1013,7 +1165,7 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add cookies with the custom prefix (should be stripped)
|
||||
|
||||
+15
-15
@@ -580,7 +580,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case to avoid replay issues
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -603,7 +603,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
// even if session.SetAuthenticated(true) was called.
|
||||
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
|
||||
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -660,7 +660,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -678,7 +678,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -706,7 +706,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -741,7 +741,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(nearExpiryToken)
|
||||
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
|
||||
},
|
||||
@@ -772,7 +772,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(validToken)
|
||||
session.SetIDToken(validToken) // Ensure ID token is also set
|
||||
session.SetRefreshToken("should-not-be-used-refresh-token")
|
||||
@@ -792,7 +792,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -814,7 +814,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -2179,7 +2179,7 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAccessToken(expiredToken)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
},
|
||||
expectedPath: "/original/path",
|
||||
},
|
||||
@@ -2756,7 +2756,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2782,7 +2782,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2809,7 +2809,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
@@ -2829,7 +2829,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2851,7 +2851,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{},
|
||||
|
||||
+150
-78
@@ -14,21 +14,40 @@ import (
|
||||
)
|
||||
|
||||
// bypassReason describes why a request is being forwarded without OIDC auth.
|
||||
// It is only used for logging and to decide whether extra SSE-specific work
|
||||
// It is only used for logging and to decide whether extra side-effects
|
||||
// (propagating the user header from an existing session) should run.
|
||||
const (
|
||||
bypassReasonExcluded = "excluded-url"
|
||||
bypassReasonSSE = "sse"
|
||||
bypassReasonExcluded = "excluded-url"
|
||||
bypassReasonSSE = "sse"
|
||||
bypassReasonWebSocket = "websocket"
|
||||
)
|
||||
|
||||
// isWebSocketUpgrade reports whether req is a WebSocket upgrade handshake
|
||||
// (RFC 6455). The middleware can only see the handshake; once Traefik
|
||||
// completes the upgrade it forwards frames directly, so we never re-process
|
||||
// per-frame traffic. We bypass auth on the handshake the same way we do for
|
||||
// SSE, because browser WebSocket clients cannot follow an OIDC redirect.
|
||||
func isWebSocketUpgrade(req *http.Request) bool {
|
||||
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
|
||||
return false
|
||||
}
|
||||
for _, token := range strings.Split(req.Header.Get("Connection"), ",") {
|
||||
if strings.EqualFold(strings.TrimSpace(token), "upgrade") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// shouldBypassAuth decides whether a request must skip OIDC authentication
|
||||
// entirely. It returns (true, reason) when either the request path matches a
|
||||
// configured excluded URL or the Accept header asks for a text/event-stream
|
||||
// response (SSE). The reason lets ServeHTTP apply any side-effects that are
|
||||
// unique to the bypass kind (e.g. propagating user headers for SSE).
|
||||
// configured excluded URL, the Accept header asks for a text/event-stream
|
||||
// response (SSE), or the request is a WebSocket upgrade handshake. The
|
||||
// reason lets ServeHTTP apply any side-effects that are unique to the bypass
|
||||
// kind (e.g. propagating user headers).
|
||||
//
|
||||
// This must be called BEFORE waiting on t.initComplete so excluded and SSE
|
||||
// traffic is never blocked by a slow/broken provider.
|
||||
// This must be called BEFORE waiting on t.initComplete so excluded, SSE and
|
||||
// WebSocket traffic is never blocked by a slow/broken provider.
|
||||
func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
return true, bypassReasonExcluded
|
||||
@@ -36,38 +55,55 @@ func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
|
||||
if strings.Contains(req.Header.Get("Accept"), "text/event-stream") {
|
||||
return true, bypassReasonSSE
|
||||
}
|
||||
if isWebSocketUpgrade(req) {
|
||||
return true, bypassReasonWebSocket
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// applySSEUserHeaders attempts to copy the authenticated user's identity from
|
||||
// an existing session onto the outgoing SSE request so downstream services
|
||||
// can still see who the user is. Failures are logged (not silenced) because
|
||||
// they indicate either a corrupt cookie or a misconfigured session manager
|
||||
// and are useful for debugging, but they never block the bypass itself.
|
||||
func (t *TraefikOidc) applySSEUserHeaders(req *http.Request) {
|
||||
// applyBypassUserHeaders enforces authentication on SSE / WebSocket bypass
|
||||
// requests and, on success, copies the authenticated user's identity onto
|
||||
// the outgoing request so downstream services can see who the user is.
|
||||
//
|
||||
// Returns true when the request carries a valid authenticated session and
|
||||
// the bypass should proceed. Returns false when no usable session is
|
||||
// present; callers must then reject the request (typically with 401) to
|
||||
// prevent unauthenticated traffic from reaching the backend just by setting
|
||||
// `Accept: text/event-stream` or sending a WebSocket upgrade.
|
||||
//
|
||||
// The check is cookie-only: the session cookie is sealed by our encryption
|
||||
// key, so the authenticated flag cannot be forged. We do NOT run full token
|
||||
// signature verification here so that SSE/WS keeps working when the OIDC
|
||||
// provider is briefly unavailable for JWK fetches.
|
||||
func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) bool {
|
||||
if t.sessionManager == nil {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
// Intentionally not fatal: SSE requests bypass auth, we just lose the
|
||||
// forwarded-user header for this request.
|
||||
t.logger.Debugf("SSE bypass: unable to load session for user header propagation: %v", err)
|
||||
return
|
||||
t.logger.Debugf("%s bypass: unable to load session: %v", reason, err)
|
||||
return false
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
return
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debugf("%s bypass: rejecting request without authenticated session", reason)
|
||||
return false
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason)
|
||||
return false
|
||||
}
|
||||
t.logger.Debugf("SSE bypass: forwarded user %s from session", email)
|
||||
|
||||
req.Header.Set("X-Forwarded-User", userIdentifier)
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-User", userIdentifier)
|
||||
}
|
||||
t.logger.Debugf("%s bypass: forwarded user %s from session", reason, userIdentifier)
|
||||
return true
|
||||
}
|
||||
|
||||
// ServeHTTP implements the main middleware logic for processing HTTP requests.
|
||||
@@ -124,16 +160,32 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Evaluate auth-bypass once, before waiting for initialization. Excluded URLs
|
||||
// and SSE requests must not block on provider init. For SSE we additionally
|
||||
// attempt to forward the user identity from an existing session (best
|
||||
// effort) so downstream handlers still see X-Forwarded-User.
|
||||
// Evaluate auth-bypass once, before waiting for initialization. Excluded
|
||||
// URLs, SSE and WebSocket upgrade requests must not block on provider
|
||||
// init. For SSE/WebSocket we ALSO require an authenticated session
|
||||
// (cookie-only check, no JWK fetch) and otherwise return 401 — clients
|
||||
// of in-flight streams can't follow an OIDC redirect, so forwarding
|
||||
// unauthenticated traffic would silently expose the backend.
|
||||
if bypass, reason := t.shouldBypassAuth(req); bypass {
|
||||
t.logger.Debugf("Bypassing OIDC for %s (%s)", req.URL.Path, reason)
|
||||
if reason == bypassReasonSSE {
|
||||
t.applySSEUserHeaders(req)
|
||||
switch reason {
|
||||
case bypassReasonExcluded:
|
||||
// Operator-declared excluded URLs forward unconditionally.
|
||||
t.next.ServeHTTP(rw, req)
|
||||
case bypassReasonSSE, bypassReasonWebSocket:
|
||||
// Skip the OIDC redirect dance (clients can't follow it
|
||||
// mid-stream) but still require an authenticated session.
|
||||
// Otherwise an unauthenticated client could hit the backend
|
||||
// just by setting Accept: text/event-stream or sending a
|
||||
// WebSocket upgrade.
|
||||
if !t.applyBypassUserHeaders(req, reason) {
|
||||
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
t.next.ServeHTTP(rw, req)
|
||||
default:
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -237,7 +289,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
// User authorization check
|
||||
if authenticated && userIdentifier != "" {
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
@@ -309,7 +361,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if refreshed {
|
||||
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
|
||||
userIdentifier = session.GetUserIdentifier()
|
||||
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
|
||||
@@ -359,9 +411,9 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// - session: The user's session data containing tokens and claims.
|
||||
// - redirectURL: The callback URL for re-authentication if needed.
|
||||
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
t.logger.Info("No email found in session during final processing, initiating re-auth")
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
|
||||
// Reset redirect count to prevent loops when session is invalid
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
@@ -374,7 +426,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
if idToken != "" {
|
||||
sid, sub, createdAt := t.extractSessionInfo(idToken)
|
||||
if t.isSessionInvalidated(sid, sub, createdAt) {
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email)
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
|
||||
// Clear the session and redirect to login
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing invalidated session: %v", err)
|
||||
@@ -386,31 +438,52 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
|
||||
tokenForClaims := session.GetIDToken()
|
||||
if tokenForClaims == "" {
|
||||
tokenForClaims = session.GetAccessToken()
|
||||
if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
// Reset redirect count to prevent loops when token is missing
|
||||
// Resolve ID-token claims at most once per request. SessionData caches
|
||||
// the parsed claims keyed on the raw ID token, so concurrent dashboard
|
||||
// panel requests on the same session don't repeatedly base64-decode and
|
||||
// JSON-unmarshal the same JWT (a real cost under the yaegi interpreter
|
||||
// that hosts Traefik plugins). idClaims is reused below by the
|
||||
// header-templates branch.
|
||||
idToken := session.GetIDToken()
|
||||
var (
|
||||
idClaims map[string]interface{}
|
||||
idClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
|
||||
}
|
||||
|
||||
// Choose which claims drive groups/roles extraction. Prefer the ID
|
||||
// token (cached) and fall back to the access token if there is no ID
|
||||
// token in the session — matching the prior behavior for opaque
|
||||
// ID-token providers.
|
||||
var (
|
||||
groupClaims map[string]interface{}
|
||||
groupClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
groupClaims, groupClaimsErr = idClaims, idClaimsErr
|
||||
} else if accessToken := session.GetAccessToken(); accessToken != "" {
|
||||
groupClaims, groupClaimsErr = t.extractClaimsFunc(accessToken)
|
||||
} else if len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
var groups, roles []string
|
||||
|
||||
if groupClaimsErr == nil && groupClaims != nil {
|
||||
var err error
|
||||
groups, roles, err = t.extractGroupsAndRolesFromClaims(groupClaims)
|
||||
if err != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize empty slices
|
||||
var groups, roles []string
|
||||
|
||||
if tokenForClaims != "" {
|
||||
var err error
|
||||
groups, roles, err = t.extractGroupsAndRoles(tokenForClaims)
|
||||
if err != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
// Reset redirect count to prevent loops when claim extraction fails
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
} else if err == nil {
|
||||
if err == nil {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
}
|
||||
@@ -429,54 +502,53 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
t.logger.Infof("User %s does not have any allowed roles or groups", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
req.Header.Set("X-Forwarded-User", userIdentifier)
|
||||
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-User", userIdentifier)
|
||||
if idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.headerTemplates) > 0 {
|
||||
// Reuse claims parsed earlier in this request if the ID token has not
|
||||
// changed. Saves an unnecessary JWT parse on every authenticated
|
||||
// request that uses headerTemplates.
|
||||
claims, err := session.GetIDTokenClaims(t.extractClaimsFunc)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", err)
|
||||
if idClaimsErr != nil {
|
||||
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", idClaimsErr)
|
||||
} else {
|
||||
// idClaims may be nil when no ID token is present; templates
|
||||
// referencing .Claims.* will simply produce empty values, which
|
||||
// matches the prior behavior.
|
||||
templateData := map[string]interface{}{
|
||||
"AccessToken": session.GetAccessToken(),
|
||||
"IDToken": session.GetIDToken(),
|
||||
"IDToken": idToken,
|
||||
"RefreshToken": session.GetRefreshToken(),
|
||||
"Claims": claims,
|
||||
"Claims": idClaims,
|
||||
}
|
||||
|
||||
for headerName, tmpl := range t.headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := tmpl.Execute(&buf, templateData); err != nil {
|
||||
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
|
||||
continue
|
||||
}
|
||||
headerValue := buf.String()
|
||||
|
||||
req.Header.Set(headerName, headerValue)
|
||||
|
||||
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
|
||||
}
|
||||
session.MarkDirty()
|
||||
t.logger.Debugf("Session marked dirty after templated header processing.")
|
||||
// NOTE: templates only mutate request headers (not session state),
|
||||
// so we deliberately do NOT MarkDirty / Save here. Previously every
|
||||
// authenticated request with header templates re-encrypted and
|
||||
// rewrote all session cookies, which was a measurable CPU and
|
||||
// Set-Cookie tax on dashboards that poll many panels per second.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -515,7 +587,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", userIdentifier)
|
||||
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create authenticated session
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetIDToken("dummy-token")
|
||||
session.Save(req, httptest.NewRecorder())
|
||||
@@ -203,7 +203,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create session with forbidden domain
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@forbidden.com")
|
||||
session.SetUserIdentifier("user@forbidden.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Save and inject cookies
|
||||
@@ -252,7 +252,7 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
|
||||
// Create session with opaque token
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots)
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
@@ -291,7 +291,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("") // No email
|
||||
session.SetUserIdentifier("") // No email
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -321,7 +321,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("") // No ID token
|
||||
session.SetAccessToken("") // No access token
|
||||
|
||||
@@ -349,7 +349,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -383,7 +383,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
testEmail := "user@example.com"
|
||||
session.SetEmail(testEmail)
|
||||
session.SetUserIdentifier(testEmail)
|
||||
session.SetIDToken("dummy-id-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@@ -129,7 +129,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
|
||||
// Simulate successful Azure authentication
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Azure may use opaque access tokens
|
||||
session.SetAccessToken("opaque-azure-access-token")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ") // trufflehog:ignore
|
||||
@@ -152,7 +152,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, session2.GetAuthenticated(), "User should remain authenticated")
|
||||
assert.Equal(t, "user@example.com", session2.GetEmail())
|
||||
assert.Equal(t, "user@example.com", session2.GetUserIdentifier())
|
||||
assert.NotEmpty(t, session2.GetAccessToken(), "Access token should persist")
|
||||
assert.NotEmpty(t, session2.GetIDToken(), "ID token should persist")
|
||||
assert.NotEmpty(t, session2.GetRefreshToken(), "Refresh token should persist")
|
||||
|
||||
@@ -485,7 +485,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
|
||||
// Set up the attacker's session with malicious data
|
||||
attackerSession.SetAuthenticated(true)
|
||||
attackerSession.SetEmail("attacker@evil.com")
|
||||
attackerSession.SetUserIdentifier("attacker@evil.com")
|
||||
attackerSession.SetIDToken(ValidIDToken)
|
||||
attackerSession.SetAccessToken(ValidAccessToken)
|
||||
|
||||
@@ -512,7 +512,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get the email from the session
|
||||
email := session.GetEmail()
|
||||
email := session.GetUserIdentifier()
|
||||
w.Header().Set("X-User-Email", email)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
+26
-26
@@ -100,7 +100,7 @@ type combinedSessionPayload struct {
|
||||
A string `json:"a,omitempty"`
|
||||
R string `json:"r,omitempty"`
|
||||
I string `json:"i,omitempty"`
|
||||
E string `json:"e,omitempty"`
|
||||
Ui string `json:"ui,omitempty"`
|
||||
Cs string `json:"cs,omitempty"`
|
||||
N string `json:"n,omitempty"`
|
||||
Cv string `json:"cv,omitempty"`
|
||||
@@ -113,11 +113,11 @@ type combinedSessionPayload struct {
|
||||
// knownSessionKeys are the standard keys that are handled explicitly in the combined payload.
|
||||
// All other mainSession.Values keys are stored in the X (extra) field.
|
||||
var knownSessionKeys = map[string]bool{
|
||||
"access_token": true,
|
||||
"refresh_token": true,
|
||||
"id_token": true,
|
||||
"email": true,
|
||||
"authenticated": true,
|
||||
"access_token": true,
|
||||
"refresh_token": true,
|
||||
"id_token": true,
|
||||
"user_identifier": true,
|
||||
"authenticated": true,
|
||||
"csrf": true,
|
||||
"nonce": true,
|
||||
"code_verifier": true,
|
||||
@@ -1134,7 +1134,7 @@ func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *
|
||||
sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName())
|
||||
|
||||
// Populate legacy session values from combined payload
|
||||
sessionData.mainSession.Values["email"] = payload.E
|
||||
sessionData.mainSession.Values["user_identifier"] = payload.Ui
|
||||
sessionData.mainSession.Values["authenticated"] = payload.Au
|
||||
sessionData.mainSession.Values["csrf"] = payload.Cs
|
||||
sessionData.mainSession.Values["nonce"] = payload.N
|
||||
@@ -1278,7 +1278,7 @@ func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, opti
|
||||
A: sd.getAccessTokenUnsafe(),
|
||||
R: sd.getRefreshTokenUnsafe(),
|
||||
I: sd.getIDTokenUnsafe(),
|
||||
E: sd.getEmailUnsafe(),
|
||||
Ui: sd.getUserIdentifierUnsafe(),
|
||||
Au: sd.getAuthenticatedUnsafe(),
|
||||
Cs: sd.getCSRFUnsafe(),
|
||||
N: sd.getNonceUnsafe(),
|
||||
@@ -2469,30 +2469,30 @@ func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address.
|
||||
// The email is extracted from ID token claims and used for
|
||||
// authorization decisions and header injection.
|
||||
// GetUserIdentifier retrieves the authenticated user's identifier as extracted
|
||||
// from the configured userIdentifierClaim of the ID token (email, sub, oid,
|
||||
// upn, preferred_username, etc.). The value is used for authorization
|
||||
// decisions and header injection.
|
||||
// Returns:
|
||||
// - The user's email address string, or an empty string if not set.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
// - The user identifier string, or an empty string if not set.
|
||||
func (sd *SessionData) GetUserIdentifier() string {
|
||||
sd.sessionMutex.RLock()
|
||||
defer sd.sessionMutex.RUnlock()
|
||||
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
return userIdentifier
|
||||
}
|
||||
|
||||
// SetEmail stores the authenticated user's email address.
|
||||
// The email is typically extracted from the 'email' claim in the ID token.
|
||||
// SetUserIdentifier stores the authenticated user's identifier value.
|
||||
// Parameters:
|
||||
// - email: The user's email address to store.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
// - userIdentifier: The user identifier to store (email, sub, or other claim value).
|
||||
func (sd *SessionData) SetUserIdentifier(userIdentifier string) {
|
||||
sd.sessionMutex.Lock()
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
currentVal, _ := sd.mainSession.Values["email"].(string)
|
||||
if currentVal != email {
|
||||
sd.mainSession.Values["email"] = email
|
||||
currentVal, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
if currentVal != userIdentifier {
|
||||
sd.mainSession.Values["user_identifier"] = userIdentifier
|
||||
sd.dirty = true
|
||||
}
|
||||
}
|
||||
@@ -2626,10 +2626,10 @@ func (sd *SessionData) getRefreshTokenUnsafe() string {
|
||||
return result.Token
|
||||
}
|
||||
|
||||
// getEmailUnsafe retrieves the email without acquiring locks.
|
||||
func (sd *SessionData) getEmailUnsafe() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
// getUserIdentifierUnsafe retrieves the user identifier without acquiring locks.
|
||||
func (sd *SessionData) getUserIdentifierUnsafe() string {
|
||||
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
return userIdentifier
|
||||
}
|
||||
|
||||
// getCSRFUnsafe retrieves the CSRF token without acquiring locks.
|
||||
|
||||
@@ -320,17 +320,16 @@ func (s *SessionBehaviourSuite) TestSessionData_DirtyTracking() {
|
||||
s.False(session.IsDirty())
|
||||
}
|
||||
|
||||
// TestSessionData_SetEmail tests email setter with dirty tracking
|
||||
func (s *SessionBehaviourSuite) TestSessionData_SetEmail() {
|
||||
// TestSessionData_SetUserIdentifier tests user identifier setter with dirty tracking
|
||||
func (s *SessionBehaviourSuite) TestSessionData_SetUserIdentifier() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
session, err := s.sessionManager.GetSession(req)
|
||||
s.Require().NoError(err)
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
// Set email
|
||||
session.SetEmail("test@example.com")
|
||||
s.Equal("test@example.com", session.GetEmail())
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
s.Equal("test@example.com", session.GetUserIdentifier())
|
||||
s.True(session.IsDirty())
|
||||
}
|
||||
|
||||
@@ -568,7 +567,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Clear() {
|
||||
// Set some data
|
||||
err = session.SetAuthenticated(true)
|
||||
s.Require().NoError(err)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.SetCSRF("csrf-token")
|
||||
|
||||
// Clear session
|
||||
@@ -588,7 +587,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Save() {
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
// Modify session
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
s.True(session.IsDirty())
|
||||
|
||||
// Save session
|
||||
|
||||
+6
-6
@@ -2688,7 +2688,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
|
||||
// Set up initial session state (what user has when first logging in)
|
||||
session1.SetAuthenticated(true)
|
||||
session1.SetEmail(originalUserData["email"].(string))
|
||||
session1.SetUserIdentifier(originalUserData["email"].(string))
|
||||
session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars")
|
||||
session1.SetIDToken("initial-valid-id-token-longer-than-20-chars")
|
||||
session1.SetRefreshToken("valid-refresh-token-should-last-30-days")
|
||||
@@ -2732,7 +2732,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
// Simulate what happens when middleware detects expired tokens
|
||||
// It should preserve session state while attempting token refresh
|
||||
originalAuth := session2.GetAuthenticated()
|
||||
originalEmail := session2.GetEmail()
|
||||
originalEmail := session2.GetUserIdentifier()
|
||||
|
||||
// Reconstruct user data from individual stored keys
|
||||
originalUserDataStored := make(map[string]interface{})
|
||||
@@ -2813,7 +2813,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
|
||||
// Verify all session data is still intact after token refresh
|
||||
postRefreshAuth := session2.GetAuthenticated()
|
||||
postRefreshEmail := session2.GetEmail()
|
||||
postRefreshEmail := session2.GetUserIdentifier()
|
||||
userDataPresent := true
|
||||
for k := range originalUserData {
|
||||
if session2.mainSession.Values["user_data_"+k] == nil {
|
||||
@@ -2907,7 +2907,7 @@ func TestSessionExpiryVsTokenExpiry(t *testing.T) {
|
||||
|
||||
// Set up session with specific creation time
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.mainSession.Values["created_at"] = sessionCreatedAt.Unix()
|
||||
|
||||
// Create tokens with specific expiry
|
||||
@@ -3018,7 +3018,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
||||
|
||||
// Set up session with data that should be preserved or removed
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("cleanup@example.com")
|
||||
session.SetUserIdentifier("cleanup@example.com")
|
||||
|
||||
session.mainSession.Values["user_data"] = "Test User|user-123"
|
||||
session.mainSession.Values["preferences"] = "theme:dark,lang:en"
|
||||
@@ -3049,7 +3049,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
||||
if scenario.shouldCleanup {
|
||||
if sessionTooOld {
|
||||
session.SetAuthenticated(false)
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
for key := range session.mainSession.Values {
|
||||
|
||||
@@ -293,7 +293,7 @@ func (tf *TestFramework) CreateAuthenticatedRequest(method, path string) (*http.
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(tf.fixtures.UserEmail)
|
||||
session.SetUserIdentifier(tf.fixtures.UserEmail)
|
||||
session.SetAccessToken(tf.fixtures.AccessToken)
|
||||
session.SetRefreshToken(tf.fixtures.RefreshToken)
|
||||
session.SetIDToken(tf.GenerateJWT(tf.fixtures.Claims))
|
||||
|
||||
+32
-19
@@ -11,6 +11,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -433,7 +434,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
session.SetRefreshToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
// Clear CSRF tokens as well to prevent any replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
@@ -475,12 +476,18 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
|
||||
return false
|
||||
}
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
|
||||
return false
|
||||
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
if t.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
t.logger.Errorf("refreshToken failed: User identifier claim '%s' missing or empty in refreshed token", t.userIdentifierClaim)
|
||||
return false
|
||||
}
|
||||
t.logger.Debugf("Configured claim '%s' not found in refreshed token, using 'sub' claim as fallback", t.userIdentifierClaim)
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetUserIdentifier(userIdentifier)
|
||||
|
||||
// Get token expiry information for logging
|
||||
var expiryTime time.Time
|
||||
@@ -506,7 +513,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
session.SetAccessToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1193,9 +1200,14 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
sessionManager := t.sessionManager
|
||||
logger := t.logger
|
||||
|
||||
// Only use the fast cleanup interval when actually running under `go test`.
|
||||
// runtime.Compiler == "yaegi" makes isTestMode() return true in production
|
||||
// (Traefik interprets the plugin via yaegi), which would otherwise pin this
|
||||
// ticker to 20 Hz on a real cluster despite tokenCache.Cleanup and
|
||||
// jwkCache.Cleanup both being no-ops there.
|
||||
cleanupInterval := 1 * time.Minute
|
||||
if isTestMode() {
|
||||
cleanupInterval = 50 * time.Millisecond // Fast interval for tests
|
||||
if isTestMode() && runtime.Compiler != "yaegi" {
|
||||
cleanupInterval = 50 * time.Millisecond
|
||||
}
|
||||
|
||||
// Create cleanup function
|
||||
@@ -1237,25 +1249,27 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
}
|
||||
|
||||
// extractGroupsAndRoles extracts group and role information from token claims.
|
||||
// It parses the 'groups' and 'roles' claims from the ID token and validates their format.
|
||||
// Parameters:
|
||||
// - idToken: The ID token containing claims to extract.
|
||||
// It parses the configured group/role claims from the supplied ID token.
|
||||
//
|
||||
// Returns:
|
||||
// - groups: Array of group names from the 'groups' claim.
|
||||
// - roles: Array of role names from the 'roles' claim.
|
||||
// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present
|
||||
// but not arrays of strings.
|
||||
// Most callers should prefer extractGroupsAndRolesFromClaims when claims have
|
||||
// already been parsed for the request (e.g. via SessionData.GetIDTokenClaims),
|
||||
// to avoid re-parsing the JWT.
|
||||
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
|
||||
}
|
||||
return t.extractGroupsAndRolesFromClaims(claims)
|
||||
}
|
||||
|
||||
// extractGroupsAndRolesFromClaims extracts group and role information from
|
||||
// already-parsed claims. Hot path: callers that have a cached claims map (such
|
||||
// as SessionData.GetIDTokenClaims) should use this to skip a redundant
|
||||
// base64+JSON decode of the JWT on every authenticated request.
|
||||
func (t *TraefikOidc) extractGroupsAndRolesFromClaims(claims map[string]interface{}) ([]string, []string, error) {
|
||||
var groups []string
|
||||
var roles []string
|
||||
|
||||
// Extract groups using configurable claim name (defaults to "groups")
|
||||
if groupsClaim, exists := claims[t.groupClaimName]; exists {
|
||||
groupsSlice, ok := groupsClaim.([]interface{})
|
||||
if !ok {
|
||||
@@ -1271,7 +1285,6 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract roles using configurable claim name (defaults to "roles")
|
||||
if rolesClaim, exists := claims[t.roleClaimName]; exists {
|
||||
rolesSlice, ok := rolesClaim.([]interface{})
|
||||
if !ok {
|
||||
|
||||
@@ -210,6 +210,8 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
|
||||
RedisPrefix: redisConfig.KeyPrefix,
|
||||
PoolSize: redisConfig.PoolSize,
|
||||
EnableMetrics: true,
|
||||
EnableTLS: redisConfig.EnableTLS,
|
||||
TLSSkipVerify: redisConfig.TLSSkipVerify,
|
||||
}
|
||||
|
||||
// Use concrete type to avoid Yaegi reflection issues with interface assignment
|
||||
|
||||
Reference in New Issue
Block a user