mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 68c150eba4 | |||
| 9cbca4c4fb | |||
| 684a990f59 | |||
| 1b6c8616fd | |||
| 4d28fa01ab |
@@ -121,6 +121,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
|
||||
| `cookiePrefix` | `_oidc_raczylo_` | Unique prefix per middleware instance to isolate sessions. |
|
||||
| `sessionMaxAge` | `86400` | Session lifetime in seconds. |
|
||||
| `refreshGracePeriodSeconds` | `60` | Proactively refresh tokens this many seconds before expiry. |
|
||||
| `maxRefreshTokenAgeSeconds` | `21600` | Heuristic max stored refresh-token lifetime (6h). Past this, the plugin treats the RT as expired without contacting the IdP — returns 401 to AJAX, full re-auth on navigations. Set `0` to disable. Tune to match your IdP's RT TTL. |
|
||||
| `rateLimit` | `100` | Requests/sec. Min `10`. |
|
||||
| `logLevel` | `info` | `debug`, `info`, `error`. |
|
||||
| `audience` | `clientID` | Custom access-token audience (Auth0 custom APIs). |
|
||||
@@ -165,6 +166,22 @@ Each instance must use a unique `cookiePrefix` **and** `sessionEncryptionKey`,
|
||||
otherwise a session minted by one instance can grant access through another.
|
||||
See [issue #87](https://github.com/lukaszraczylo/traefikoidc/issues/87).
|
||||
|
||||
### SSE and WebSocket endpoints
|
||||
|
||||
Browser clients cannot follow an OIDC `302` redirect on an SSE stream or a
|
||||
WebSocket upgrade. The middleware handles this automatically:
|
||||
|
||||
- **SSE** (`Accept: text/event-stream`) and **WebSocket** (`Upgrade: websocket`)
|
||||
requests skip the OIDC redirect.
|
||||
- They are **not** unauthenticated — a valid encrypted session cookie is
|
||||
required, otherwise the request is rejected. The session must already exist
|
||||
(i.e. the user logged in via a normal HTTP page first).
|
||||
- `X-Forwarded-User` is forwarded from the session.
|
||||
- Validation is cookie-only (no JWK fetch), so streaming keeps working during
|
||||
brief IdP outages.
|
||||
|
||||
No configuration needed — this is implicit behavior.
|
||||
|
||||
### HTTP 431 from backends
|
||||
|
||||
Either the ID token or the chunked OIDC cookies overflow your backend's header
|
||||
|
||||
+1
-1
@@ -1491,7 +1491,7 @@ func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("Failed to set authenticated: %v", err)
|
||||
}
|
||||
session.SetEmail("user@company.com")
|
||||
session.SetUserIdentifier("user@company.com")
|
||||
session.SetIDToken(validJWT)
|
||||
session.SetAccessToken(validJWT)
|
||||
|
||||
|
||||
+30
-7
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// validateRedirectCount checks if redirect limit is exceeded and handles the error
|
||||
@@ -42,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
|
||||
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
|
||||
// Clear all existing session data
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetIDToken("")
|
||||
@@ -249,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetUserIdentifier(userIdentifier)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
@@ -289,7 +290,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
session.SetIDToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
// Clear CSRF tokens to prevent replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
@@ -360,9 +361,31 @@ func (t *TraefikOidc) isNonNavigationRequest(req *http.Request) bool {
|
||||
return !strings.Contains(accept, "text/html")
|
||||
}
|
||||
|
||||
// isRefreshTokenExpired checks if refresh token is likely expired (older than 6 hours)
|
||||
// isRefreshTokenExpired checks whether the stored refresh token is likely
|
||||
// past its useful lifetime, using the cookie-side issued_at timestamp set by
|
||||
// SetRefreshToken. IdPs do not expose RT TTL on the wire, so this is a
|
||||
// conservative heuristic gated by t.maxRefreshTokenAge (default 6h, set via
|
||||
// MaxRefreshTokenAgeSeconds; 0 disables the check).
|
||||
//
|
||||
// The point of this check is to short-circuit the refresh path BEFORE the
|
||||
// thundering herd hits the IdP for a token the provider has almost certainly
|
||||
// revoked. Together with the RefreshCoordinator wireup, it keeps Grafana-
|
||||
// style polling clients from looping on invalid_grant after a long pause.
|
||||
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
|
||||
// This is a heuristic check - actual implementation would depend on
|
||||
// the specific provider and token metadata
|
||||
return false // Placeholder implementation
|
||||
if t == nil || session == nil {
|
||||
return false
|
||||
}
|
||||
if t.maxRefreshTokenAge <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
issuedAt := session.GetRefreshTokenIssuedAt()
|
||||
if issuedAt.IsZero() {
|
||||
// No timestamp recorded (legacy session pre-dating the issued_at
|
||||
// field). Don't force a re-auth - attempt refresh once and let the
|
||||
// IdP be the source of truth.
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Since(issuedAt) > t.maxRefreshTokenAge
|
||||
}
|
||||
|
||||
@@ -192,7 +192,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Pre-populate session with old data
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-access-token-with-many-characters")
|
||||
session.SetRefreshToken("old-refresh-token-with-many-characters")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
@@ -207,7 +207,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
|
||||
|
||||
// Verify old data is cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
|
||||
// Verify new data is set
|
||||
s.Equal(csrfToken, session.GetCSRF())
|
||||
@@ -711,7 +711,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
session, err := sessionManager.GetSession(req)
|
||||
s.Require().NoError(err)
|
||||
_ = session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
|
||||
session.mainSession.Values["redirect_count"] = 3
|
||||
|
||||
@@ -720,7 +720,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
|
||||
|
||||
// Session should be cleared
|
||||
s.False(session.GetAuthenticated())
|
||||
s.Empty(session.GetEmail())
|
||||
s.Empty(session.GetUserIdentifier())
|
||||
s.Empty(session.GetIDToken())
|
||||
|
||||
// Redirect count should be reset to 0 and then incremented by defaultInitiateAuthentication
|
||||
|
||||
@@ -113,6 +113,14 @@ func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
|
||||
}
|
||||
|
||||
// GetSharedRefreshResultCache returns the short-lived refresh-result cache used
|
||||
// by the refresh path to coalesce grants across Traefik replicas via Redis.
|
||||
func (cm *CacheManager) GetSharedRefreshResultCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetRefreshResultCache(), managed: true}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all cache components
|
||||
func (cm *CacheManager) Close() error {
|
||||
cm.mu.Lock()
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce("test-nonce")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("old-access-token")
|
||||
session.SetRefreshToken("old-refresh-token")
|
||||
session.SetIDToken("old-id-token")
|
||||
@@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Now perform selective clearing (as done in the fix)
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
@@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
|
||||
// Set initial session data
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("old@example.com")
|
||||
session.SetUserIdentifier("old@example.com")
|
||||
session.SetAccessToken("old-token")
|
||||
session.SetCSRF("existing-csrf")
|
||||
|
||||
@@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
// OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF
|
||||
// NEW BEHAVIOR: Selective clearing
|
||||
session2.SetAuthenticated(false)
|
||||
session2.SetEmail("")
|
||||
session2.SetUserIdentifier("")
|
||||
session2.SetAccessToken("")
|
||||
session2.SetRefreshToken("")
|
||||
session2.SetIDToken("")
|
||||
|
||||
@@ -70,6 +70,33 @@ overwrite it).
|
||||
Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local
|
||||
dev). Otherwise leave it at default.
|
||||
|
||||
### Streaming Endpoints (SSE and WebSocket)
|
||||
|
||||
The middleware automatically bypasses the OIDC redirect for two request kinds
|
||||
that browsers cannot follow a 302 on:
|
||||
|
||||
| Bypass | Triggered by |
|
||||
|--------|--------------|
|
||||
| Server-Sent Events (SSE) | `Accept: text/event-stream` |
|
||||
| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) |
|
||||
|
||||
These requests do **not** require any explicit configuration — they are
|
||||
handled implicitly. However, the bypass is **not** unauthenticated:
|
||||
|
||||
- A valid, encrypted session cookie is required. Requests without one are
|
||||
rejected (the connection cannot proceed to the backend).
|
||||
- The session cookie is sealed with `sessionEncryptionKey`, so the
|
||||
`authenticated` flag cannot be forged.
|
||||
- Validation is cookie-only — no JWK fetch / signature verification — so
|
||||
streaming endpoints keep working when the OIDC provider is briefly
|
||||
unavailable.
|
||||
- The user identifier from the session is forwarded as `X-Forwarded-User`
|
||||
(and `X-Auth-Request-User` unless `minimalHeaders: true`).
|
||||
|
||||
For browser clients, the user must complete the normal OIDC flow on a
|
||||
regular HTTP page first; the resulting session cookie is then reused on the
|
||||
SSE / WebSocket connection.
|
||||
|
||||
---
|
||||
|
||||
## Security Options
|
||||
@@ -113,6 +140,7 @@ strictAudienceValidation: true
|
||||
|-----------|------|---------|-------------|
|
||||
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
|
||||
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
|
||||
| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). |
|
||||
| `cookieDomain` | string | auto-detected | Domain for session cookies |
|
||||
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
|
||||
|
||||
|
||||
+11
-1
@@ -718,6 +718,11 @@ spec:
|
||||
<td class="py-2 px-3">86400</td>
|
||||
<td class="py-2 px-3">Maximum session age in seconds (24 hours default)</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">maxRefreshTokenAgeSeconds</code></td>
|
||||
<td class="py-2 px-3">21600</td>
|
||||
<td class="py-2 px-3">Heuristic upper bound on stored refresh-token lifetime (6 hours default). Past this, the plugin treats the RT as expired without contacting the IdP. Set <code>0</code> to disable.</td>
|
||||
</tr>
|
||||
<tr class="border-b border-gray-100 dark:border-gray-800">
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">cookiePrefix</code></td>
|
||||
<td class="py-2 px-3">_oidc_raczylo_</td>
|
||||
@@ -858,7 +863,12 @@ spec:
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.enableTLS</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections</td>
|
||||
<td class="py-2 px-3">Enable TLS for Redis connections (e.g. AWS ElastiCache in-transit encryption)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.tlsSkipVerify</code></td>
|
||||
<td class="py-2 px-3">false</td>
|
||||
<td class="py-2 px-3">Skip TLS server certificate verification (testing only; not recommended in production)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -2,6 +2,8 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -40,6 +42,31 @@ func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, http
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
jwks, err := m.GetJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if jwks == nil {
|
||||
return nil, fmt.Errorf("JWKS is nil")
|
||||
}
|
||||
for i := range jwks.Keys {
|
||||
k := &jwks.Keys[i]
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
}
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
return k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
return k.ToECDSAPublicKey()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) Cleanup() {
|
||||
atomic.AddInt32(&m.CleanupCalls, 1)
|
||||
m.mu.Lock()
|
||||
|
||||
Vendored
+3
@@ -24,6 +24,7 @@ type Config struct {
|
||||
Type BackendType
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
TLSServerName string
|
||||
PoolSize int
|
||||
RedisDB int
|
||||
CleanupInterval time.Duration
|
||||
@@ -34,6 +35,8 @@ type Config struct {
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
EnableMetrics bool
|
||||
EnableTLS bool
|
||||
TLSSkipVerify bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
|
||||
Vendored
+73
-26
@@ -20,6 +20,7 @@ type HybridBackend struct {
|
||||
ctx context.Context
|
||||
syncWriteCacheTypes map[string]bool
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
l1BackfillBuffer chan *l1BackfillItem
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
l1Hits atomic.Int64
|
||||
@@ -28,6 +29,7 @@ type HybridBackend struct {
|
||||
l1Writes atomic.Int64
|
||||
misses atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
l1BackfillDrops atomic.Int64
|
||||
fallbackMode atomic.Bool
|
||||
}
|
||||
|
||||
@@ -39,6 +41,15 @@ type asyncWriteItem struct {
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// l1BackfillItem represents a deferred write of an L2-resolved value back into
|
||||
// L1. Backfills run on a single bounded worker so a burst of L2 hits cannot
|
||||
// detonate the goroutine count (issue: ~1000% CPU under sustained polling).
|
||||
type l1BackfillItem struct {
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
@@ -114,6 +125,7 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
secondary: config.Secondary,
|
||||
syncWriteCacheTypes: config.SyncWriteCacheTypes,
|
||||
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
|
||||
l1BackfillBuffer: make(chan *l1BackfillItem, config.AsyncBufferSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: config.Logger,
|
||||
@@ -123,6 +135,11 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
h.wg.Add(1)
|
||||
go h.asyncWriteWorker()
|
||||
|
||||
// Start L1 backfill worker (single goroutine) to bound goroutine growth on
|
||||
// L2 hits regardless of request rate.
|
||||
h.wg.Add(1)
|
||||
go h.l1BackfillWorker()
|
||||
|
||||
// Start health monitoring
|
||||
h.wg.Add(1)
|
||||
go h.healthMonitor()
|
||||
@@ -223,18 +240,10 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
|
||||
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Populate L1 cache with value from L2 (write-through on read)
|
||||
// Use goroutine to avoid blocking the read path
|
||||
go func() {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
|
||||
}
|
||||
}()
|
||||
// Populate L1 cache with value from L2 (write-through on read).
|
||||
// Hand off to the bounded backfill worker instead of spawning a goroutine
|
||||
// per read - under burst that would mint thousands of goroutines.
|
||||
h.queueL1Backfill(key, value, ttl)
|
||||
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
@@ -371,6 +380,7 @@ func (h *HybridBackend) Close() error {
|
||||
|
||||
// Close async write channel
|
||||
close(h.asyncWriteBuffer)
|
||||
close(h.l1BackfillBuffer)
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
@@ -440,13 +450,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
|
||||
for key, value := range l2Results {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
|
||||
}(key, value)
|
||||
h.queueL1Backfill(key, value, 0) // 0 = primary backend default TTL
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -455,13 +459,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
|
||||
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte, t time.Duration) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, t)
|
||||
}(key, value, ttl)
|
||||
h.queueL1Backfill(key, value, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -538,6 +536,55 @@ func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, tt
|
||||
return nil
|
||||
}
|
||||
|
||||
// queueL1Backfill enqueues an L2-resolved value for write-through into L1.
|
||||
// Drops on full buffer to keep the read path constant-time; the next L2 hit
|
||||
// for the same key simply re-queues it.
|
||||
func (h *HybridBackend) queueL1Backfill(key string, value []byte, ttl time.Duration) {
|
||||
select {
|
||||
case h.l1BackfillBuffer <- &l1BackfillItem{key: key, value: value, ttl: ttl}:
|
||||
default:
|
||||
h.l1BackfillDrops.Add(1)
|
||||
h.logger.Debugf("L1 backfill buffer full, dropping for key: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
// l1BackfillWorker drains the backfill queue serially. Single worker is
|
||||
// intentional - L1 writes are local and cheap, and serializing them keeps
|
||||
// goroutine count bounded under any read rate.
|
||||
func (h *HybridBackend) l1BackfillWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
// Drain remaining items best-effort then exit.
|
||||
for len(h.l1BackfillBuffer) > 0 {
|
||||
select {
|
||||
case item := <-h.l1BackfillBuffer:
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
_ = h.primary.Set(writeCtx, item.key, item.value, item.ttl)
|
||||
cancel()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
case item, ok := <-h.l1BackfillBuffer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
if err := h.primary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", item.key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", item.key)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// asyncWriteWorker processes asynchronous writes to L2
|
||||
func (h *HybridBackend) asyncWriteWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
+112
@@ -0,0 +1,112 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestHybridBackend_L1BackfillBounded verifies that a burst of L2 hits does
|
||||
// not detonate the goroutine count. Pre-fix the code spawned one goroutine
|
||||
// per Get() L2 hit; post-fix all backfills funnel through a single worker.
|
||||
func TestHybridBackend_L1BackfillBounded(t *testing.T) {
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
hybrid, err := NewHybridBackend(&HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 256,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer hybrid.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
const burst = 1000
|
||||
|
||||
// Pre-populate L2 with `burst` distinct keys so each Get triggers a
|
||||
// fresh L1 backfill enqueue.
|
||||
for i := 0; i < burst; i++ {
|
||||
require.NoError(t, secondary.Set(ctx, fmt.Sprintf("k:%d", i), []byte("v"), time.Minute))
|
||||
}
|
||||
|
||||
baseline := runtime.NumGoroutine()
|
||||
|
||||
// Issue the burst as fast as possible; the backfill worker MUST be the
|
||||
// only goroutine doing L1 writes. Allow brief slack for the test runtime
|
||||
// scheduling but anything north of +20 means goroutine leakage.
|
||||
peak := baseline
|
||||
for i := 0; i < burst; i++ {
|
||||
_, _, exists, err := hybrid.Get(ctx, fmt.Sprintf("k:%d", i))
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists)
|
||||
if g := runtime.NumGoroutine(); g > peak {
|
||||
peak = g
|
||||
}
|
||||
}
|
||||
|
||||
delta := peak - baseline
|
||||
if delta > 20 {
|
||||
t.Fatalf("goroutine count grew by %d during burst (baseline=%d peak=%d); backfill worker not bounding goroutines",
|
||||
delta, baseline, peak)
|
||||
}
|
||||
|
||||
// L1 must eventually catch up via the worker. Worker drains serially so
|
||||
// give it a generous window proportional to the burst size.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
var populated int
|
||||
for i := 0; i < burst; i++ {
|
||||
if _, _, ok, _ := primary.Get(ctx, fmt.Sprintf("k:%d", i)); ok {
|
||||
populated++
|
||||
}
|
||||
}
|
||||
// Be lenient: drops are acceptable under buffer pressure, just want
|
||||
// most of the keys to make it.
|
||||
if populated >= burst-int(hybrid.l1BackfillDrops.Load()) {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("L1 not backfilled within deadline: l2Hits=%d l1Writes=%d drops=%d",
|
||||
hybrid.l2Hits.Load(), hybrid.l1Writes.Load(), hybrid.l1BackfillDrops.Load())
|
||||
}
|
||||
|
||||
// TestHybridBackend_L1BackfillFullDrops verifies the drop semantics when the
|
||||
// buffer is saturated. Drops must be counted, never block, never spawn a
|
||||
// goroutine.
|
||||
func TestHybridBackend_L1BackfillFullDrops(t *testing.T) {
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
// Tiny buffer + slow primary writes via failSet so the worker stays
|
||||
// blocked enough to overflow the buffer.
|
||||
hybrid, err := NewHybridBackend(&HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 4,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer hybrid.Close()
|
||||
|
||||
// Stop the worker from draining: cancel the underlying context so the
|
||||
// worker bails out, leaving us with a cold buffer and the queue method
|
||||
// itself responsible for drop accounting.
|
||||
hybrid.cancel()
|
||||
// Wait for worker to exit so it can't drain.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
hybrid.queueL1Backfill(fmt.Sprintf("k:%d", i), []byte("v"), time.Minute)
|
||||
}
|
||||
|
||||
assert.Greater(t, hybrid.l1BackfillDrops.Load(), int64(0),
|
||||
"expected some drops when buffer is saturated and worker is stopped")
|
||||
}
|
||||
Vendored
+3
@@ -49,6 +49,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
poolConfig := &PoolConfig{
|
||||
Address: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
TLSServerName: config.TLSServerName,
|
||||
DB: config.RedisDB,
|
||||
MaxConnections: config.PoolSize,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
@@ -57,6 +58,8 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
EnableHealthCheck: true,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
EnableTLS: config.EnableTLS,
|
||||
TLSSkipVerify: config.TLSSkipVerify,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(poolConfig)
|
||||
|
||||
+25
-3
@@ -2,6 +2,7 @@ package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -31,6 +32,7 @@ type ConnectionPool struct {
|
||||
type PoolConfig struct {
|
||||
Address string
|
||||
Password string
|
||||
TLSServerName string // SNI server name; defaults to host(Address) when empty
|
||||
DB int
|
||||
MaxConnections int
|
||||
ConnectTimeout time.Duration
|
||||
@@ -39,6 +41,8 @@ type PoolConfig struct {
|
||||
EnableHealthCheck bool // Enable connection health validation
|
||||
MaxRetries int // Max retries for failed operations
|
||||
RetryDelay time.Duration // Initial delay between retries
|
||||
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
|
||||
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
|
||||
}
|
||||
|
||||
// NewConnectionPool creates a new connection pool
|
||||
@@ -96,7 +100,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
conn, err = p.createConnection(ctx)
|
||||
if err != nil {
|
||||
// If this is the last attempt, return error
|
||||
if attempt == maxAttempts-1 {
|
||||
@@ -193,13 +197,31 @@ func (p *ConnectionPool) Stats() map[string]interface{} {
|
||||
}
|
||||
|
||||
// createConnection creates a new Redis connection
|
||||
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
|
||||
// Connect with timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: p.config.ConnectTimeout,
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("tcp", p.config.Address)
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if p.config.EnableTLS {
|
||||
serverName := p.config.TLSServerName
|
||||
if serverName == "" {
|
||||
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
|
||||
serverName = host
|
||||
}
|
||||
}
|
||||
tlsCfg := &tls.Config{
|
||||
ServerName: serverName,
|
||||
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
|
||||
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
} else {
|
||||
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
+230
@@ -0,0 +1,230 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// drainRESPRequest consumes a single RESP request (array or inline) from r and
|
||||
// returns true on success. Any read error returns false.
|
||||
func drainRESPRequest(r *bufio.Reader) bool {
|
||||
header, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(header, "*") {
|
||||
return true // inline command (single line) — already consumed
|
||||
}
|
||||
n, err := strconv.Atoi(strings.TrimRight(strings.TrimPrefix(header, "*"), "\r\n"))
|
||||
if err != nil || n <= 0 {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
// Each bulk: "$len\r\n<bytes>\r\n"
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := r.ReadString('\n'); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// startTLSPingServer spins up a TLS listener that speaks just enough RESP to
|
||||
// answer PING with +PONG. Returns the listener address and a self-signed cert.
|
||||
func startTLSPingServer(t *testing.T) (addr string, certPEM []byte, stop func()) {
|
||||
t.Helper()
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "localhost"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
tlsCert := tls.Certificate{
|
||||
Certificate: [][]byte{der},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
|
||||
listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stopCh := make(chan struct{})
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
c, acceptErr := listener.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(conn net.Conn) {
|
||||
defer wg.Done()
|
||||
defer conn.Close()
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
if !drainRESPRequest(reader) {
|
||||
return
|
||||
}
|
||||
_, _ = conn.Write([]byte("+PONG\r\n"))
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
stop = func() {
|
||||
close(stopCh)
|
||||
_ = listener.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
return listener.Addr().String(), der, stop
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_SkipVerify verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=true successfully negotiates TLS and exchanges a Redis command.
|
||||
// Regression test for issue #133 (enableTLS not propagated to client).
|
||||
func TestConnectionPool_TLSDial_SkipVerify(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_VerifyFails verifies that EnableTLS=true with
|
||||
// TLSSkipVerify=false rejects a self-signed server cert.
|
||||
func TestConnectionPool_TLSDial_VerifyFails(t *testing.T) {
|
||||
addr, _, stop := startTLSPingServer(t)
|
||||
defer stop()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "tls")
|
||||
}
|
||||
|
||||
// TestConnectionPool_TLSDial_PlainServerRejected verifies that EnableTLS=true
|
||||
// fails to handshake against a plain (non-TLS) listener.
|
||||
func TestConnectionPool_TLSDial_PlainServerRejected(t *testing.T) {
|
||||
plain, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer plain.Close()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, acceptErr := plain.Accept()
|
||||
if acceptErr != nil {
|
||||
return
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: plain.Addr().String(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 1 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestConnectionPool_PlainDial_StillWorks ensures non-TLS path is unaffected
|
||||
// when EnableTLS=false (default).
|
||||
func TestConnectionPool_PlainDial_StillWorks(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
pool, err := NewConnectionPool(&PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
EnableTLS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestIssue132_RefreshTokenHonorsUserIdentifierClaim reproduces and verifies
|
||||
// the fix for issue #132: token refresh path hardcoded the "email" claim and
|
||||
// ignored the configured userIdentifierClaim. Keycloak users without an email
|
||||
// claim (using sub or another identifier) were being kicked out on refresh
|
||||
// even though their initial login worked.
|
||||
//
|
||||
// The callback path (auth_flow.go) already honored userIdentifierClaim with
|
||||
// "sub" fallback. The refresh path (token_manager.go) had drifted out of sync
|
||||
// after PR #100 (commit a316a98).
|
||||
func TestIssue132_RefreshTokenHonorsUserIdentifierClaim(t *testing.T) {
|
||||
tests := []struct {
|
||||
claims map[string]any
|
||||
name string
|
||||
userIdentifierClaim string
|
||||
expectedIdentifier string
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "sub claim configured, only sub present (Keycloak no-email case)",
|
||||
userIdentifierClaim: "sub",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-keycloak-12345",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user-uuid-keycloak-12345",
|
||||
},
|
||||
{
|
||||
name: "preferred_username configured, claim present",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"preferred_username": "alice",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "alice",
|
||||
},
|
||||
{
|
||||
name: "configured claim missing, falls back to sub",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]any{
|
||||
"sub": "fallback-sub-id",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "fallback-sub-id",
|
||||
},
|
||||
{
|
||||
name: "email default, email present (backward compatibility)",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"sub": "user-uuid-12345",
|
||||
"email": "user@example.com",
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: true,
|
||||
expectedIdentifier: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "email default, no email and no sub - refresh fails",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]any{
|
||||
"exp": float64(9999999999),
|
||||
},
|
||||
expectSuccess: false,
|
||||
expectedIdentifier: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager(
|
||||
"test-encryption-key-32-bytes-long!!",
|
||||
false,
|
||||
"",
|
||||
"",
|
||||
0,
|
||||
NewLogger("error"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("session manager: %v", err)
|
||||
}
|
||||
defer sessionManager.Shutdown()
|
||||
|
||||
capturedClaims := tt.claims
|
||||
tOidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
userIdentifierClaim: tt.userIdentifierClaim,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: &EnhancedMockTokenExchanger{
|
||||
RefreshResponse: &TokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
IDToken: "new-id-token-jwt",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
},
|
||||
tokenVerifier: &EnhancedMockTokenVerifier{Err: nil},
|
||||
extractClaimsFunc: func(token string) (map[string]any, error) {
|
||||
return capturedClaims, nil
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("get session: %v", err)
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
session.SetRefreshToken("initial-refresh-token")
|
||||
|
||||
refreshed := tOidc.refreshToken(rw, req, session)
|
||||
|
||||
if refreshed != tt.expectSuccess {
|
||||
t.Fatalf("refreshToken() = %v, want %v", refreshed, tt.expectSuccess)
|
||||
}
|
||||
|
||||
if got := session.GetUserIdentifier(); got != tt.expectedIdentifier {
|
||||
t.Errorf("session.GetUserIdentifier() = %q, want %q", got, tt.expectedIdentifier)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
@@ -18,6 +19,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// parsedKeysSuffix marks the parallel UniversalCache entry that stores
|
||||
// pre-parsed public keys for a given JWKS URL.
|
||||
const parsedKeysSuffix = ":parsed"
|
||||
|
||||
// parsedJWKS holds keys decoded from a JWKSet, indexed by kid. Storing the
|
||||
// already-parsed crypto.PublicKey avoids re-running the DER/PEM round trip
|
||||
// on every JWT verification — a costly operation under the yaegi interpreter
|
||||
// that hosts Traefik plugins.
|
||||
type parsedJWKS struct {
|
||||
keys map[string]crypto.PublicKey
|
||||
}
|
||||
|
||||
// JWK represents a JSON Web Key as defined in RFC 7517.
|
||||
// It can represent different key types including RSA, EC, and symmetric keys.
|
||||
type JWK struct {
|
||||
@@ -49,6 +62,7 @@ type JWKCache struct {
|
||||
// JWKCacheInterface defines the contract for JWK caching implementations.
|
||||
type JWKCacheInterface interface {
|
||||
GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error)
|
||||
Cleanup()
|
||||
Close()
|
||||
}
|
||||
@@ -96,6 +110,62 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// GetPublicKey returns the parsed public key for a given kid, fetching and
|
||||
// caching the JWKS plus its derived parsedJWKS on miss. The parsed entry is
|
||||
// stored alongside the raw JWKSet under a sibling cache key with the same
|
||||
// 1-hour TTL, so both invalidate together when the upstream JWKS rotates.
|
||||
func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
parsedKey := jwksURL + parsedKeysSuffix
|
||||
if v, found := c.cache.Get(parsedKey); found {
|
||||
if pj, ok := v.(*parsedJWKS); ok {
|
||||
if k, ok := pj.keys[kid]; ok {
|
||||
return k, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jwks, err := c.GetJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pj := buildParsedJWKS(jwks)
|
||||
_ = c.cache.Set(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
if k, ok := pj.keys[kid]; ok {
|
||||
return k, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
// buildParsedJWKS pre-parses every JWK in the set into the matching
|
||||
// crypto.PublicKey, indexed by kid. Errors on individual keys are skipped so
|
||||
// a single bad key does not block the rest of the keyset.
|
||||
func buildParsedJWKS(jwks *JWKSet) *parsedJWKS {
|
||||
out := make(map[string]crypto.PublicKey, len(jwks.Keys))
|
||||
for i := range jwks.Keys {
|
||||
k := &jwks.Keys[i]
|
||||
if k.Kid == "" {
|
||||
continue
|
||||
}
|
||||
var pub crypto.PublicKey
|
||||
var err error
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
pub, err = k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
pub, err = k.ToECDSAPublicKey()
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out[k.Kid] = pub
|
||||
}
|
||||
return &parsedJWKS{keys: out}
|
||||
}
|
||||
|
||||
// Cleanup is a no-op as cleanup is handled by UniversalCache
|
||||
func (c *JWKCache) Cleanup() {
|
||||
// Handled internally by UniversalCache
|
||||
|
||||
@@ -528,6 +528,21 @@ func verifyNotBefore(notBefore float64) error {
|
||||
// - An error if the key parsing fails, the algorithm is unsupported,
|
||||
// or the signature verification fails
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
return verifySignatureWithKey(tokenString, pubKey, alg)
|
||||
}
|
||||
|
||||
// verifySignatureWithKey verifies a JWT signature using an already-parsed
|
||||
// public key, skipping the PEM-encode/decode round trip that verifySignature
|
||||
// performs. This is the hot path used by VerifyJWTSignatureAndClaims.
|
||||
func verifySignatureWithKey(tokenString string, pubKey crypto.PublicKey, alg string) error {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
@@ -537,14 +552,6 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
var hashFunc crypto.Hash
|
||||
switch alg {
|
||||
case "RS256", "PS256", "ES256":
|
||||
|
||||
@@ -2,6 +2,7 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
@@ -639,6 +640,26 @@ func (m *mockJWKCacheForLogout) GetJWKS(ctx context.Context, jwksURL string, htt
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockJWKCacheForLogout) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
jwks, err := m.GetJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range jwks.Keys {
|
||||
k := &jwks.Keys[i]
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
}
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
return k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
return k.ToECDSAPublicKey()
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
func (m *mockJWKCacheForLogout) Clear() {}
|
||||
func (m *mockJWKCacheForLogout) Cleanup() {}
|
||||
func (m *mockJWKCacheForLogout) Close() {}
|
||||
@@ -755,6 +776,22 @@ func (s *staticJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient
|
||||
return s.jwks, nil
|
||||
}
|
||||
|
||||
func (s *staticJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
for i := range s.jwks.Keys {
|
||||
k := &s.jwks.Keys[i]
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
}
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
return k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
return k.ToECDSAPublicKey()
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
func (s *staticJWKCache) Clear() {}
|
||||
func (s *staticJWKCache) Cleanup() {}
|
||||
func (s *staticJWKCache) Close() {}
|
||||
|
||||
@@ -226,6 +226,13 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}
|
||||
return 60 * time.Second
|
||||
}(),
|
||||
maxRefreshTokenAge: func() time.Duration {
|
||||
// 0 (or unset) disables the heuristic; negative is rejected by Validate.
|
||||
if config.MaxRefreshTokenAgeSeconds > 0 {
|
||||
return time.Duration(config.MaxRefreshTokenAgeSeconds) * time.Second
|
||||
}
|
||||
return 0
|
||||
}(),
|
||||
tokenCleanupStopChan: make(chan struct{}),
|
||||
metadataRefreshStopChan: make(chan struct{}),
|
||||
ctx: pluginCtx,
|
||||
@@ -242,6 +249,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
|
||||
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
|
||||
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
|
||||
refreshResultCache: cacheManager.GetSharedRefreshResultCache(),
|
||||
}
|
||||
|
||||
// Log audience configuration
|
||||
@@ -260,6 +268,11 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
tokenResilienceConfig := DefaultTokenResilienceConfig()
|
||||
t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger)
|
||||
|
||||
// Coalesces concurrent refresh-token grants per refresh_token to one upstream
|
||||
// call, preventing the thundering herd that yields invalid_grant when the IdP
|
||||
// rotates refresh tokens (Zitadel/Authentik default).
|
||||
t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger)
|
||||
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
|
||||
+199
-47
@@ -79,34 +79,186 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestServeHTTP_EventStream tests the event-stream bypass functionality
|
||||
// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the
|
||||
// handshake must skip the OIDC redirect dance (clients can't follow it
|
||||
// mid-stream) but it must STILL require an authenticated session, otherwise
|
||||
// any caller could reach the backend by setting Accept: text/event-stream.
|
||||
func TestServeHTTP_EventStream(t *testing.T) {
|
||||
nextCalled := false
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
newOidc := func(next http.Handler) *TraefikOidc {
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
return oidc
|
||||
}
|
||||
|
||||
t.Run("unauthenticated_request_is_rejected", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Error("backend handler must NOT be called for unauthenticated SSE bypass")
|
||||
}
|
||||
})
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
var forwardedUser string
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
forwardedUser = r.Header.Get("X-Forwarded-User")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
// Build an authenticated session and inject its cookies onto req.
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
setupRW := httptest.NewRecorder()
|
||||
if err := session.Save(req, setupRW); err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
}
|
||||
for _, c := range setupRW.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Fatal("expected authenticated SSE request to be forwarded to backend")
|
||||
}
|
||||
if forwardedUser != "user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket
|
||||
// handshake bypasses the OIDC redirect (clients can't follow it) but the
|
||||
// session must already be authenticated, otherwise the backend is exposed
|
||||
// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`.
|
||||
func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
newOidc := func(next http.Handler) *TraefikOidc {
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
return oidc
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/events", nil)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
rw := httptest.NewRecorder()
|
||||
t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
}))
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
if !nextCalled {
|
||||
t.Error("expected event-stream request to bypass OIDC")
|
||||
}
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code)
|
||||
}
|
||||
if nextCalled {
|
||||
t.Error("backend handler must NOT be called for unauthenticated WS bypass")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
var forwardedUser string
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
forwardedUser = r.Header.Get("X-Forwarded-User")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
// Mixed-case + multi-token Connection header to exercise parsing.
|
||||
req.Header.Set("Connection", "keep-alive, Upgrade")
|
||||
req.Header.Set("Upgrade", "WebSocket")
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
session.SetUserIdentifier("ws-user@example.com")
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("failed to mark session authenticated: %v", err)
|
||||
}
|
||||
setupRW := httptest.NewRecorder()
|
||||
if err := session.Save(req, setupRW); err != nil {
|
||||
t.Fatalf("failed to save session: %v", err)
|
||||
}
|
||||
for _, c := range setupRW.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Fatal("expected authenticated WS handshake to be forwarded to backend")
|
||||
}
|
||||
if forwardedUser != "ws-user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("plain_http_does_not_bypass", func(t *testing.T) {
|
||||
// Sanity: requests without Upgrade headers must NOT hit the WS
|
||||
// bypass branch (otherwise the new code path could short-circuit
|
||||
// normal authentication).
|
||||
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("backend must not be called for unauthenticated plain HTTP")
|
||||
}))
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
rw := httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
if rw.Code == http.StatusOK {
|
||||
t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestServeHTTP_InitializationTimeout tests initialization timeout handling
|
||||
@@ -256,7 +408,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "successful authorization with email",
|
||||
setupSession: func() *MockSessionData {
|
||||
session := &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -288,7 +440,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "no email triggers reauth",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "",
|
||||
userIdentifier: "",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -309,7 +461,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "roles and groups authorization",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -342,7 +494,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "unauthorized role/group returns 403",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -369,7 +521,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "template headers processing",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
isDirty: false,
|
||||
@@ -401,7 +553,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
name: "OPTIONS request with CORS",
|
||||
setupSession: func() *MockSessionData {
|
||||
return &MockSessionData{
|
||||
email: "user@example.com",
|
||||
userIdentifier: "user@example.com",
|
||||
idToken: "test-id-token",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
@@ -452,7 +604,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
manager: &SessionManager{logger: NewLogger("debug")},
|
||||
}
|
||||
// Copy values from mock to concrete session
|
||||
concreteSession.SetEmail(session.email)
|
||||
concreteSession.SetUserIdentifier(session.userIdentifier)
|
||||
concreteSession.SetIDToken(session.idToken)
|
||||
concreteSession.SetAccessToken(session.accessToken)
|
||||
concreteSession.SetRefreshToken(session.refreshToken)
|
||||
@@ -502,23 +654,23 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
|
||||
// MockSessionData is a test implementation of SessionData interface
|
||||
type MockSessionData struct {
|
||||
email string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
userIdentifier string
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
redirectCount int
|
||||
authenticated bool
|
||||
isDirty bool
|
||||
}
|
||||
|
||||
func (m *MockSessionData) GetEmail() string { return m.email }
|
||||
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
|
||||
func (m *MockSessionData) GetIDToken() string { return m.idToken }
|
||||
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
|
||||
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
|
||||
func (m *MockSessionData) SetEmail(email string) { m.email = email }
|
||||
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
|
||||
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
|
||||
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
|
||||
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
|
||||
@@ -610,7 +762,7 @@ func TestMinimalHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
// Set up session data
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Call processAuthorizedRequest directly
|
||||
@@ -685,7 +837,7 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
@@ -771,7 +923,7 @@ func TestStripAuthCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Now add OIDC session cookies (simulating what the browser would send)
|
||||
@@ -852,7 +1004,7 @@ func TestStripAuthCookies_NoCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
@@ -899,7 +1051,7 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only OIDC cookies
|
||||
@@ -950,7 +1102,7 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add only non-OIDC cookies
|
||||
@@ -1013,7 +1165,7 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Add cookies with the custom prefix (should be stripped)
|
||||
|
||||
+41
-15
@@ -208,6 +208,32 @@ func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if m.Err != nil {
|
||||
return nil, m.Err
|
||||
}
|
||||
if m.JWKS == nil {
|
||||
return nil, fmt.Errorf("JWKS is nil")
|
||||
}
|
||||
for i := range m.JWKS.Keys {
|
||||
k := &m.JWKS.Keys[i]
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
}
|
||||
switch k.Kty {
|
||||
case "RSA":
|
||||
return k.ToRSAPublicKey()
|
||||
case "EC":
|
||||
return k.ToECDSAPublicKey()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
func (m *MockJWKCache) Cleanup() {
|
||||
// Mock cleanup is a no-op - we don't want to destroy the mock JWKS data
|
||||
// Real cleanup is for expired entries, not resetting all data
|
||||
@@ -554,7 +580,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case to avoid replay issues
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -577,7 +603,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
// even if session.SetAuthenticated(true) was called.
|
||||
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
|
||||
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -634,7 +660,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -652,7 +678,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -680,7 +706,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true) // Set flag initially
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Create an expired token for this test
|
||||
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
@@ -715,7 +741,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(nearExpiryToken)
|
||||
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
|
||||
},
|
||||
@@ -746,7 +772,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken(validToken)
|
||||
session.SetIDToken(validToken) // Ensure ID token is also set
|
||||
session.SetRefreshToken("should-not-be-used-refresh-token")
|
||||
@@ -766,7 +792,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -788,7 +814,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
requestPath: "/protected",
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@disallowed.com") // Use disallowed domain
|
||||
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
||||
// Generate a fresh valid token for this test case
|
||||
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
@@ -2153,7 +2179,7 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
||||
})
|
||||
session.SetAccessToken(expiredToken)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
},
|
||||
expectedPath: "/original/path",
|
||||
},
|
||||
@@ -2730,7 +2756,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2756,7 +2782,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2783,7 +2809,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
@@ -2803,7 +2829,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
@@ -2825,7 +2851,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{},
|
||||
|
||||
+150
-78
@@ -14,21 +14,40 @@ import (
|
||||
)
|
||||
|
||||
// bypassReason describes why a request is being forwarded without OIDC auth.
|
||||
// It is only used for logging and to decide whether extra SSE-specific work
|
||||
// It is only used for logging and to decide whether extra side-effects
|
||||
// (propagating the user header from an existing session) should run.
|
||||
const (
|
||||
bypassReasonExcluded = "excluded-url"
|
||||
bypassReasonSSE = "sse"
|
||||
bypassReasonExcluded = "excluded-url"
|
||||
bypassReasonSSE = "sse"
|
||||
bypassReasonWebSocket = "websocket"
|
||||
)
|
||||
|
||||
// isWebSocketUpgrade reports whether req is a WebSocket upgrade handshake
|
||||
// (RFC 6455). The middleware can only see the handshake; once Traefik
|
||||
// completes the upgrade it forwards frames directly, so we never re-process
|
||||
// per-frame traffic. We bypass auth on the handshake the same way we do for
|
||||
// SSE, because browser WebSocket clients cannot follow an OIDC redirect.
|
||||
func isWebSocketUpgrade(req *http.Request) bool {
|
||||
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
|
||||
return false
|
||||
}
|
||||
for _, token := range strings.Split(req.Header.Get("Connection"), ",") {
|
||||
if strings.EqualFold(strings.TrimSpace(token), "upgrade") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// shouldBypassAuth decides whether a request must skip OIDC authentication
|
||||
// entirely. It returns (true, reason) when either the request path matches a
|
||||
// configured excluded URL or the Accept header asks for a text/event-stream
|
||||
// response (SSE). The reason lets ServeHTTP apply any side-effects that are
|
||||
// unique to the bypass kind (e.g. propagating user headers for SSE).
|
||||
// configured excluded URL, the Accept header asks for a text/event-stream
|
||||
// response (SSE), or the request is a WebSocket upgrade handshake. The
|
||||
// reason lets ServeHTTP apply any side-effects that are unique to the bypass
|
||||
// kind (e.g. propagating user headers).
|
||||
//
|
||||
// This must be called BEFORE waiting on t.initComplete so excluded and SSE
|
||||
// traffic is never blocked by a slow/broken provider.
|
||||
// This must be called BEFORE waiting on t.initComplete so excluded, SSE and
|
||||
// WebSocket traffic is never blocked by a slow/broken provider.
|
||||
func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
return true, bypassReasonExcluded
|
||||
@@ -36,38 +55,55 @@ func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
|
||||
if strings.Contains(req.Header.Get("Accept"), "text/event-stream") {
|
||||
return true, bypassReasonSSE
|
||||
}
|
||||
if isWebSocketUpgrade(req) {
|
||||
return true, bypassReasonWebSocket
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// applySSEUserHeaders attempts to copy the authenticated user's identity from
|
||||
// an existing session onto the outgoing SSE request so downstream services
|
||||
// can still see who the user is. Failures are logged (not silenced) because
|
||||
// they indicate either a corrupt cookie or a misconfigured session manager
|
||||
// and are useful for debugging, but they never block the bypass itself.
|
||||
func (t *TraefikOidc) applySSEUserHeaders(req *http.Request) {
|
||||
// applyBypassUserHeaders enforces authentication on SSE / WebSocket bypass
|
||||
// requests and, on success, copies the authenticated user's identity onto
|
||||
// the outgoing request so downstream services can see who the user is.
|
||||
//
|
||||
// Returns true when the request carries a valid authenticated session and
|
||||
// the bypass should proceed. Returns false when no usable session is
|
||||
// present; callers must then reject the request (typically with 401) to
|
||||
// prevent unauthenticated traffic from reaching the backend just by setting
|
||||
// `Accept: text/event-stream` or sending a WebSocket upgrade.
|
||||
//
|
||||
// The check is cookie-only: the session cookie is sealed by our encryption
|
||||
// key, so the authenticated flag cannot be forged. We do NOT run full token
|
||||
// signature verification here so that SSE/WS keeps working when the OIDC
|
||||
// provider is briefly unavailable for JWK fetches.
|
||||
func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) bool {
|
||||
if t.sessionManager == nil {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
// Intentionally not fatal: SSE requests bypass auth, we just lose the
|
||||
// forwarded-user header for this request.
|
||||
t.logger.Debugf("SSE bypass: unable to load session for user header propagation: %v", err)
|
||||
return
|
||||
t.logger.Debugf("%s bypass: unable to load session: %v", reason, err)
|
||||
return false
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
return
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debugf("%s bypass: rejecting request without authenticated session", reason)
|
||||
return false
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason)
|
||||
return false
|
||||
}
|
||||
t.logger.Debugf("SSE bypass: forwarded user %s from session", email)
|
||||
|
||||
req.Header.Set("X-Forwarded-User", userIdentifier)
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-User", userIdentifier)
|
||||
}
|
||||
t.logger.Debugf("%s bypass: forwarded user %s from session", reason, userIdentifier)
|
||||
return true
|
||||
}
|
||||
|
||||
// ServeHTTP implements the main middleware logic for processing HTTP requests.
|
||||
@@ -124,16 +160,32 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
t.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Evaluate auth-bypass once, before waiting for initialization. Excluded URLs
|
||||
// and SSE requests must not block on provider init. For SSE we additionally
|
||||
// attempt to forward the user identity from an existing session (best
|
||||
// effort) so downstream handlers still see X-Forwarded-User.
|
||||
// Evaluate auth-bypass once, before waiting for initialization. Excluded
|
||||
// URLs, SSE and WebSocket upgrade requests must not block on provider
|
||||
// init. For SSE/WebSocket we ALSO require an authenticated session
|
||||
// (cookie-only check, no JWK fetch) and otherwise return 401 — clients
|
||||
// of in-flight streams can't follow an OIDC redirect, so forwarding
|
||||
// unauthenticated traffic would silently expose the backend.
|
||||
if bypass, reason := t.shouldBypassAuth(req); bypass {
|
||||
t.logger.Debugf("Bypassing OIDC for %s (%s)", req.URL.Path, reason)
|
||||
if reason == bypassReasonSSE {
|
||||
t.applySSEUserHeaders(req)
|
||||
switch reason {
|
||||
case bypassReasonExcluded:
|
||||
// Operator-declared excluded URLs forward unconditionally.
|
||||
t.next.ServeHTTP(rw, req)
|
||||
case bypassReasonSSE, bypassReasonWebSocket:
|
||||
// Skip the OIDC redirect dance (clients can't follow it
|
||||
// mid-stream) but still require an authenticated session.
|
||||
// Otherwise an unauthenticated client could hit the backend
|
||||
// just by setting Accept: text/event-stream or sending a
|
||||
// WebSocket upgrade.
|
||||
if !t.applyBypassUserHeaders(req, reason) {
|
||||
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
t.next.ServeHTTP(rw, req)
|
||||
default:
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -237,7 +289,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
// User authorization check
|
||||
if authenticated && userIdentifier != "" {
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
@@ -309,7 +361,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if refreshed {
|
||||
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
|
||||
userIdentifier = session.GetUserIdentifier()
|
||||
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
|
||||
@@ -359,9 +411,9 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// - session: The user's session data containing tokens and claims.
|
||||
// - redirectURL: The callback URL for re-authentication if needed.
|
||||
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
t.logger.Info("No email found in session during final processing, initiating re-auth")
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
|
||||
// Reset redirect count to prevent loops when session is invalid
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
@@ -374,7 +426,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
if idToken != "" {
|
||||
sid, sub, createdAt := t.extractSessionInfo(idToken)
|
||||
if t.isSessionInvalidated(sid, sub, createdAt) {
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email)
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
|
||||
// Clear the session and redirect to login
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing invalidated session: %v", err)
|
||||
@@ -386,31 +438,52 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
|
||||
tokenForClaims := session.GetIDToken()
|
||||
if tokenForClaims == "" {
|
||||
tokenForClaims = session.GetAccessToken()
|
||||
if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
// Reset redirect count to prevent loops when token is missing
|
||||
// Resolve ID-token claims at most once per request. SessionData caches
|
||||
// the parsed claims keyed on the raw ID token, so concurrent dashboard
|
||||
// panel requests on the same session don't repeatedly base64-decode and
|
||||
// JSON-unmarshal the same JWT (a real cost under the yaegi interpreter
|
||||
// that hosts Traefik plugins). idClaims is reused below by the
|
||||
// header-templates branch.
|
||||
idToken := session.GetIDToken()
|
||||
var (
|
||||
idClaims map[string]interface{}
|
||||
idClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
|
||||
}
|
||||
|
||||
// Choose which claims drive groups/roles extraction. Prefer the ID
|
||||
// token (cached) and fall back to the access token if there is no ID
|
||||
// token in the session — matching the prior behavior for opaque
|
||||
// ID-token providers.
|
||||
var (
|
||||
groupClaims map[string]interface{}
|
||||
groupClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
groupClaims, groupClaimsErr = idClaims, idClaimsErr
|
||||
} else if accessToken := session.GetAccessToken(); accessToken != "" {
|
||||
groupClaims, groupClaimsErr = t.extractClaimsFunc(accessToken)
|
||||
} else if len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
var groups, roles []string
|
||||
|
||||
if groupClaimsErr == nil && groupClaims != nil {
|
||||
var err error
|
||||
groups, roles, err = t.extractGroupsAndRolesFromClaims(groupClaims)
|
||||
if err != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize empty slices
|
||||
var groups, roles []string
|
||||
|
||||
if tokenForClaims != "" {
|
||||
var err error
|
||||
groups, roles, err = t.extractGroupsAndRoles(tokenForClaims)
|
||||
if err != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
// Reset redirect count to prevent loops when claim extraction fails
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
} else if err == nil {
|
||||
if err == nil {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
}
|
||||
@@ -429,54 +502,53 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
t.logger.Infof("User %s does not have any allowed roles or groups", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
req.Header.Set("X-Forwarded-User", userIdentifier)
|
||||
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-User", userIdentifier)
|
||||
if idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.headerTemplates) > 0 {
|
||||
// Reuse claims parsed earlier in this request if the ID token has not
|
||||
// changed. Saves an unnecessary JWT parse on every authenticated
|
||||
// request that uses headerTemplates.
|
||||
claims, err := session.GetIDTokenClaims(t.extractClaimsFunc)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", err)
|
||||
if idClaimsErr != nil {
|
||||
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", idClaimsErr)
|
||||
} else {
|
||||
// idClaims may be nil when no ID token is present; templates
|
||||
// referencing .Claims.* will simply produce empty values, which
|
||||
// matches the prior behavior.
|
||||
templateData := map[string]interface{}{
|
||||
"AccessToken": session.GetAccessToken(),
|
||||
"IDToken": session.GetIDToken(),
|
||||
"IDToken": idToken,
|
||||
"RefreshToken": session.GetRefreshToken(),
|
||||
"Claims": claims,
|
||||
"Claims": idClaims,
|
||||
}
|
||||
|
||||
for headerName, tmpl := range t.headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := tmpl.Execute(&buf, templateData); err != nil {
|
||||
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
|
||||
continue
|
||||
}
|
||||
headerValue := buf.String()
|
||||
|
||||
req.Header.Set(headerName, headerValue)
|
||||
|
||||
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
|
||||
}
|
||||
session.MarkDirty()
|
||||
t.logger.Debugf("Session marked dirty after templated header processing.")
|
||||
// NOTE: templates only mutate request headers (not session state),
|
||||
// so we deliberately do NOT MarkDirty / Save here. Previously every
|
||||
// authenticated request with header templates re-encrypted and
|
||||
// rewrote all session cookies, which was a measurable CPU and
|
||||
// Set-Cookie tax on dashboards that poll many panels per second.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -515,7 +587,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
}
|
||||
}
|
||||
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
|
||||
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", userIdentifier)
|
||||
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create authenticated session
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetIDToken("dummy-token")
|
||||
session.Save(req, httptest.NewRecorder())
|
||||
@@ -203,7 +203,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
// Create session with forbidden domain
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@forbidden.com")
|
||||
session.SetUserIdentifier("user@forbidden.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Save and inject cookies
|
||||
@@ -252,7 +252,7 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
|
||||
// Create session with opaque token
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots)
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
@@ -291,7 +291,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("") // No email
|
||||
session.SetUserIdentifier("") // No email
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -321,7 +321,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("") // No ID token
|
||||
session.SetAccessToken("") // No access token
|
||||
|
||||
@@ -349,7 +349,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -383,7 +383,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
testEmail := "user@example.com"
|
||||
session.SetEmail(testEmail)
|
||||
session.SetUserIdentifier(testEmail)
|
||||
session.SetIDToken("dummy-id-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
@@ -466,10 +466,23 @@ func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
|
||||
|
||||
// hashRefreshToken creates a hash of the refresh token for deduplication
|
||||
func (rc *RefreshCoordinator) hashRefreshToken(token string) string {
|
||||
return refreshCoordinatorSessionID(token)
|
||||
}
|
||||
|
||||
// refreshCoordinatorSessionID derives a stable identifier from a refresh token
|
||||
// for both deduplication and per-session attempt tracking. Using sha256 of the
|
||||
// raw token means each rotation produces a fresh sessionID with its own attempt
|
||||
// budget, which is what we want.
|
||||
func refreshCoordinatorSessionID(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// refreshCoordinatorWaitTimeout caps how long a request may wait for a
|
||||
// coordinated refresh result. It is wider than RefreshTimeout so a follower
|
||||
// always sees the leader's result instead of timing out independently.
|
||||
const refreshCoordinatorWaitTimeout = 35 * time.Second
|
||||
|
||||
// isUnderMemoryPressure checks if the system is under memory pressure by
|
||||
// consulting the global memory monitor. Returns true when pressure reaches
|
||||
// High or Critical, at which point we refuse new refresh operations to
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// stubTokenExchanger lets us count how many upstream refresh-token grants
|
||||
// happen for a given refresh_token across concurrent middleware-level calls.
|
||||
type stubTokenExchanger struct {
|
||||
calls int32
|
||||
delay time.Duration
|
||||
resp *TokenResponse
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.delay > 0 {
|
||||
time.Sleep(s.delay)
|
||||
}
|
||||
return s.resp, nil
|
||||
}
|
||||
|
||||
func (s *stubTokenExchanger) RevokeTokenWithProvider(_, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_SingleUpstreamCall verifies the wireup: many
|
||||
// concurrent calls to coordinatedTokenRefresh with the same refresh token
|
||||
// must collapse to a single tokenExchanger.GetNewTokenWithRefreshToken call.
|
||||
//
|
||||
// Without the wireup this assertion fails (one upstream call per goroutine).
|
||||
func TestCoordinatedTokenRefresh_SingleUpstreamCall(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 100 * time.Millisecond,
|
||||
resp: &TokenResponse{
|
||||
AccessToken: "new_access",
|
||||
RefreshToken: "new_refresh",
|
||||
IDToken: "new_id",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const concurrency = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrency)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
start := make(chan struct{})
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
resp, err := oidc.coordinatedTokenRefresh(req, "shared_refresh_token")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "new_access" {
|
||||
t.Errorf("unexpected response: %+v", resp)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
got := atomic.LoadInt32(&stub.calls)
|
||||
// Up to 2 is acceptable to absorb the documented timing slack in the
|
||||
// existing coordinator tests (e.g. operation just cleaned up before a
|
||||
// late goroutine reads the in-flight map). Anything beyond that means
|
||||
// coalescing is broken.
|
||||
if got > 2 {
|
||||
t.Fatalf("expected <=2 upstream refresh calls, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator verifies the nil
|
||||
// coordinator path so existing tests that build TraefikOidc literals stay
|
||||
// green.
|
||||
func TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenExchanger: stub,
|
||||
// refreshCoordinator deliberately nil
|
||||
}
|
||||
|
||||
resp, err := oidc.coordinatedTokenRefresh(nil, "rt")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "ok" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 1 {
|
||||
t.Fatalf("expected exactly 1 upstream call, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_DistinctTokensRunInParallel verifies that
|
||||
// distinct refresh tokens are not falsely coalesced.
|
||||
func TestCoordinatedTokenRefresh_DistinctTokensRunInParallel(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
delay: 20 * time.Millisecond,
|
||||
resp: &TokenResponse{AccessToken: "ok"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cfg := DefaultRefreshCoordinatorConfig()
|
||||
cfg.MaxRefreshAttempts = 10000
|
||||
cfg.MaxConcurrentRefreshes = 32
|
||||
cfg.DeduplicationCleanupDelay = 0
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
const distinct = 8
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(distinct)
|
||||
for i := 0; i < distinct; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := oidc.coordinatedTokenRefresh(nil, refreshCoordinatorSessionID(string(rune('a'+i))))
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if got := atomic.LoadInt32(&stub.calls); int(got) != distinct {
|
||||
t.Fatalf("expected %d distinct upstream calls, got %d", distinct, got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// inMemoryCache is the smallest CacheInterface that satisfies the cross-
|
||||
// replica dedup contract: Set/Get with TTL. Used in place of the universal
|
||||
// cache singleton so these tests stay hermetic.
|
||||
type inMemoryCache struct {
|
||||
entries map[string]inMemoryCacheEntry
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type inMemoryCacheEntry struct {
|
||||
expiresAt time.Time
|
||||
value interface{}
|
||||
}
|
||||
|
||||
func newInMemoryCache() *inMemoryCache {
|
||||
return &inMemoryCache{entries: make(map[string]inMemoryCacheEntry)}
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Set(key string, value any, ttl time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.entries[key] = inMemoryCacheEntry{value: value, expiresAt: time.Now().Add(ttl)}
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Get(key string) (any, bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
e, ok := c.entries[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(e.expiresAt) {
|
||||
delete(c.entries, key)
|
||||
return nil, false
|
||||
}
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) Delete(key string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.entries, key)
|
||||
}
|
||||
|
||||
func (c *inMemoryCache) SetMaxSize(int) {}
|
||||
func (c *inMemoryCache) Cleanup() {}
|
||||
func (c *inMemoryCache) Close() {}
|
||||
func (c *inMemoryCache) Size() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.entries)
|
||||
}
|
||||
func (c *inMemoryCache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.entries = map[string]inMemoryCacheEntry{}
|
||||
}
|
||||
func (c *inMemoryCache) GetStats() map[string]any { return map[string]any{} }
|
||||
|
||||
// erroringTokenExchanger always errors - simulates an IdP rejection.
|
||||
type erroringTokenExchanger struct {
|
||||
calls int32
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
|
||||
return nil, errors.New("not used")
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
|
||||
atomic.AddInt32(&e.calls, 1)
|
||||
return nil, errors.New("invalid_grant")
|
||||
}
|
||||
|
||||
func (e *erroringTokenExchanger) RevokeTokenWithProvider(_, _ string) error { return nil }
|
||||
|
||||
// TestCoordinatedTokenRefresh_CrossReplicaCacheHit simulates a peer Traefik
|
||||
// replica having just refreshed: the shared cache already has the result, so
|
||||
// this pod must reuse it without ever calling the IdP.
|
||||
func TestCoordinatedTokenRefresh_CrossReplicaCacheHit(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "should_not_be_called"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
preExisting := &TokenResponse{
|
||||
AccessToken: "from_peer",
|
||||
RefreshToken: "rotated_by_peer",
|
||||
IDToken: "id_from_peer",
|
||||
}
|
||||
rt := "shared_refresh_token"
|
||||
cache.Set(refreshResultCacheKey(refreshCoordinatorSessionID(rt)), preExisting, refreshResultCacheTTL)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
resp, err := oidc.coordinatedTokenRefresh(httptest.NewRequest("GET", "/", nil), rt)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "from_peer" {
|
||||
t.Fatalf("expected peer-provided response, got %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 0 {
|
||||
t.Fatalf("expected 0 upstream calls (peer already refreshed), got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache verifies that on a
|
||||
// cache miss the leader stores its result for peers to find within the TTL.
|
||||
func TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache(t *testing.T) {
|
||||
stub := &stubTokenExchanger{
|
||||
resp: &TokenResponse{AccessToken: "fresh_grant"},
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: stub,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
rt := "fresh_refresh_token"
|
||||
resp, err := oidc.coordinatedTokenRefresh(nil, rt)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.AccessToken != "fresh_grant" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if got := atomic.LoadInt32(&stub.calls); got != 1 {
|
||||
t.Fatalf("expected 1 upstream call, got %d", got)
|
||||
}
|
||||
|
||||
v, ok := cache.Get(refreshResultCacheKey(refreshCoordinatorSessionID(rt)))
|
||||
if !ok {
|
||||
t.Fatal("expected refresh result to be cached after upstream success")
|
||||
}
|
||||
if tr, ok := v.(*TokenResponse); !ok || tr.AccessToken != "fresh_grant" {
|
||||
t.Fatalf("cached value malformed: %+v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCoordinatedTokenRefresh_ErrorIsNotCached makes sure we don't poison the
|
||||
// dedup cache when the IdP rejects the grant. Peers must run their own
|
||||
// refresh; they cannot inherit an error.
|
||||
func TestCoordinatedTokenRefresh_ErrorIsNotCached(t *testing.T) {
|
||||
failing := &erroringTokenExchanger{}
|
||||
logger := NewLogger("error")
|
||||
cache := newInMemoryCache()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
logger: logger,
|
||||
tokenExchanger: failing,
|
||||
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
|
||||
refreshResultCache: cache,
|
||||
}
|
||||
defer oidc.refreshCoordinator.Shutdown()
|
||||
|
||||
if _, err := oidc.coordinatedTokenRefresh(nil, "doomed_refresh_token"); err == nil {
|
||||
t.Fatal("expected an error from the failing exchanger")
|
||||
}
|
||||
if cache.Size() != 0 {
|
||||
t.Fatalf("error result must not be cached, size=%d", cache.Size())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// sessionWithIssuedAt builds the smallest SessionData that GetRefreshTokenIssuedAt
|
||||
// reads from. We can't reuse sessionPool.Get() here because that requires a
|
||||
// fully initialized SessionManager - overkill for this unit-level check.
|
||||
func sessionWithIssuedAt(t *testing.T, issuedAt time.Time) *SessionData {
|
||||
t.Helper()
|
||||
rs := sessions.NewSession(nil, "refresh")
|
||||
if !issuedAt.IsZero() {
|
||||
rs.Values["issued_at"] = issuedAt.Unix()
|
||||
}
|
||||
return &SessionData{
|
||||
refreshSession: rs,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
idTokenChunks: make(map[int]*sessions.Session),
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_DisabledWhenAgeZero(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 0}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-30*24*time.Hour))
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false when maxRefreshTokenAge is 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Time{}) // no issued_at value
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false when issued_at missing (legacy session)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_WithinWindow(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-1*time.Hour))
|
||||
if tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=false within max age")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_BeyondWindow(t *testing.T) {
|
||||
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
|
||||
sd := sessionWithIssuedAt(t, time.Now().Add(-7*time.Hour))
|
||||
if !tr.isRefreshTokenExpired(sd) {
|
||||
t.Fatal("expected isRefreshTokenExpired=true beyond max age")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRefreshTokenExpired_NilGuards(t *testing.T) {
|
||||
var tr *TraefikOidc
|
||||
if tr.isRefreshTokenExpired(nil) {
|
||||
t.Fatal("nil receiver must not panic and must return false")
|
||||
}
|
||||
tr = &TraefikOidc{maxRefreshTokenAge: time.Hour}
|
||||
if tr.isRefreshTokenExpired(nil) {
|
||||
t.Fatal("nil session must return false")
|
||||
}
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
|
||||
// Simulate successful Azure authentication
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetUserIdentifier("user@example.com")
|
||||
// Azure may use opaque access tokens
|
||||
session.SetAccessToken("opaque-azure-access-token")
|
||||
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ") // trufflehog:ignore
|
||||
@@ -152,7 +152,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, session2.GetAuthenticated(), "User should remain authenticated")
|
||||
assert.Equal(t, "user@example.com", session2.GetEmail())
|
||||
assert.Equal(t, "user@example.com", session2.GetUserIdentifier())
|
||||
assert.NotEmpty(t, session2.GetAccessToken(), "Access token should persist")
|
||||
assert.NotEmpty(t, session2.GetIDToken(), "ID token should persist")
|
||||
assert.NotEmpty(t, session2.GetRefreshToken(), "Refresh token should persist")
|
||||
|
||||
@@ -485,7 +485,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
|
||||
// Set up the attacker's session with malicious data
|
||||
attackerSession.SetAuthenticated(true)
|
||||
attackerSession.SetEmail("attacker@evil.com")
|
||||
attackerSession.SetUserIdentifier("attacker@evil.com")
|
||||
attackerSession.SetIDToken(ValidIDToken)
|
||||
attackerSession.SetAccessToken(ValidAccessToken)
|
||||
|
||||
@@ -512,7 +512,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get the email from the session
|
||||
email := session.GetEmail()
|
||||
email := session.GetUserIdentifier()
|
||||
w.Header().Set("X-User-Email", email)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
+26
-26
@@ -100,7 +100,7 @@ type combinedSessionPayload struct {
|
||||
A string `json:"a,omitempty"`
|
||||
R string `json:"r,omitempty"`
|
||||
I string `json:"i,omitempty"`
|
||||
E string `json:"e,omitempty"`
|
||||
Ui string `json:"ui,omitempty"`
|
||||
Cs string `json:"cs,omitempty"`
|
||||
N string `json:"n,omitempty"`
|
||||
Cv string `json:"cv,omitempty"`
|
||||
@@ -113,11 +113,11 @@ type combinedSessionPayload struct {
|
||||
// knownSessionKeys are the standard keys that are handled explicitly in the combined payload.
|
||||
// All other mainSession.Values keys are stored in the X (extra) field.
|
||||
var knownSessionKeys = map[string]bool{
|
||||
"access_token": true,
|
||||
"refresh_token": true,
|
||||
"id_token": true,
|
||||
"email": true,
|
||||
"authenticated": true,
|
||||
"access_token": true,
|
||||
"refresh_token": true,
|
||||
"id_token": true,
|
||||
"user_identifier": true,
|
||||
"authenticated": true,
|
||||
"csrf": true,
|
||||
"nonce": true,
|
||||
"code_verifier": true,
|
||||
@@ -1134,7 +1134,7 @@ func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *
|
||||
sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName())
|
||||
|
||||
// Populate legacy session values from combined payload
|
||||
sessionData.mainSession.Values["email"] = payload.E
|
||||
sessionData.mainSession.Values["user_identifier"] = payload.Ui
|
||||
sessionData.mainSession.Values["authenticated"] = payload.Au
|
||||
sessionData.mainSession.Values["csrf"] = payload.Cs
|
||||
sessionData.mainSession.Values["nonce"] = payload.N
|
||||
@@ -1278,7 +1278,7 @@ func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, opti
|
||||
A: sd.getAccessTokenUnsafe(),
|
||||
R: sd.getRefreshTokenUnsafe(),
|
||||
I: sd.getIDTokenUnsafe(),
|
||||
E: sd.getEmailUnsafe(),
|
||||
Ui: sd.getUserIdentifierUnsafe(),
|
||||
Au: sd.getAuthenticatedUnsafe(),
|
||||
Cs: sd.getCSRFUnsafe(),
|
||||
N: sd.getNonceUnsafe(),
|
||||
@@ -2469,30 +2469,30 @@ func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmail retrieves the authenticated user's email address.
|
||||
// The email is extracted from ID token claims and used for
|
||||
// authorization decisions and header injection.
|
||||
// GetUserIdentifier retrieves the authenticated user's identifier as extracted
|
||||
// from the configured userIdentifierClaim of the ID token (email, sub, oid,
|
||||
// upn, preferred_username, etc.). The value is used for authorization
|
||||
// decisions and header injection.
|
||||
// Returns:
|
||||
// - The user's email address string, or an empty string if not set.
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
// - The user identifier string, or an empty string if not set.
|
||||
func (sd *SessionData) GetUserIdentifier() string {
|
||||
sd.sessionMutex.RLock()
|
||||
defer sd.sessionMutex.RUnlock()
|
||||
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
return userIdentifier
|
||||
}
|
||||
|
||||
// SetEmail stores the authenticated user's email address.
|
||||
// The email is typically extracted from the 'email' claim in the ID token.
|
||||
// SetUserIdentifier stores the authenticated user's identifier value.
|
||||
// Parameters:
|
||||
// - email: The user's email address to store.
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
// - userIdentifier: The user identifier to store (email, sub, or other claim value).
|
||||
func (sd *SessionData) SetUserIdentifier(userIdentifier string) {
|
||||
sd.sessionMutex.Lock()
|
||||
defer sd.sessionMutex.Unlock()
|
||||
|
||||
currentVal, _ := sd.mainSession.Values["email"].(string)
|
||||
if currentVal != email {
|
||||
sd.mainSession.Values["email"] = email
|
||||
currentVal, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
if currentVal != userIdentifier {
|
||||
sd.mainSession.Values["user_identifier"] = userIdentifier
|
||||
sd.dirty = true
|
||||
}
|
||||
}
|
||||
@@ -2626,10 +2626,10 @@ func (sd *SessionData) getRefreshTokenUnsafe() string {
|
||||
return result.Token
|
||||
}
|
||||
|
||||
// getEmailUnsafe retrieves the email without acquiring locks.
|
||||
func (sd *SessionData) getEmailUnsafe() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
// getUserIdentifierUnsafe retrieves the user identifier without acquiring locks.
|
||||
func (sd *SessionData) getUserIdentifierUnsafe() string {
|
||||
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
|
||||
return userIdentifier
|
||||
}
|
||||
|
||||
// getCSRFUnsafe retrieves the CSRF token without acquiring locks.
|
||||
|
||||
@@ -320,17 +320,16 @@ func (s *SessionBehaviourSuite) TestSessionData_DirtyTracking() {
|
||||
s.False(session.IsDirty())
|
||||
}
|
||||
|
||||
// TestSessionData_SetEmail tests email setter with dirty tracking
|
||||
func (s *SessionBehaviourSuite) TestSessionData_SetEmail() {
|
||||
// TestSessionData_SetUserIdentifier tests user identifier setter with dirty tracking
|
||||
func (s *SessionBehaviourSuite) TestSessionData_SetUserIdentifier() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
session, err := s.sessionManager.GetSession(req)
|
||||
s.Require().NoError(err)
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
// Set email
|
||||
session.SetEmail("test@example.com")
|
||||
s.Equal("test@example.com", session.GetEmail())
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
s.Equal("test@example.com", session.GetUserIdentifier())
|
||||
s.True(session.IsDirty())
|
||||
}
|
||||
|
||||
@@ -568,7 +567,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Clear() {
|
||||
// Set some data
|
||||
err = session.SetAuthenticated(true)
|
||||
s.Require().NoError(err)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.SetCSRF("csrf-token")
|
||||
|
||||
// Clear session
|
||||
@@ -588,7 +587,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Save() {
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
// Modify session
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
s.True(session.IsDirty())
|
||||
|
||||
// Save session
|
||||
|
||||
+6
-6
@@ -2688,7 +2688,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
|
||||
// Set up initial session state (what user has when first logging in)
|
||||
session1.SetAuthenticated(true)
|
||||
session1.SetEmail(originalUserData["email"].(string))
|
||||
session1.SetUserIdentifier(originalUserData["email"].(string))
|
||||
session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars")
|
||||
session1.SetIDToken("initial-valid-id-token-longer-than-20-chars")
|
||||
session1.SetRefreshToken("valid-refresh-token-should-last-30-days")
|
||||
@@ -2732,7 +2732,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
// Simulate what happens when middleware detects expired tokens
|
||||
// It should preserve session state while attempting token refresh
|
||||
originalAuth := session2.GetAuthenticated()
|
||||
originalEmail := session2.GetEmail()
|
||||
originalEmail := session2.GetUserIdentifier()
|
||||
|
||||
// Reconstruct user data from individual stored keys
|
||||
originalUserDataStored := make(map[string]interface{})
|
||||
@@ -2813,7 +2813,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
|
||||
|
||||
// Verify all session data is still intact after token refresh
|
||||
postRefreshAuth := session2.GetAuthenticated()
|
||||
postRefreshEmail := session2.GetEmail()
|
||||
postRefreshEmail := session2.GetUserIdentifier()
|
||||
userDataPresent := true
|
||||
for k := range originalUserData {
|
||||
if session2.mainSession.Values["user_data_"+k] == nil {
|
||||
@@ -2907,7 +2907,7 @@ func TestSessionExpiryVsTokenExpiry(t *testing.T) {
|
||||
|
||||
// Set up session with specific creation time
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetUserIdentifier("test@example.com")
|
||||
session.mainSession.Values["created_at"] = sessionCreatedAt.Unix()
|
||||
|
||||
// Create tokens with specific expiry
|
||||
@@ -3018,7 +3018,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
||||
|
||||
// Set up session with data that should be preserved or removed
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("cleanup@example.com")
|
||||
session.SetUserIdentifier("cleanup@example.com")
|
||||
|
||||
session.mainSession.Values["user_data"] = "Test User|user-123"
|
||||
session.mainSession.Values["preferences"] = "theme:dark,lang:en"
|
||||
@@ -3049,7 +3049,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
|
||||
if scenario.shouldCleanup {
|
||||
if sessionTooOld {
|
||||
session.SetAuthenticated(false)
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
for key := range session.mainSession.Values {
|
||||
|
||||
+15
@@ -55,6 +55,15 @@ type Config struct {
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
// MaxRefreshTokenAgeSeconds is a heuristic upper bound on the lifetime of
|
||||
// a stored refresh token. Once the token has been in the session longer
|
||||
// than this, requests treat it as expired up-front - returning 401 to
|
||||
// AJAX callers and triggering full re-auth on navigations - instead of
|
||||
// hammering the IdP with grants that will only fail with invalid_grant.
|
||||
// IdPs do not expose RT TTL on the wire, so this is intentionally a
|
||||
// conservative heuristic; tune to match your provider configuration.
|
||||
// Default 21600 (6h). Set to 0 to disable the check.
|
||||
MaxRefreshTokenAgeSeconds int `json:"maxRefreshTokenAgeSeconds"`
|
||||
SessionMaxAge int `json:"sessionMaxAge"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
@@ -247,6 +256,7 @@ func CreateConfig() *Config {
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
OverrideScopes: false, // Default to appending scopes, not overriding
|
||||
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
|
||||
MaxRefreshTokenAgeSeconds: 21600, // 6h - conservative heuristic, see field doc
|
||||
SecurityHeaders: createDefaultSecurityConfig(),
|
||||
Redis: nil, // Redis is disabled by default, configure via Traefik or env vars
|
||||
}
|
||||
@@ -370,6 +380,11 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
|
||||
}
|
||||
|
||||
// Validate refresh-token max-age heuristic
|
||||
if c.MaxRefreshTokenAgeSeconds < 0 {
|
||||
return fmt.Errorf("maxRefreshTokenAgeSeconds cannot be negative")
|
||||
}
|
||||
|
||||
// Validate audience if specified
|
||||
if c.Audience != "" {
|
||||
// Validate audience format - should be a valid identifier or URL
|
||||
|
||||
@@ -293,7 +293,7 @@ func (tf *TestFramework) CreateAuthenticatedRequest(method, path string) (*http.
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(tf.fixtures.UserEmail)
|
||||
session.SetUserIdentifier(tf.fixtures.UserEmail)
|
||||
session.SetAccessToken(tf.fixtures.AccessToken)
|
||||
session.SetRefreshToken(tf.fixtures.RefreshToken)
|
||||
session.SetIDToken(tf.GenerateJWT(tf.fixtures.Claims))
|
||||
|
||||
+132
-64
@@ -11,6 +11,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -46,6 +47,17 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Hot-path fast-return: a previously-verified token has already passed
|
||||
// signature, claims, and replay checks. Skipping the parseJWT cost here
|
||||
// matters under bursty traffic (e.g. 10+ concurrent panel requests on
|
||||
// every Grafana dashboard refresh) where the same token is validated
|
||||
// dozens of times per second by validateStandardTokens.
|
||||
if t.tokenCache != nil {
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
parsedJWT, parseErr := parseJWT(token)
|
||||
if parseErr != nil {
|
||||
return fmt.Errorf("failed to parse JWT for blacklist check: %w", parseErr)
|
||||
@@ -63,12 +75,6 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check token cache FIRST - if token is already verified and cached, return immediately
|
||||
// This prevents false positives when multiple goroutines validate the same token concurrently
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only check JTI blacklist for tokens that aren't already in the cache
|
||||
// This is for FIRST-TIME validation to detect replay attacks
|
||||
if jti, ok := parsedJWT.Claims["jti"].(string); ok && jti != "" {
|
||||
@@ -315,15 +321,6 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
jwksURL := t.jwksURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
|
||||
if !t.suppressDiagnosticLogs && jwks != nil {
|
||||
t.safeLogDebugf("DIAGNOSTIC: Retrieved JWKS with %d keys from URL: %s", len(jwks.Keys), jwksURL)
|
||||
}
|
||||
|
||||
kid, ok := jwt.Header["kid"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing key ID in token header")
|
||||
@@ -337,38 +334,12 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
t.safeLogDebugf("DIAGNOSTIC: Looking for kid=%s, alg=%s in JWKS", kid, alg)
|
||||
}
|
||||
|
||||
if jwks == nil {
|
||||
return fmt.Errorf("JWKS is nil, cannot verify token")
|
||||
}
|
||||
|
||||
// Find the matching key in JWKS
|
||||
var matchingKey *JWK
|
||||
availableKids := make([]string, 0, len(jwks.Keys))
|
||||
for _, key := range jwks.Keys {
|
||||
availableKids = append(availableKids, key.Kid)
|
||||
if key.Kid == kid {
|
||||
matchingKey = &key
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchingKey == nil {
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogErrorf("DIAGNOSTIC: No matching key found for kid=%s. Available kids: %v", kid, availableKids)
|
||||
}
|
||||
return fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogDebugf("DIAGNOSTIC: Found matching key for kid=%s, key type: %s", kid, matchingKey.Kty)
|
||||
}
|
||||
|
||||
publicKeyPEM, err := jwkToPEM(matchingKey)
|
||||
pubKey, err := t.jwkCache.GetPublicKey(context.Background(), jwksURL, kid, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
|
||||
return fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
if err := verifySignature(token, publicKeyPEM, alg); err != nil {
|
||||
if err := verifySignatureWithKey(token, pubKey, alg); err != nil {
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogErrorf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s: %v", kid, alg, err)
|
||||
}
|
||||
@@ -451,7 +422,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
}
|
||||
t.logger.Debugf("Attempting refresh with token starting with %s...", tokenPrefix)
|
||||
|
||||
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
|
||||
newToken, err := t.coordinatedTokenRefresh(req, initialRefreshToken)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
|
||||
@@ -463,7 +434,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
session.SetRefreshToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
// Clear CSRF tokens as well to prevent any replay attacks
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
@@ -505,12 +476,18 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
|
||||
return false
|
||||
}
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
|
||||
return false
|
||||
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
if t.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
t.logger.Errorf("refreshToken failed: User identifier claim '%s' missing or empty in refreshed token", t.userIdentifierClaim)
|
||||
return false
|
||||
}
|
||||
t.logger.Debugf("Configured claim '%s' not found in refreshed token, using 'sub' claim as fallback", t.userIdentifierClaim)
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetUserIdentifier(userIdentifier)
|
||||
|
||||
// Get token expiry information for logging
|
||||
var expiryTime time.Time
|
||||
@@ -536,7 +513,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
session.SetAccessToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetEmail("")
|
||||
session.SetUserIdentifier("")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -553,6 +530,91 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return true
|
||||
}
|
||||
|
||||
// coordinatedTokenRefresh routes a refresh-token grant through the
|
||||
// RefreshCoordinator so that concurrent requests sharing the same refresh
|
||||
// token coalesce into a single upstream call. This prevents the thundering
|
||||
// herd that yields invalid_grant when the IdP rotates refresh tokens.
|
||||
//
|
||||
// Falls back to a direct call when the coordinator is nil, which only
|
||||
// happens in tests that build TraefikOidc literals without going through
|
||||
// NewWithContext.
|
||||
func (t *TraefikOidc) coordinatedTokenRefresh(req *http.Request, refreshToken string) (*TokenResponse, error) {
|
||||
if t.refreshCoordinator == nil {
|
||||
return t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
}
|
||||
|
||||
parentCtx := context.Background()
|
||||
if req != nil {
|
||||
parentCtx = req.Context()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(parentCtx, refreshCoordinatorWaitTimeout)
|
||||
defer cancel()
|
||||
|
||||
sessionID := refreshCoordinatorSessionID(refreshToken)
|
||||
|
||||
return t.refreshCoordinator.CoordinateRefresh(
|
||||
ctx,
|
||||
sessionID,
|
||||
refreshToken,
|
||||
func() (*TokenResponse, error) {
|
||||
// Cross-replica dedup. The in-process coordinator already
|
||||
// collapses concurrent grants on this pod; this Redis-backed
|
||||
// short-TTL cache covers the (rare) case of a failover or
|
||||
// load-balancer reroute mid-refresh, where two pods would
|
||||
// otherwise both POST the same refresh_token to the IdP.
|
||||
if cached, ok := t.lookupCachedRefreshResult(sessionID); ok {
|
||||
return cached, nil
|
||||
}
|
||||
resp, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
if err == nil && resp != nil {
|
||||
t.cacheRefreshResult(sessionID, resp)
|
||||
}
|
||||
return resp, err
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// lookupCachedRefreshResult returns a previously-stored TokenResponse for the
|
||||
// given refresh-token hash, if one exists and is still within its short TTL.
|
||||
// The cache wraps the universal cache, which is Redis-backed in production -
|
||||
// so a "hit" here means another Traefik replica refreshed this same token
|
||||
// within the last few seconds.
|
||||
func (t *TraefikOidc) lookupCachedRefreshResult(sessionID string) (*TokenResponse, bool) {
|
||||
if t.refreshResultCache == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := t.refreshResultCache.Get(refreshResultCacheKey(sessionID))
|
||||
if !ok || v == nil {
|
||||
return nil, false
|
||||
}
|
||||
if tr, ok := v.(*TokenResponse); ok && tr != nil {
|
||||
return tr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// cacheRefreshResult stores the new TokenResponse under the refresh-token
|
||||
// hash for a short window. TTL is intentionally tight: the rotated refresh
|
||||
// token cannot be re-presented to the IdP, and any peer waiting longer than
|
||||
// this window has almost certainly given up via its own coordinator timeout.
|
||||
func (t *TraefikOidc) cacheRefreshResult(sessionID string, resp *TokenResponse) {
|
||||
if t.refreshResultCache == nil || resp == nil {
|
||||
return
|
||||
}
|
||||
t.refreshResultCache.Set(refreshResultCacheKey(sessionID), resp, refreshResultCacheTTL)
|
||||
}
|
||||
|
||||
// refreshResultCacheKey namespaces refresh-result entries inside the shared
|
||||
// cache namespace.
|
||||
func refreshResultCacheKey(sessionID string) string {
|
||||
return "rt-result:" + sessionID
|
||||
}
|
||||
|
||||
// refreshResultCacheTTL bounds how long a peer can lean on the dedup cache.
|
||||
// Long enough for a sibling replica to observe the result, short enough that
|
||||
// a stale entry never re-supplies a token after the IdP has already moved on.
|
||||
const refreshResultCacheTTL = 5 * time.Second
|
||||
|
||||
// RevokeToken revokes a token locally by adding it to the blacklist cache.
|
||||
// It removes the token from the verification cache and adds both the token
|
||||
// and its JTI (if present) to the blacklist to prevent future use.
|
||||
@@ -1138,9 +1200,14 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
sessionManager := t.sessionManager
|
||||
logger := t.logger
|
||||
|
||||
// Only use the fast cleanup interval when actually running under `go test`.
|
||||
// runtime.Compiler == "yaegi" makes isTestMode() return true in production
|
||||
// (Traefik interprets the plugin via yaegi), which would otherwise pin this
|
||||
// ticker to 20 Hz on a real cluster despite tokenCache.Cleanup and
|
||||
// jwkCache.Cleanup both being no-ops there.
|
||||
cleanupInterval := 1 * time.Minute
|
||||
if isTestMode() {
|
||||
cleanupInterval = 50 * time.Millisecond // Fast interval for tests
|
||||
if isTestMode() && runtime.Compiler != "yaegi" {
|
||||
cleanupInterval = 50 * time.Millisecond
|
||||
}
|
||||
|
||||
// Create cleanup function
|
||||
@@ -1182,25 +1249,27 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
}
|
||||
|
||||
// extractGroupsAndRoles extracts group and role information from token claims.
|
||||
// It parses the 'groups' and 'roles' claims from the ID token and validates their format.
|
||||
// Parameters:
|
||||
// - idToken: The ID token containing claims to extract.
|
||||
// It parses the configured group/role claims from the supplied ID token.
|
||||
//
|
||||
// Returns:
|
||||
// - groups: Array of group names from the 'groups' claim.
|
||||
// - roles: Array of role names from the 'roles' claim.
|
||||
// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present
|
||||
// but not arrays of strings.
|
||||
// Most callers should prefer extractGroupsAndRolesFromClaims when claims have
|
||||
// already been parsed for the request (e.g. via SessionData.GetIDTokenClaims),
|
||||
// to avoid re-parsing the JWT.
|
||||
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
|
||||
}
|
||||
return t.extractGroupsAndRolesFromClaims(claims)
|
||||
}
|
||||
|
||||
// extractGroupsAndRolesFromClaims extracts group and role information from
|
||||
// already-parsed claims. Hot path: callers that have a cached claims map (such
|
||||
// as SessionData.GetIDTokenClaims) should use this to skip a redundant
|
||||
// base64+JSON decode of the JWT on every authenticated request.
|
||||
func (t *TraefikOidc) extractGroupsAndRolesFromClaims(claims map[string]interface{}) ([]string, []string, error) {
|
||||
var groups []string
|
||||
var roles []string
|
||||
|
||||
// Extract groups using configurable claim name (defaults to "groups")
|
||||
if groupsClaim, exists := claims[t.groupClaimName]; exists {
|
||||
groupsSlice, ok := groupsClaim.([]interface{})
|
||||
if !ok {
|
||||
@@ -1216,7 +1285,6 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract roles using configurable claim name (defaults to "roles")
|
||||
if rolesClaim, exists := claims[t.roleClaimName]; exists {
|
||||
rolesSlice, ok := rolesClaim.([]interface{})
|
||||
if !ok {
|
||||
|
||||
@@ -95,6 +95,7 @@ type TraefikOidc struct {
|
||||
cancelFunc context.CancelFunc
|
||||
errorRecoveryManager *ErrorRecoveryManager
|
||||
tokenResilienceManager *TokenResilienceManager
|
||||
refreshCoordinator *RefreshCoordinator
|
||||
goroutineWG *sync.WaitGroup
|
||||
dcrConfig *DynamicClientRegistrationConfig
|
||||
dynamicClientRegistrar *DynamicClientRegistrar
|
||||
@@ -124,11 +125,13 @@ type TraefikOidc struct {
|
||||
scopesSupported []string
|
||||
scopes []string
|
||||
refreshGracePeriod time.Duration
|
||||
maxRefreshTokenAge time.Duration
|
||||
metadataMu sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
metadataRetryMutex sync.Mutex
|
||||
firstRequestMutex sync.Mutex
|
||||
sessionInvalidationCache CacheInterface
|
||||
refreshResultCache CacheInterface
|
||||
minimalHeaders bool
|
||||
stripAuthCookies bool
|
||||
enableBackchannelLogout bool
|
||||
|
||||
@@ -343,6 +343,31 @@ func (c *UniversalCache) Get(key string) (interface{}, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// Fast read path for caches whose eviction is dominated by TTL rather than
|
||||
// access-recency (token, JWK, session). Holding only an RLock here lets all
|
||||
// concurrent readers verify cached tokens in parallel — under yaegi the
|
||||
// previous unconditional Lock serialized every JWT verify on a single
|
||||
// mutex and pinned a CPU under load.
|
||||
switch c.config.Type {
|
||||
case CacheTypeToken, CacheTypeJWK, CacheTypeSession:
|
||||
c.mu.RLock()
|
||||
item, exists := c.items[key]
|
||||
if !exists {
|
||||
c.mu.RUnlock()
|
||||
atomic.AddInt64(&c.misses, 1)
|
||||
return nil, false
|
||||
}
|
||||
if !time.Now().After(item.ExpiresAt) {
|
||||
value := item.Value
|
||||
c.mu.RUnlock()
|
||||
atomic.AddInt64(&c.hits, 1)
|
||||
return value, true
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
// Expired — fall through to the write-locked slow path below to
|
||||
// remove the entry under exclusive access.
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ type UniversalCacheManager struct {
|
||||
metadataCache *UniversalCache
|
||||
dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
|
||||
sessionInvalidationCache *UniversalCache // Session invalidation cache for backchannel/front-channel logout
|
||||
refreshResultCache *UniversalCache // Short-lived cross-replica refresh-result dedup (paired with RefreshCoordinator)
|
||||
logger *Logger
|
||||
blacklistCache *UniversalCache
|
||||
cancel context.CancelFunc
|
||||
@@ -181,6 +182,18 @@ func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) {
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
|
||||
// Refresh-result cache: short-lived store keyed by sha256(refreshToken).
|
||||
// In Redis-backed mode this gives cross-replica dedup of refresh grants;
|
||||
// in memory-only mode it's effectively redundant with RefreshCoordinator
|
||||
// but safe and cheap to keep.
|
||||
manager.refreshResultCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000,
|
||||
DefaultTTL: 5 * time.Second,
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
})
|
||||
}
|
||||
|
||||
// initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration
|
||||
@@ -197,6 +210,8 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
|
||||
RedisPrefix: redisConfig.KeyPrefix,
|
||||
PoolSize: redisConfig.PoolSize,
|
||||
EnableMetrics: true,
|
||||
EnableTLS: redisConfig.EnableTLS,
|
||||
TLSSkipVerify: redisConfig.TLSSkipVerify,
|
||||
}
|
||||
|
||||
// Use concrete type to avoid Yaegi reflection issues with interface assignment
|
||||
@@ -387,6 +402,21 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
|
||||
createBackend("session_invalidation"),
|
||||
)
|
||||
|
||||
// Refresh-result cache - shared via Redis so concurrent refreshes across
|
||||
// Traefik replicas can dedup their grants. The 5s TTL is long enough for
|
||||
// peers to observe a recent refresh and short enough that a stale entry
|
||||
// can't be replayed against a now-rotated refresh token.
|
||||
manager.refreshResultCache = NewUniversalCacheWithBackend(
|
||||
UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
MaxSize: 1000,
|
||||
DefaultTTL: 5 * time.Second,
|
||||
Logger: logger,
|
||||
SkipAutoCleanup: true, // Managed cleanup
|
||||
},
|
||||
createBackend("refresh_result"),
|
||||
)
|
||||
|
||||
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
|
||||
}
|
||||
|
||||
@@ -436,6 +466,7 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() {
|
||||
m.tokenTypeCache,
|
||||
m.dcrCredentialsCache,
|
||||
m.sessionInvalidationCache,
|
||||
m.refreshResultCache,
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
@@ -498,6 +529,14 @@ func (m *UniversalCacheManager) GetSessionInvalidationCache() *UniversalCache {
|
||||
return m.sessionInvalidationCache
|
||||
}
|
||||
|
||||
// GetRefreshResultCache returns the short-lived refresh-result cache used to
|
||||
// coalesce refresh-token grants across Traefik replicas.
|
||||
func (m *UniversalCacheManager) GetRefreshResultCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.refreshResultCache
|
||||
}
|
||||
|
||||
// GetDCRCredentialsCache returns the DCR credentials cache for distributed storage
|
||||
func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
@@ -520,7 +559,7 @@ func (m *UniversalCacheManager) Close() error {
|
||||
|
||||
// Close all caches first (they won't close the shared backend)
|
||||
for _, cache := range []*UniversalCache{
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache,
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache, m.refreshResultCache,
|
||||
} {
|
||||
if cache != nil {
|
||||
_ = cache.Close() // Safe to ignore: best effort cache cleanup
|
||||
|
||||
@@ -250,6 +250,11 @@ func (t *TraefikOidc) Close() error {
|
||||
t.safeLogDebug("metadataRefreshStopChan closed")
|
||||
}
|
||||
|
||||
if t.refreshCoordinator != nil {
|
||||
t.refreshCoordinator.Shutdown()
|
||||
t.safeLogDebug("refreshCoordinator shut down")
|
||||
}
|
||||
|
||||
if t.goroutineWG != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestVerifyToken_CacheHitSkipsParse proves the hot-path optimization: when a
|
||||
// token is in the cache, VerifyToken returns nil without calling parseJWT.
|
||||
// We construct a token that PASSES the cheap format checks (3 segments, len
|
||||
// >= 10) but whose body is unparseable JSON. With the cache hit hoisted ahead
|
||||
// of parseJWT, the function returns nil. Without the hoist, parseJWT would
|
||||
// fail with "failed to parse JWT for blacklist check".
|
||||
func TestVerifyToken_CacheHitSkipsParse(t *testing.T) {
|
||||
tr := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenCache: NewTokenCache(),
|
||||
// limiter intentionally absent; if we reached the rate-limit check
|
||||
// the test would NPE - this is a stronger assertion that we exit
|
||||
// before that point.
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
tr.tokenVerifier = tr
|
||||
|
||||
// Three segments separated by '.', body is junk after base64-decode + JSON.
|
||||
// Pre-fix this fails parseJWT; post-fix it returns nil because the cache
|
||||
// short-circuits.
|
||||
junkToken := "header.bm90LWpzb24.signature" // base64(not-json) in the middle
|
||||
tr.tokenCache.Set(junkToken, map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
"sub": "test",
|
||||
}, time.Hour)
|
||||
|
||||
if err := tr.VerifyToken(junkToken); err != nil {
|
||||
t.Fatalf("expected cache-hit fast path to return nil, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyToken_CacheMissStillParses ensures we did not skip too aggressively
|
||||
// - on a cache miss, the function must still parse and reach the rate-limit
|
||||
// check. We assert by passing a syntactically valid token whose signature
|
||||
// won't verify, expecting an error from later in the pipeline.
|
||||
func TestVerifyToken_CacheMissStillParses(t *testing.T) {
|
||||
tr := &TraefikOidc{
|
||||
logger: NewLogger("error"),
|
||||
tokenCache: NewTokenCache(),
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
// no tokenBlacklist, no jwkCache - the function will fail somewhere
|
||||
// after parseJWT. We just need a non-nil error to confirm we did
|
||||
// progress past the cache check.
|
||||
}
|
||||
tr.tokenVerifier = tr
|
||||
|
||||
// Real JWT structure but unsigned/unverifiable.
|
||||
rawToken := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature"
|
||||
|
||||
if err := tr.VerifyToken(rawToken); err == nil {
|
||||
t.Fatal("expected an error past parseJWT for an unsigned token, got nil")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user