Compare commits

...

8 Commits

Author SHA1 Message Date
lukaszraczylo 8c5df82dcf fix(azure): treat Microsoft proprietary access tokens as opaque (#134) (#138)
Followup to issue #134 — two reporters returned saying that even with the
JWKS caching fix in v1.0.7/v1.0.8, every request emitted:

  ERROR: TraefikOidcPlugin: UNKNOWN token verification failed:
    signature verification failed: crypto/rsa: verification error
  ERROR: TraefikOidcPlugin: DIAGNOSTIC: Signature verification failed for
    kid=<kid>, alg=RS256: crypto/rsa: verification error

Root cause: when an Azure tenant is configured without a custom API
resource, Microsoft issues access tokens for Microsoft Graph (or Azure
Mgmt). These tokens carry a `nonce` value in the JWT *header*; the bytes
that get signed contain SHA256(nonce), while the wire token ships the
original nonce. Any standard JWS verifier rejects the signature, which is
exactly Microsoft's intent — they document the format as proprietary and
tell client apps not to validate it
(https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
"you can't validate tokens for Microsoft Graph according to these rules
due to their proprietary format").

validateAzureTokens was nonetheless attempting JWT verification on every
JWT-shaped access token, then silently falling back to the ID token when
verification failed. Auth still worked end-to-end, but every request
spammed two error log lines.

Two-layer defense:

* validateAzureTokens now detects the proprietary-nonce header before
  calling verifyToken on the access token. When detected, the token is
  treated as opaque (matching the existing branch for non-JWT tokens) and
  validation proceeds via the ID token, exactly as Microsoft prescribes.

* VerifyJWTSignatureAndClaims downgrades the DIAGNOSTIC error log to
  debug for tokens carrying the same proprietary marker, in case any
  path outside validateAzureTokens reaches it.

Authorization still hinges on a separately-verifiable ID token — the
confused-deputy guard from CWE-441 is preserved (and explicitly tested).
2026-05-11 17:31:37 +01:00
lukaszraczylo aa96e9dbee Add sponsorship
Just in case you appreciate this project, feel generous and want to sponsor my caffeine addiction.
2026-05-10 21:25:26 +01:00
lukaszraczylo 1e33bb0a4d feat(auth): support private_key_jwt and client_secret_basic (#137)
revocation endpoints, joining the existing client_secret_post default.
Both are opt-in via the new clientAuthMethod config field. Closes #135.

private_key_jwt (RFC 7523 §2.2 / OpenID Connect Core §9)
========================================================
Plugin signs a short-lived JWT with a configured private key and presents
it as client_assertion. Use when the IdP enforces short secret TTLs or
requires secretless client auth (Microsoft Entra ID / Azure AD, Okta,
Auth0, Keycloak).

New Config fields:
  clientAuthMethod          (default: client_secret_post)
  clientAssertionPrivateKey (inline PEM)
  clientAssertionKeyPath    (PEM file path; mutually exclusive)
  clientAssertionKeyID      (JWS kid header — required)
  clientAssertionAlg        (default: RS256; RS/PS/ES 256–512 supported)

PEM forms accepted: PKCS#8, PKCS#1, SEC1.
Assertion claims: iss=sub=clientID, aud=tokenURL, iat=now, exp=now+60s,
random 16-byte hex jti per request. ECDSA signatures are raw r||s per
RFC 7515 (not ASN.1).

client_secret_basic (RFC 6749 §2.3.1)
=====================================
Sends credentials in the Authorization: Basic header instead of the
body. Both halves are form-urlencoded individually before base64 — that
encoding step is required by the spec and is NOT what stdlib's
http.Request.SetBasicAuth does, so the plugin uses its own helper. The
form body omits client_id and client_secret on this path.

Wire-up
=======
Both methods are dispatched at the same two call sites:
  helpers.go:exchangeTokens — auth_code + refresh_token grants
  token_manager.go:RevokeTokenWithProvider — RFC 7009 revocation

Existing clientSecret deployments are unaffected — empty
clientAuthMethod maps to the historical client_secret_post behavior, and
clientAssertion remains nil unless the new fields are set.

Yaegi compatibility
===================
All required crypto/rsa, crypto/ecdsa, crypto/x509, encoding/pem and
crypto/sha256/384/512 symbols are exposed by the traefik/yaegi stdlib
symbol tables (RSA SignPKCS1v15 + SignPSS, ECDSA Sign,
ParsePKCS8/1PrivateKey, ParseECPrivateKey).

Tests (16 new)
==============
Algorithm-family coverage:
  TestIssue135_SignerRSAFamily — RS256/384/512 + PS256/384/512
  TestIssue135_SignerECDSAFamily — ES256/384/512, raw r||s shape
  TestIssue135_SignerRejectsAlgKeyMismatch
  TestIssue135_SignerJTIUniqueness — 50 sigs, all jti distinct
  TestIssue135_SignerPEMVariants — PKCS#8, PKCS#1, SEC1

Config validation:
  TestIssue135_ConfigValidation — full Validate() matrix
  TestIssue135_ConfigKeyPathLoadsFile

Wire-up:
  TestIssue135_AuthCodeExchangeUsesAssertion
  TestIssue135_RefreshTokenUsesAssertion
  TestIssue135_BackcompatClientSecretPath
  TestIssue135_RevocationUsesAssertion
  TestIssue135_BuildSignerFromInlineConfig
  TestIssue135_BuildSignerDefaultsToRS256
  TestIssue135_ClientSecretBasicAuth — Authorization header, no body creds
  TestIssue135_ClientSecretBasicURLEncodesReservedChars — :, +, /, @, =, &
  TestIssue135_ClientSecretBasicRevocation — revocation parity

Documentation
=============
  README.md — required-row note + 5 optional rows + dedicated section
  docs/CONFIGURATION.md — new Client Authentication section with three
    method subsections, OpenSSL keygen snippet, RFC links
  docs/index.html — 5 new config-table rows + Private Key JWT
    explainer card
  .traefik.yml + examples/complete-traefik-config.yaml — commented
    opt-in example

Out of scope (deferred)
=======================
mTLS / tls_client_auth (RFC 8705) — separate change; requires per-call
http.Client with tls.Config.Certificates and conflicts with the current
pooled HTTP client architecture.
2026-05-09 18:02:41 +01:00
lukaszraczylo bfd702a447 fix(jwk): keep parsed JWKS in local cache only (#134) (#136)
Under yaegi (Traefik's plugin runtime) json.Marshal exposes unexported
struct fields with an X-prefixed name. parsedJWKS{ keys map[string]
crypto.PublicKey } therefore round-tripped through Redis as
{"Xkeys":{"<kid>":{"N":<huge>,"E":65537}}} — *rsa.PublicKey.N is a
*big.Int that marshals to a JSON number hundreds of digits long. On
read, json.Unmarshal into interface{} parses numbers as float64, which
cannot represent that range:

  Failed to deserialize value for key .../discovery/v2.0/keys:parsed:
  json: cannot unmarshal number 2251513...
    into Go value of type float64

Auth still worked (the JWKCache rebuilt the keys in memory on every
miss) but the error log spammed every request.

Two structural problems were behind it:

* parsedJWKS holds crypto.PublicKey interface values that aren't
  meaningfully JSON-serializable. Even on compiled Go (where the
  unexported field marshals to {}), the post-roundtrip type assertion
  v.(*parsedJWKS) silently failed and the cache was useless.
* The same pattern applied to *JWKSet — the struct shape survived JSON
  but the type assertion still failed, defeating the cache for every
  call that went through Redis.

Both keys now use the new UniversalCache.SetLocal/GetLocal pair, which
skips the configured distributed backend entirely. JWK rotation is rare
and a per-replica HTTP fetch on cold cache is cheap, so cross-replica
coherence buys nothing for these entries.

Stale Redis entries written by previous versions are simply ignored —
the new code never reads under those keys, and Redis TTL retires them.

Includes regression coverage for the Azure round-trip, the
poisoned-stale-data scenario, and the SetLocal/GetLocal isolation
contract.

patch-release
2026-05-08 13:35:23 +01:00
lukaszraczylo 68c150eba4 fix(cache/redis): honor enableTLS for Redis backend (#133)
The redis.enableTLS / redis.tlsSkipVerify settings were accepted by the
config layer but silently dropped before reaching the connection pool, so
the plugin always dialed Redis in plaintext. This blocked TLS-only Redis
deployments such as AWS ElastiCache with in-transit encryption.

- Add EnableTLS, TLSSkipVerify, TLSServerName to backends.Config and
  PoolConfig and forward them through universal_cache_singleton ->
  backends.Config -> PoolConfig.
- In the connection pool, dial via tls.Dialer.DialContext (TLS 1.2
  minimum) with SNI defaulting to the host part of the configured
  Address when TLSServerName is empty, so ElastiCache cluster endpoints
  validate out of the box. Plain dial path now also propagates ctx.
- Add regression tests covering successful TLS negotiation with skip-
  verify, rejection of self-signed certs without skip-verify, rejection
  of plain TCP servers when EnableTLS=true, and unaffected plaintext
  behavior.
- Document maxRefreshTokenAgeSeconds (added in 1b6c861) and the implicit
  SSE / WebSocket auth bypass (added in 684a990) in README.md,
  docs/CONFIGURATION.md and docs/index.html.
- Add the missing redis.tlsSkipVerify row to docs/index.html and clarify
  the redis.enableTLS description.

patch-release
2026-05-07 12:24:13 +01:00
lukaszraczylo 9cbca4c4fb fix(refresh): honor userIdentifierClaim in token refresh path (#132)
patch-release

The refresh path in token_manager.go hardcoded the "email" claim when
extracting the user identifier from a refreshed ID token, ignoring the
configured userIdentifierClaim. Keycloak users without an email claim
(using sub or another identifier) were kicked out on refresh even
though their initial login worked.

The callback path (auth_flow.go:226-239) already honored
userIdentifierClaim with "sub" fallback; PR #100 (commit a316a98)
added that support but missed the refresh path.

Mirror the callback logic in refreshToken so both paths behave the same.

Cleanup: rename Get/SetEmail to Get/SetUserIdentifier on SessionData
to match the actual semantics. The slot already stored the configured
identifier (email, sub, oid, upn, preferred_username), only the API
name was misleading. Storage key "email" → "user_identifier" and
combinedSessionPayload field E (json:"e") → Ui (json:"ui").

Compat note: existing user sessions invalidate on upgrade — every active
user re-authenticates once after deploying this change.
2026-05-07 09:21:41 +01:00
lukaszraczylo 684a990f59 fix: reduce yaegi CPU footprint + require auth on SSE/WebSocket bypass
minor-release

Behaviour changes (potentially breaking for operators relying on the prior
unauthenticated SSE bypass):

* SSE (`Accept: text/event-stream`) and WebSocket upgrade requests now
  return 401 when no authenticated session is present. Previously the
  bypass forwarded unconditionally, which let any caller reach the
  backend by setting the right header. Excluded URLs are unchanged.
  Operators relying on unauthenticated SSE/WS access must move the path
  into ExcludedURLs.

Performance fixes (target: long-running dashboards like Grafana / ArgoCD
where many panels poll concurrently while the page stays open):

* Stop honouring isTestMode() for the singleton-token-cleanup interval
  under yaegi (the Traefik plugin runtime). In production the plugin was
  running a 20 Hz no-op cleanup ticker because runtime.Compiler ==
  "yaegi" tripped the test-mode branch.
* processAuthorizedRequest now resolves ID-token claims at most once per
  request via SessionData.GetIDTokenClaims (already cached on the
  session) and reuses them for both groups/roles extraction and
  header-template rendering. Previously every authenticated request
  parsed the JWT twice.
* Added extractGroupsAndRolesFromClaims to drive groups/roles off
  pre-parsed claims; extractGroupsAndRoles still works for tests.
* Removed the unconditional session.MarkDirty() in the header-templates
  branch. Templates only mutate request headers, not session state, so
  the prior MarkDirty was re-encrypting and rewriting all session
  cookies on every authenticated request that used header templates.

Other:

* Added isWebSocketUpgrade (RFC 6455 handshake detection — Connection:
  Upgrade + Upgrade: websocket, tolerant of multi-token Connection
  headers and case).
* Renamed applySSEUserHeaders -> applyBypassUserHeaders; it now returns
  bool so the dispatcher can reject unauthenticated SSE/WS with 401.
* Added tests for SSE and WS bypass covering both the auth-rejection
  path and the authenticated forward path.
2026-05-02 03:12:20 +01:00
lukaszraczylo 1b6c8616fd fix(refresh): coalesce refresh-token grants + bound goroutines + cache hot path (target v0.8.27) (#131)
* fix(refresh): wire RefreshCoordinator into the live refresh path

The RefreshCoordinator existed but was never instantiated. The actual
refresh path used only session.refreshMutex, which is per-SessionData
instance - and SessionData is pulled from a sync.Pool per request -
so concurrent requests sharing a refresh token had ZERO coordination.

Symptom: when access_token expired (e.g. 5min Zitadel default), every
in-flight request from a polling client (Grafana panels) entered the
refresh path simultaneously and POSTed the same refresh_token to the
IdP. With refresh-token rotation enabled (Zitadel/Authentik default),
only one grant succeeded; the rest got invalid_grant and each cleared
the entire session. Subsequent requests then thrashed in re-auth loops.

This commit:
- adds refreshCoordinator field on TraefikOidc
- instantiates it in NewWithContext with DefaultRefreshCoordinatorConfig
- shuts it down in Close() under shutdownOnce
- routes refreshToken() through the coordinator via coordinatedTokenRefresh,
  which collapses concurrent grants to a single upstream call per
  refresh_token hash
- exports refreshCoordinatorSessionID for both internal hashing and the
  middleware-level wireup so dedup keys stay aligned

Behavioural notes:
- nil-coordinator fallback preserves existing tests that build TraefikOidc
  literals without going through the constructor
- followers receive the same TokenResponse/error as the leader, so no
  per-instance code paths change
- existing TestGetNewTokenWithRefreshToken_Concurrency still passes
  because it hits GetNewTokenWithRefreshToken directly, below the
  coordinator boundary

Tests:
- refresh_coordinator_wireup_test.go: 50 concurrent refreshes coalesce
  to <=2 upstream calls; distinct tokens still run in parallel; nil
  coordinator falls back cleanly

* perf(cache): bound L1 backfill goroutines in HybridBackend

Get() and GetMany() previously spawned a goroutine per L2 hit to write
the value through to L1. Under sustained polling traffic (e.g. a Grafana
dashboard refreshing every 30s with N panels) this minted thousands of
goroutines, each running in Yaegi - directly contributing to the
~1000% CPU spike that pairs with the refresh-token herd.

Replace the per-hit goroutines with a single l1BackfillWorker fed by
l1BackfillBuffer, mirroring the existing asyncWriteBuffer/asyncWriteWorker
pattern for L2 writes. Buffer overflow drops the backfill (counted via
l1BackfillDrops) - a dropped backfill just means the next L2 hit for
that key re-queues it, which is safe.

Tests:
- TestHybridBackend_L1BackfillBounded: 1000 distinct L2 hits keep
  goroutine count within +20 of baseline (pre-fix it grew by ~1000)
- TestHybridBackend_L1BackfillFullDrops: drops are accounted for when
  the buffer is saturated and the worker is stopped

* feat(refresh): implement isRefreshTokenExpired heuristic

Replace the placeholder `return false` with a real check based on the
issued_at timestamp that SetRefreshToken already stamps into the session.
Gated by a new MaxRefreshTokenAgeSeconds config field (default 21600 =
6h, matching the existing comment). 0 disables the check.

This wires the previously-dead refreshTokenExpired branch in middleware.go,
which short-circuits AJAX requests with a 401 instead of letting them
hammer the IdP for a refresh token that's almost certainly stale - the
classic Grafana-after-long-pause failure mode.

Behaviour:
- maxRefreshTokenAge=0 disables the check (preserves prior behaviour)
- legacy sessions without issued_at still attempt one refresh; the IdP
  remains the source of truth on first try
- nil-receiver and nil-session guards keep test code that builds
  TraefikOidc literals safe

Tests:
- TestIsRefreshTokenExpired_DisabledWhenAgeZero
- TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp
- TestIsRefreshTokenExpired_WithinWindow
- TestIsRefreshTokenExpired_BeyondWindow
- TestIsRefreshTokenExpired_NilGuards

* perf(token): skip parseJWT on cache hit in VerifyToken

The token cache fast-return existed but ran AFTER parseJWT, so every
validation paid for base64 + JSON unmarshal even on a hit. Under bursty
traffic (e.g. 10+ concurrent panel requests on every Grafana dashboard
refresh, each calling validateStandardTokens which verifies BOTH the
access token and the ID token), this is two redundant parses per
request multiplied by the panel count.

Move the cache lookup ahead of parseJWT. On a hit the function returns
nil immediately. On a miss the original flow runs unchanged.

Also nil-guard t.tokenCache to keep partial-literal test instances safe
(matches the same pattern we already use for tokenBlacklist).

Tests:
- TestVerifyToken_CacheHitSkipsParse: cache pre-populated with claims
  for a token whose body would fail parseJWT - returns nil iff the
  fast-path bypasses the parse
- TestVerifyToken_CacheMissStillParses: a syntactically valid but
  unsigned token still errors past parseJWT on cache miss

* feat(refresh): cross-replica refresh-grant dedup via shared cache

The in-process RefreshCoordinator added in 9f96d8c already collapses
concurrent refresh-token grants on a single Traefik replica. With the
plugin's existing Redis (Dragonfly) cache infrastructure available, we
can extend that dedup across replicas: if pod A refreshes a token at
T+0 and pod B receives a request for the same session at T+1, pod B
should reuse pod A's result rather than POSTing the now-rotated refresh
token to the IdP.

Implementation:
- Add a refreshResultCache to UniversalCacheManager (memory-only when
  Redis is disabled, Redis-backed in production via the existing
  hybrid/Redis-only mode selection)
- Expose it through CacheManager.GetSharedRefreshResultCache and on the
  TraefikOidc struct as refreshResultCache (CacheInterface)
- Inside the closure passed to RefreshCoordinator.CoordinateRefresh,
  consult the cache first; on hit return immediately, on miss exchange
  with the IdP and populate the cache for peers
- 5s TTL: long enough for siblings to observe, short enough that a
  rotated refresh token cannot be re-supplied after the IdP has moved on
- Errors are intentionally NOT cached - peers must always be able to
  retry on their own

Pragmatic choice: optimistic cache rather than a hard distributed lock.
- A hard lock (SET NX + poll) doubles Redis RTT and risks dead-locks
  if a Traefik pod dies mid-grant.
- The user's BGP+Local externalTrafficPolicy already pins ingress for
  a session to one node in steady state, so cross-pod racing is rare.
- This optimistic path catches the rare failover case without adding
  failure modes.

Tests:
- TestCoordinatedTokenRefresh_CrossReplicaCacheHit: pre-populated cache
  short-circuits the upstream call entirely (0 IdP calls)
- TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache: leader stores
  a successful result for peers to find
- TestCoordinatedTokenRefresh_ErrorIsNotCached: invalid_grant must not
  poison the dedup cache - peers must retry independently
2026-04-30 18:52:39 +01:00
46 changed files with 4213 additions and 282 deletions
+15
View File
@@ -0,0 +1,15 @@
# These are supported funding model platforms
github: lukaszraczylo
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
polar: # Replace with a single Polar username
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
thanks_dev: # Replace with a single thanks.dev username
custom: https://monzo.me/lukaszraczylo
+13
View File
@@ -23,6 +23,19 @@ testData:
providerURL: https://accounts.google.com
clientID: 1234567890.apps.googleusercontent.com
clientSecret: your-client-secret
# Alternative: RFC 7523 private_key_jwt client authentication (Entra ID,
# Okta, Auth0, Keycloak). Replaces clientSecret with a signed JWT assertion.
# See README "Client authentication via private key JWT".
# clientAuthMethod: private_key_jwt
# clientAssertionKeyID: my-key-2026
# clientAssertionAlg: RS256 # default; or PS256/384/512, ES256/384/512
# # File path option:
# clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
# # Or inline PEM (PKCS#8 / PKCS#1 / SEC1):
# clientAssertionPrivateKey: |
# -----BEGIN PRIVATE KEY-----
# MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDexampleexample
# -----END PRIVATE KEY-----
sessionEncryptionKey: potato-secret-is-at-least-32-bytes-long
callbackURL: /oauth2/callback
+61 -1
View File
@@ -96,7 +96,7 @@ More example configs in [`examples/`](examples/).
|-----------|-------------|
| `providerURL` | Issuer URL (used for OIDC discovery). |
| `clientID` | OAuth 2.0 client ID. |
| `clientSecret` | OAuth 2.0 client secret. Supports `urn:k8s:secret:ns:name:key`. |
| `clientSecret` | OAuth 2.0 client secret. Supports `urn:k8s:secret:ns:name:key`. Required when `clientAuthMethod` is unset, `client_secret_post`, or `client_secret_basic`; optional with `private_key_jwt`. |
| `sessionEncryptionKey` | Cookie encryption key, **min 32 bytes**. |
| `callbackURL` | Callback path, e.g. `/oauth2/callback`. |
@@ -121,6 +121,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
| `cookiePrefix` | `_oidc_raczylo_` | Unique prefix per middleware instance to isolate sessions. |
| `sessionMaxAge` | `86400` | Session lifetime in seconds. |
| `refreshGracePeriodSeconds` | `60` | Proactively refresh tokens this many seconds before expiry. |
| `maxRefreshTokenAgeSeconds` | `21600` | Heuristic max stored refresh-token lifetime (6h). Past this, the plugin treats the RT as expired without contacting the IdP — returns 401 to AJAX, full re-auth on navigations. Set `0` to disable. Tune to match your IdP's RT TTL. |
| `rateLimit` | `100` | Requests/sec. Min `10`. |
| `logLevel` | `info` | `debug`, `info`, `error`. |
| `audience` | `clientID` | Custom access-token audience (Auth0 custom APIs). |
@@ -132,6 +133,11 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
| `stripAuthCookies` | `false` | Strip OIDC cookies from backend hop (mitigates HTTP 431). |
| `caCertPath` / `caCertPEM` | none | Trust an internal CA for the provider's TLS. |
| `insecureSkipVerify` | `false` | **Local dev only.** Disables TLS verification, logs a security warning. |
| `clientAuthMethod` | `client_secret_post` | Client auth method. Set `private_key_jwt` for RFC 7523 JWT assertions (Entra ID, Okta, Auth0, Keycloak). See [Client authentication via private key JWT](#client-authentication-via-private-key-jwt). |
| `clientAssertionPrivateKey` | none | Inline PEM private key for `private_key_jwt`. Mutually exclusive with `clientAssertionKeyPath`. |
| `clientAssertionKeyPath` | none | File path to PEM private key for `private_key_jwt`. |
| `clientAssertionKeyID` | none | JWS `kid` header. Required when `clientAuthMethod=private_key_jwt`; must match the public key registered with the IdP. |
| `clientAssertionAlg` | `RS256` | JWS alg for `private_key_jwt`. Supported: `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. |
| `enableBackchannelLogout` / `backchannelLogoutURL` | `false` / none | OIDC Back-Channel Logout (server-to-server). |
| `enableFrontchannelLogout` / `frontchannelLogoutURL` | `false` / none | OIDC Front-Channel Logout (iframe). |
| `redis` | disabled | See [docs/REDIS.md](docs/REDIS.md). |
@@ -165,6 +171,22 @@ Each instance must use a unique `cookiePrefix` **and** `sessionEncryptionKey`,
otherwise a session minted by one instance can grant access through another.
See [issue #87](https://github.com/lukaszraczylo/traefikoidc/issues/87).
### SSE and WebSocket endpoints
Browser clients cannot follow an OIDC `302` redirect on an SSE stream or a
WebSocket upgrade. The middleware handles this automatically:
- **SSE** (`Accept: text/event-stream`) and **WebSocket** (`Upgrade: websocket`)
requests skip the OIDC redirect.
- They are **not** unauthenticated — a valid encrypted session cookie is
required, otherwise the request is rejected. The session must already exist
(i.e. the user logged in via a normal HTTP page first).
- `X-Forwarded-User` is forwarded from the session.
- Validation is cookie-only (no JWK fetch), so streaming keeps working during
brief IdP outages.
No configuration needed — this is implicit behavior.
### HTTP 431 from backends
Either the ID token or the chunked OIDC cookies overflow your backend's header
@@ -196,6 +218,44 @@ caCertPEM: |
Both can be combined. An unparseable bundle fails the plugin at startup.
See [#125](https://github.com/lukaszraczylo/traefikoidc/issues/125).
### Client authentication via private key JWT
Use when your IdP enforces short-lived secrets or pushes secretless client auth
— Microsoft Entra ID / Azure AD, Okta, Auth0, Keycloak. Instead of sending a
static `clientSecret`, the plugin signs a short-lived JWT and submits it as
`client_assertion` per [RFC 7523](https://www.rfc-editor.org/rfc/rfc7523).
Minimal config:
```yaml
clientAuthMethod: private_key_jwt
clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
clientAssertionKeyID: my-key-2026
# clientAssertionAlg: RS256 # default; or PS256/384/512, ES256/384/512
```
Or inline:
```yaml
clientAuthMethod: private_key_jwt
clientAssertionPrivateKey: |
-----BEGIN PRIVATE KEY-----
...
-----END PRIVATE KEY-----
clientAssertionKeyID: my-key-2026
```
Accepted PEM forms: PKCS#8 (`PRIVATE KEY`), PKCS#1 (`RSA PRIVATE KEY`), SEC1
(`EC PRIVATE KEY`). The assertion uses `iss=sub=clientID`, `aud=tokenURL`, 60s
lifetime, random hex `jti` per request. Sent on `/token` (auth-code + refresh)
and `/revoke`. The `kid` must match the public key registered with the IdP.
`clientSecret` becomes optional with `private_key_jwt`. Existing
`client_secret_post` setups are unaffected. Keys are parsed once at startup —
rotation requires a Traefik reload.
See [issue #135](https://github.com/lukaszraczylo/traefikoidc/issues/135).
### Environment variable names containing `API`
Traefik reserves `TRAEFIK_API_*`. User vars whose name contains `API` (e.g.
+1 -1
View File
@@ -1491,7 +1491,7 @@ func TestAudienceEndToEndScenario(t *testing.T) {
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("Failed to set authenticated: %v", err)
}
session.SetEmail("user@company.com")
session.SetUserIdentifier("user@company.com")
session.SetIDToken(validJWT)
session.SetAccessToken(validJWT)
+30 -7
View File
@@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"strings"
"time"
)
// validateRedirectCount checks if redirect limit is exceeded and handles the error
@@ -42,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
// Clear all existing session data
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
session.SetEmail("")
session.SetUserIdentifier("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetIDToken("")
@@ -249,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
session.SetUserIdentifier(userIdentifier)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
@@ -289,7 +290,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
session.SetIDToken("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
session.SetUserIdentifier("")
// Clear CSRF tokens to prevent replay attacks
session.SetCSRF("")
session.SetNonce("")
@@ -360,9 +361,31 @@ func (t *TraefikOidc) isNonNavigationRequest(req *http.Request) bool {
return !strings.Contains(accept, "text/html")
}
// isRefreshTokenExpired checks if refresh token is likely expired (older than 6 hours)
// isRefreshTokenExpired checks whether the stored refresh token is likely
// past its useful lifetime, using the cookie-side issued_at timestamp set by
// SetRefreshToken. IdPs do not expose RT TTL on the wire, so this is a
// conservative heuristic gated by t.maxRefreshTokenAge (default 6h, set via
// MaxRefreshTokenAgeSeconds; 0 disables the check).
//
// The point of this check is to short-circuit the refresh path BEFORE the
// thundering herd hits the IdP for a token the provider has almost certainly
// revoked. Together with the RefreshCoordinator wireup, it keeps Grafana-
// style polling clients from looping on invalid_grant after a long pause.
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
// This is a heuristic check - actual implementation would depend on
// the specific provider and token metadata
return false // Placeholder implementation
if t == nil || session == nil {
return false
}
if t.maxRefreshTokenAge <= 0 {
return false
}
issuedAt := session.GetRefreshTokenIssuedAt()
if issuedAt.IsZero() {
// No timestamp recorded (legacy session pre-dating the issued_at
// field). Don't force a re-auth - attempt refresh once and let the
// IdP be the source of truth.
return false
}
return time.Since(issuedAt) > t.maxRefreshTokenAge
}
+4 -4
View File
@@ -192,7 +192,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
// Pre-populate session with old data
_ = session.SetAuthenticated(true)
session.SetEmail("old@example.com")
session.SetUserIdentifier("old@example.com")
session.SetAccessToken("old-access-token-with-many-characters")
session.SetRefreshToken("old-refresh-token-with-many-characters")
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
@@ -207,7 +207,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
// Verify old data is cleared
s.False(session.GetAuthenticated())
s.Empty(session.GetEmail())
s.Empty(session.GetUserIdentifier())
// Verify new data is set
s.Equal(csrfToken, session.GetCSRF())
@@ -711,7 +711,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
session, err := sessionManager.GetSession(req)
s.Require().NoError(err)
_ = session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
session.mainSession.Values["redirect_count"] = 3
@@ -720,7 +720,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
// Session should be cleared
s.False(session.GetAuthenticated())
s.Empty(session.GetEmail())
s.Empty(session.GetUserIdentifier())
s.Empty(session.GetIDToken())
// Redirect count should be reset to 0 and then incremented by defaultInitiateAuthentication
+8
View File
@@ -113,6 +113,14 @@ func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
}
// GetSharedRefreshResultCache returns the short-lived refresh-result cache used
// by the refresh path to coalesce grants across Traefik replicas via Redis.
func (cm *CacheManager) GetSharedRefreshResultCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetRefreshResultCache(), managed: true}
}
// Close gracefully shuts down all cache components
func (cm *CacheManager) Close() error {
cm.mu.Lock()
+295
View File
@@ -0,0 +1,295 @@
package traefikoidc
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"os"
"time"
)
// isSupportedClientAssertionAlg reports whether alg is a recognized JWS
// algorithm for private_key_jwt (RFC 7523 §2.2).
func isSupportedClientAssertionAlg(alg string) bool {
switch alg {
case "RS256", "RS384", "RS512",
"PS256", "PS384", "PS512",
"ES256", "ES384", "ES512":
return true
}
return false
}
// ClientAssertionSigner builds and signs client_assertion JWTs (RFC 7523 §2.2).
type ClientAssertionSigner struct {
key crypto.PrivateKey
alg string
kid string
// rand is the entropy source for jti generation and PSS/ECDSA signing.
// Defaults to crypto/rand.Reader when nil.
rand io.Reader
// now returns the current time. Defaults to time.Now when nil.
now func() time.Time
}
// NewClientAssertionSigner parses pemBytes as a private key, validates that
// alg is consistent with the key type, and returns a ready-to-use signer.
// kid is placed verbatim in the JWS header.
//
// PEM block types understood:
// - "PRIVATE KEY" → PKCS#8 (tried first for all types)
// - "RSA PRIVATE KEY" → PKCS#1
// - "EC PRIVATE KEY" → SEC1
func NewClientAssertionSigner(pemBytes []byte, alg, kid string) (*ClientAssertionSigner, error) {
if !isSupportedClientAssertionAlg(alg) {
return nil, fmt.Errorf("unsupported client assertion alg %q", alg)
}
if kid == "" {
return nil, fmt.Errorf("kid must not be empty")
}
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, fmt.Errorf("no PEM block found in private key material")
}
var key crypto.PrivateKey
var parseErr error
switch block.Type {
case "PRIVATE KEY":
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
case "RSA PRIVATE KEY":
key, parseErr = x509.ParsePKCS1PrivateKey(block.Bytes)
case "EC PRIVATE KEY":
key, parseErr = x509.ParseECPrivateKey(block.Bytes)
default:
// Best-effort fallback for unknown block types.
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
}
if parseErr != nil {
return nil, fmt.Errorf("failed to parse private key (block type %q): %w", block.Type, parseErr)
}
if err := validateAlgKeyMatch(alg, key); err != nil {
return nil, err
}
return &ClientAssertionSigner{key: key, alg: alg, kid: kid}, nil
}
// validateAlgKeyMatch returns an error when alg implies a key type that does
// not match the actual key.
func validateAlgKeyMatch(alg string, key crypto.PrivateKey) error {
switch alg[0] {
case 'R', 'P': // RS* or PS*
if _, ok := key.(*rsa.PrivateKey); !ok {
return fmt.Errorf("alg %q requires an RSA key, got %T", alg, key)
}
case 'E': // ES*
if _, ok := key.(*ecdsa.PrivateKey); !ok {
return fmt.Errorf("alg %q requires an EC key, got %T", alg, key)
}
}
return nil
}
// Sign constructs and returns a signed client_assertion JWT.
// audience is typically the token endpoint URL (RFC 7523 §3).
// clientID is used as both iss and sub per RFC 7523 §2.2.
func (s *ClientAssertionSigner) Sign(audience, clientID string) (string, error) {
rander := s.rand
if rander == nil {
rander = rand.Reader
}
nowFn := s.now
if nowFn == nil {
nowFn = time.Now
}
now := nowFn()
// 16 random bytes as lowercase hex for jti uniqueness.
jtiBytes := make([]byte, 16)
if _, err := io.ReadFull(rander, jtiBytes); err != nil {
return "", fmt.Errorf("failed to generate jti: %w", err)
}
jti := hex.EncodeToString(jtiBytes)
header := map[string]string{
"alg": s.alg,
"typ": "JWT",
"kid": s.kid,
}
hdrJSON, err := json.Marshal(header)
if err != nil {
return "", fmt.Errorf("failed to marshal JWT header: %w", err)
}
claims := map[string]any{
"iss": clientID,
"sub": clientID,
"aud": audience,
"jti": jti,
"iat": now.Unix(),
"exp": now.Add(60 * time.Second).Unix(),
}
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("failed to marshal JWT claims: %w", err)
}
hdrB64 := base64.RawURLEncoding.EncodeToString(hdrJSON)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
signingInput := hdrB64 + "." + claimsB64
sig, err := s.sign(rander, []byte(signingInput))
if err != nil {
return "", err
}
return signingInput + "." + base64.RawURLEncoding.EncodeToString(sig), nil
}
// sign computes raw signature bytes for signingInput per s.alg.
// validateAlgKeyMatch in NewClientAssertionSigner guarantees the key type
// matches s.alg, but the comma-ok asserts here keep errcheck happy and
// surface internal misuse loudly instead of via panic.
func (s *ClientAssertionSigner) sign(rander io.Reader, input []byte) ([]byte, error) {
switch s.alg {
case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
rsaKey, ok := s.key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("internal: alg %q requires *rsa.PrivateKey, got %T", s.alg, s.key)
}
hash := rsaHashForAlg(s.alg)
digest := hashSum(hash, input)
if s.alg[0] == 'R' {
return signRSAPKCS1v15(rander, rsaKey, hash, digest)
}
return signRSAPSS(rander, rsaKey, hash, digest)
case "ES256", "ES384", "ES512":
ecKey, ok := s.key.(*ecdsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("internal: alg %q requires *ecdsa.PrivateKey, got %T", s.alg, s.key)
}
hash := ecHashForAlg(s.alg)
digest := hashSum(hash, input)
return signECDSA(rander, ecKey, digest)
}
return nil, fmt.Errorf("unhandled alg %q", s.alg)
}
func rsaHashForAlg(alg string) crypto.Hash {
switch alg {
case "RS256", "PS256":
return crypto.SHA256
case "RS384", "PS384":
return crypto.SHA384
case "RS512", "PS512":
return crypto.SHA512
}
return 0
}
func ecHashForAlg(alg string) crypto.Hash {
switch alg {
case "ES256":
return crypto.SHA256
case "ES384":
return crypto.SHA384
case "ES512":
return crypto.SHA512
}
return 0
}
func hashSum(h crypto.Hash, input []byte) []byte {
switch h {
case crypto.SHA256:
sum := sha256.Sum256(input)
return sum[:]
case crypto.SHA384:
sum := sha512.Sum384(input)
return sum[:]
case crypto.SHA512:
sum := sha512.Sum512(input)
return sum[:]
}
return nil
}
func signRSAPKCS1v15(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
sig, err := rsa.SignPKCS1v15(rander, key, hash, digest)
if err != nil {
return nil, fmt.Errorf("RSA PKCS1v15 signing failed: %w", err)
}
return sig, nil
}
func signRSAPSS(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hash}
sig, err := rsa.SignPSS(rander, key, hash, digest, opts)
if err != nil {
return nil, fmt.Errorf("RSA PSS signing failed: %w", err)
}
return sig, nil
}
// signECDSA produces the JWS raw r||s signature (RFC 7515 App. A.3).
// Each scalar is zero-padded to (curve.BitSize+7)/8 bytes.
func signECDSA(rander io.Reader, key *ecdsa.PrivateKey, digest []byte) ([]byte, error) {
r, ss, err := ecdsa.Sign(rander, key, digest)
if err != nil {
return nil, fmt.Errorf("ECDSA signing failed: %w", err)
}
byteLen := (key.Curve.Params().BitSize + 7) / 8
sig := make([]byte, 2*byteLen)
padBigInt(sig[0:byteLen], r)
padBigInt(sig[byteLen:], ss)
return sig, nil
}
// padBigInt writes n as a fixed-width big-endian integer into buf.
func padBigInt(buf []byte, n *big.Int) {
b := n.Bytes()
copy(buf[len(buf)-len(b):], b)
}
// buildClientAssertionSignerFromConfig loads key material and constructs a
// ClientAssertionSigner. Called from NewWithContext when
// ClientAuthMethod == "private_key_jwt".
func buildClientAssertionSignerFromConfig(config *Config) (*ClientAssertionSigner, error) {
var pemBytes []byte
if config.ClientAssertionPrivateKey != "" {
pemBytes = []byte(config.ClientAssertionPrivateKey)
} else {
data, err := os.ReadFile(config.ClientAssertionKeyPath)
if err != nil {
return nil, fmt.Errorf("read clientAssertionKeyPath %q: %w", config.ClientAssertionKeyPath, err)
}
pemBytes = data
}
alg := config.ClientAssertionAlg
if alg == "" {
alg = "RS256"
}
return NewClientAssertionSigner(pemBytes, alg, config.ClientAssertionKeyID)
}
+4 -4
View File
@@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
session.SetCSRF(csrfToken)
session.SetNonce("test-nonce")
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken("old-access-token")
session.SetRefreshToken("old-refresh-token")
session.SetIDToken("old-id-token")
@@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Now perform selective clearing (as done in the fix)
session2.SetAuthenticated(false)
session2.SetEmail("")
session2.SetUserIdentifier("")
session2.SetAccessToken("")
session2.SetRefreshToken("")
session2.SetIDToken("")
@@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// Set initial session data
session.SetAuthenticated(true)
session.SetEmail("old@example.com")
session.SetUserIdentifier("old@example.com")
session.SetAccessToken("old-token")
session.SetCSRF("existing-csrf")
@@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF
// NEW BEHAVIOR: Selective clearing
session2.SetAuthenticated(false)
session2.SetEmail("")
session2.SetUserIdentifier("")
session2.SetAccessToken("")
session2.SetRefreshToken("")
session2.SetIDToken("")
+158 -1
View File
@@ -5,6 +5,7 @@ Complete reference for all Traefik OIDC middleware configuration options.
## Table of Contents
- [Required Parameters](#required-parameters)
- [Client Authentication](#client-authentication)
- [Optional Parameters](#optional-parameters)
- [Security Options](#security-options)
- [Session Management](#session-management)
@@ -22,7 +23,7 @@ Complete reference for all Traefik OIDC middleware configuration options.
|-----------|------|-------------|---------|
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
| `clientSecret` | string | OAuth 2.0 client secret | `your-client-secret` |
| `clientSecret` | string | OAuth 2.0 client secret. Required when `clientAuthMethod` is unset, `client_secret_post`, or `client_secret_basic`. Optional when `clientAuthMethod: private_key_jwt`. | `your-client-secret` |
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
@@ -45,6 +46,129 @@ spec:
---
## Client Authentication
The middleware supports three client authentication methods at the token and
revocation endpoints. The default is `client_secret_post` (current behavior);
`private_key_jwt` is opt-in and backwards compatible.
| Method | Default | Description |
|--------|---------|-------------|
| `client_secret_post` | yes | `client_id` + `client_secret` in the request body. |
| `client_secret_basic` | no | RFC 6749 §2.3.1 — `client_id` + `client_secret` in the `Authorization: Basic` header (form-urlencoded then base64); not in the body. |
| `private_key_jwt` | no | RFC 7523 §2.2 — plugin signs a short-lived JWT with a private key and sends it as `client_assertion`. |
Select via `clientAuthMethod`:
```yaml
clientAuthMethod: private_key_jwt
```
### client_secret_post
Default. The plugin sends `client_id` and `client_secret` as form parameters
in the token / revocation request body. No additional configuration required.
### private_key_jwt
Asymmetric client authentication per
[RFC 7523 §2.2](https://www.rfc-editor.org/rfc/rfc7523). Use this when your
IdP enforces short secret TTLs, when policy mandates secretless clients, or
when you want to avoid distributing a shared secret to the proxy.
For each token / revocation request the plugin builds a JWS with:
- `iss` = `sub` = `clientID`
- `aud` = token endpoint URL
- `iat` = now, `exp` = now + 60s
- `jti` = random hex per request
- `kid` header = `clientAssertionKeyID`
**Required fields:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `clientAuthMethod` | string | `client_secret_post` | Set to `private_key_jwt`. |
| `clientAssertionPrivateKey` | string | none | Inline PEM private key. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8, PKCS#1, and SEC1 formats accepted. |
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk. Mutually exclusive with `clientAssertionPrivateKey`. |
| `clientAssertionKeyID` | string | none | `kid` header inserted in the JWS. Must match the public key registered with the IdP. |
| `clientAssertionAlg` | string | `RS256` | One of `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`, `ES256`, `ES384`, `ES512`. |
When `clientAuthMethod: private_key_jwt`, `clientSecret` is optional.
**Example — inline PEM:**
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth
spec:
plugin:
traefikoidc:
providerURL: https://idp.example.com
clientID: my-client-id
sessionEncryptionKey: your-32-byte-encryption-key-here
callbackURL: /oauth2/callback
clientAuthMethod: private_key_jwt
clientAssertionKeyID: key-2026-01
clientAssertionAlg: RS256
clientAssertionPrivateKey: |
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7VJTUt9Us8cKj
MZj4ev7QnMa1mYV3Kx1jRkH5YwXQ7N2J2j8K5pP6h0oZmXq1yQv4r8wZb3sH9D2k
... (truncated) ...
-----END PRIVATE KEY-----
```
**Example — key on disk:**
```yaml
clientAuthMethod: private_key_jwt
clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
clientAssertionKeyID: key-2026-01
clientAssertionAlg: RS256
```
**Generating an RS256 key with OpenSSL:**
```bash
openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 \
-out client-key.pem
openssl rsa -in client-key.pem -pubout -out client-pub.pem
```
Register `client-pub.pem` (or its JWK form) with your IdP under the same
`kid` you set in `clientAssertionKeyID`.
**Notes:**
- The private key is parsed once at plugin startup. Key rotation requires a
Traefik reload.
- Assertion lifetime is fixed at 60 seconds.
- A fresh random `jti` is generated per request.
- The `aud` claim is the token endpoint URL (from discovery).
- Tracking issue:
[#135](https://github.com/lukaszraczylo/traefikoidc/issues/135).
### client_secret_basic
Per [RFC 6749 §2.3.1][rfc6749-2-3-1], the plugin sends the client credentials
in an `Authorization: Basic` header instead of the body. Both halves
(`client_id`, `client_secret`) are form-urlencoded individually, joined with
a colon, then base64-encoded. Use this when your IdP requires Basic auth at
the token endpoint and rejects credentials in the body.
```yaml
clientAuthMethod: client_secret_basic
clientID: your-client-id
clientSecret: your-client-secret
```
[rfc6749-2-3-1]: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
---
## Optional Parameters
| Parameter | Type | Default | Description |
@@ -59,6 +183,11 @@ spec:
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
| `clientAuthMethod` | string | `client_secret_post` | Client authentication method at token/revocation endpoints. One of `client_secret_post`, `client_secret_basic`, `private_key_jwt`. See [Client Authentication](#client-authentication). |
| `clientAssertionPrivateKey` | string | none | Inline PEM private key for `private_key_jwt`. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8 / PKCS#1 / SEC1. |
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk for `private_key_jwt`. Mutually exclusive with `clientAssertionPrivateKey`. |
| `clientAssertionKeyID` | string | none | `kid` header for `private_key_jwt` assertions. Required when `clientAuthMethod: private_key_jwt`. |
| `clientAssertionAlg` | string | `RS256` | Signing algorithm for `private_key_jwt`. One of `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. |
### TLS Termination at Load Balancer
@@ -70,6 +199,33 @@ overwrite it).
Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local
dev). Otherwise leave it at default.
### Streaming Endpoints (SSE and WebSocket)
The middleware automatically bypasses the OIDC redirect for two request kinds
that browsers cannot follow a 302 on:
| Bypass | Triggered by |
|--------|--------------|
| Server-Sent Events (SSE) | `Accept: text/event-stream` |
| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) |
These requests do **not** require any explicit configuration — they are
handled implicitly. However, the bypass is **not** unauthenticated:
- A valid, encrypted session cookie is required. Requests without one are
rejected (the connection cannot proceed to the backend).
- The session cookie is sealed with `sessionEncryptionKey`, so the
`authenticated` flag cannot be forged.
- Validation is cookie-only — no JWK fetch / signature verification — so
streaming endpoints keep working when the OIDC provider is briefly
unavailable.
- The user identifier from the session is forwarded as `X-Forwarded-User`
(and `X-Auth-Request-User` unless `minimalHeaders: true`).
For browser clients, the user must complete the normal OIDC flow on a
regular HTTP page first; the resulting session cookie is then reused on the
SSE / WebSocket connection.
---
## Security Options
@@ -113,6 +269,7 @@ strictAudienceValidation: true
|-----------|------|---------|-------------|
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). |
| `cookieDomain` | string | auto-detected | Domain for session cookies |
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
+46 -3
View File
@@ -642,7 +642,7 @@ spec:
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientSecret</code></td>
<td class="py-2 px-3">OAuth 2.0 client secret</td>
<td class="py-2 px-3">OAuth 2.0 client secret. Only required when <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code> is unset or <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_post</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_basic</code>.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">sessionEncryptionKey</code></td>
@@ -718,6 +718,11 @@ spec:
<td class="py-2 px-3">86400</td>
<td class="py-2 px-3">Maximum session age in seconds (24 hours default)</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">maxRefreshTokenAgeSeconds</code></td>
<td class="py-2 px-3">21600</td>
<td class="py-2 px-3">Heuristic upper bound on stored refresh-token lifetime (6 hours default). Past this, the plugin treats the RT as expired without contacting the IdP. Set <code>0</code> to disable.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">cookiePrefix</code></td>
<td class="py-2 px-3">_oidc_raczylo_</td>
@@ -748,15 +753,48 @@ spec:
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Require RFC 7662 introspection for opaque tokens</td>
</tr>
<tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">disableReplayDetection</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Disable JTI replay detection (for multi-replica without Redis)</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code></td>
<td class="py-2 px-3">client_secret_post</td>
<td class="py-2 px-3">Selects how the plugin authenticates to the token endpoint. One of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_post</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_basic</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionPrivateKey</code></td>
<td class="py-2 px-3">none</td>
<td class="py-2 px-3">Inline PEM private key used to sign client assertions for <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionKeyPath</code></td>
<td class="py-2 px-3">none</td>
<td class="py-2 px-3">Path to a PEM private key file. Alternative to <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionPrivateKey</code>.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionKeyID</code></td>
<td class="py-2 px-3">none</td>
<td class="py-2 px-3">JWS <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">kid</code> header value. Required when <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code> is <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
</tr>
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionAlg</code></td>
<td class="py-2 px-3">RS256</td>
<td class="py-2 px-3">Signing algorithm for the client assertion. One of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS512</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS512</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES512</code>.</td>
</tr>
</tbody>
</table>
</div>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Private Key JWT (RFC 7523)</h3>
<p class="text-gray-600 dark:text-gray-400 mb-3 text-sm">Use this when your IdP (Entra ID, Okta, Auth0, Keycloak) pressures short-lived secrets, or when policy mandates secretless service-to-service authentication. The plugin signs a 60-second assertion with the configured private key and sends it as <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_assertion</code> instead of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret</code>. Public-key registration on the IdP replaces shared-secret rotation. See <a href="https://www.rfc-editor.org/rfc/rfc7523" target="_blank" rel="noopener" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 underline">RFC 7523</a> and <a href="https://github.com/lukaszraczylo/traefikoidc/issues/135" target="_blank" rel="noopener" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 underline">issue #135</a>.</p>
<pre class="bg-gray-900 text-gray-100 p-4 rounded-lg overflow-x-auto text-sm"><code>clientAuthMethod: private_key_jwt
clientAssertionKeyPath: /etc/traefik/oidc-client.pem
clientAssertionKeyID: my-client-key-2026
# clientSecret no longer required</code></pre>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Example: Google Workspace with Domain Restriction</h3>
@@ -858,7 +896,12 @@ spec:
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.enableTLS</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Enable TLS for Redis connections</td>
<td class="py-2 px-3">Enable TLS for Redis connections (e.g. AWS ElastiCache in-transit encryption)</td>
</tr>
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.tlsSkipVerify</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Skip TLS server certificate verification (testing only; not recommended in production)</td>
</tr>
</tbody>
</table>
+10
View File
@@ -101,6 +101,16 @@ http:
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
# ----------------------------------------------------------------
# Optional: switch to RFC 7523 private_key_jwt client auth
# (Entra ID, Okta, Auth0, Keycloak). Replaces clientSecret with a
# signed JWT assertion. See README for details and PEM formats.
# ----------------------------------------------------------------
# clientAuthMethod: "private_key_jwt"
# clientAssertionKeyPath: "/etc/traefik/oidc/client-key.pem"
# clientAssertionKeyID: "prod-key-2026"
# clientAssertionAlg: "RS256" # or PS256/384/512, ES256/384/512
# Session Configuration
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
sessionMaxAge: 28800 # 8 hours
+37 -4
View File
@@ -107,9 +107,12 @@ type TokenResponse struct {
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant)
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
"client_id": {t.clientID},
"client_secret": {t.clientSecret},
"grant_type": {grantType},
}
// client_id is sent in the body for every method except client_secret_basic,
// where it is carried in the Authorization header per RFC 6749 §2.3.1.
if t.clientAuthMethod != "client_secret_basic" || t.clientAssertion != nil {
data.Set("client_id", t.clientID)
}
if grantType == "authorization_code" {
@@ -141,16 +144,33 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
}
}
// Read tokenURL with RLock
// Read tokenURL with RLock — needed as audience for private_key_jwt (RFC 7523 §3).
t.metadataMu.RLock()
tokenURL := t.tokenURL
t.metadataMu.RUnlock()
useBasicAuth := false
if t.clientAssertion != nil {
assertion, err := t.clientAssertion.Sign(tokenURL, t.clientID)
if err != nil {
return nil, fmt.Errorf("failed to sign client assertion: %w", err)
}
data.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
data.Set("client_assertion", assertion)
} else if t.clientAuthMethod == "client_secret_basic" {
useBasicAuth = true
} else {
data.Set("client_secret", t.clientSecret)
}
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if useBasicAuth {
setOAuthBasicAuth(req, t.clientID, t.clientSecret)
}
resp, err := client.Do(req)
if err != nil {
@@ -423,6 +443,19 @@ func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (strin
return u.String(), nil
}
// setOAuthBasicAuth sets the Authorization header per RFC 6749 §2.3.1: the
// client_id and client_secret are form-urlencoded individually, joined with a
// colon, then base64-encoded. This differs from http.Request.SetBasicAuth,
// which skips the form-urlencode step — that matters for credentials with
// reserved characters (`:`, `@`, `+`, `%`, etc.) where the wire format would
// otherwise diverge from what the spec mandates.
func setOAuthBasicAuth(req *http.Request, clientID, clientSecret string) {
user := url.QueryEscape(clientID)
pass := url.QueryEscape(clientSecret)
auth := base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
req.Header.Set("Authorization", "Basic "+auth)
}
// deduplicateScopes removes duplicate scopes from a slice while preserving order.
// This ensures that OAuth scope parameters don't contain duplicates which could
// cause issues with some authorization servers.
+3
View File
@@ -24,6 +24,7 @@ type Config struct {
Type BackendType
RedisAddr string
RedisPassword string
TLSServerName string
PoolSize int
RedisDB int
CleanupInterval time.Duration
@@ -34,6 +35,8 @@ type Config struct {
EnableCircuitBreaker bool
EnableHealthCheck bool
EnableMetrics bool
EnableTLS bool
TLSSkipVerify bool
}
// DefaultConfig returns a default configuration for in-memory caching
+73 -26
View File
@@ -20,6 +20,7 @@ type HybridBackend struct {
ctx context.Context
syncWriteCacheTypes map[string]bool
asyncWriteBuffer chan *asyncWriteItem
l1BackfillBuffer chan *l1BackfillItem
cancel context.CancelFunc
wg sync.WaitGroup
l1Hits atomic.Int64
@@ -28,6 +29,7 @@ type HybridBackend struct {
l1Writes atomic.Int64
misses atomic.Int64
l2Hits atomic.Int64
l1BackfillDrops atomic.Int64
fallbackMode atomic.Bool
}
@@ -39,6 +41,15 @@ type asyncWriteItem struct {
ttl time.Duration
}
// l1BackfillItem represents a deferred write of an L2-resolved value back into
// L1. Backfills run on a single bounded worker so a burst of L2 hits cannot
// detonate the goroutine count (issue: ~1000% CPU under sustained polling).
type l1BackfillItem struct {
key string
value []byte
ttl time.Duration
}
// Logger interface for structured logging
type Logger interface {
Debugf(format string, args ...interface{})
@@ -114,6 +125,7 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
secondary: config.Secondary,
syncWriteCacheTypes: config.SyncWriteCacheTypes,
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
l1BackfillBuffer: make(chan *l1BackfillItem, config.AsyncBufferSize),
ctx: ctx,
cancel: cancel,
logger: config.Logger,
@@ -123,6 +135,11 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
h.wg.Add(1)
go h.asyncWriteWorker()
// Start L1 backfill worker (single goroutine) to bound goroutine growth on
// L2 hits regardless of request rate.
h.wg.Add(1)
go h.l1BackfillWorker()
// Start health monitoring
h.wg.Add(1)
go h.healthMonitor()
@@ -223,18 +240,10 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
h.l2Hits.Add(1)
// Populate L1 cache with value from L2 (write-through on read)
// Use goroutine to avoid blocking the read path
go func() {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
} else {
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
}
}()
// Populate L1 cache with value from L2 (write-through on read).
// Hand off to the bounded backfill worker instead of spawning a goroutine
// per read - under burst that would mint thousands of goroutines.
h.queueL1Backfill(key, value, ttl)
return value, ttl, true, nil
}
@@ -371,6 +380,7 @@ func (h *HybridBackend) Close() error {
// Close async write channel
close(h.asyncWriteBuffer)
close(h.l1BackfillBuffer)
// Wait for workers to finish with timeout
done := make(chan struct{})
@@ -440,13 +450,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
for key, value := range l2Results {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
}(key, value)
h.queueL1Backfill(key, value, 0) // 0 = primary backend default TTL
}
}
} else {
@@ -455,13 +459,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte, t time.Duration) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, t)
}(key, value, ttl)
h.queueL1Backfill(key, value, ttl)
}
}
}
@@ -538,6 +536,55 @@ func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, tt
return nil
}
// queueL1Backfill enqueues an L2-resolved value for write-through into L1.
// Drops on full buffer to keep the read path constant-time; the next L2 hit
// for the same key simply re-queues it.
func (h *HybridBackend) queueL1Backfill(key string, value []byte, ttl time.Duration) {
select {
case h.l1BackfillBuffer <- &l1BackfillItem{key: key, value: value, ttl: ttl}:
default:
h.l1BackfillDrops.Add(1)
h.logger.Debugf("L1 backfill buffer full, dropping for key: %s", key)
}
}
// l1BackfillWorker drains the backfill queue serially. Single worker is
// intentional - L1 writes are local and cheap, and serializing them keeps
// goroutine count bounded under any read rate.
func (h *HybridBackend) l1BackfillWorker() {
defer h.wg.Done()
for {
select {
case <-h.ctx.Done():
// Drain remaining items best-effort then exit.
for len(h.l1BackfillBuffer) > 0 {
select {
case item := <-h.l1BackfillBuffer:
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
_ = h.primary.Set(writeCtx, item.key, item.value, item.ttl)
cancel()
default:
return
}
}
return
case item, ok := <-h.l1BackfillBuffer:
if !ok {
return
}
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
if err := h.primary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", item.key, err)
} else {
h.logger.Debugf("Populated L1 cache from L2 for key: %s", item.key)
}
cancel()
}
}
}
// asyncWriteWorker processes asynchronous writes to L2
func (h *HybridBackend) asyncWriteWorker() {
defer h.wg.Done()
+112
View File
@@ -0,0 +1,112 @@
//go:build !yaegi
package backends
import (
"context"
"fmt"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestHybridBackend_L1BackfillBounded verifies that a burst of L2 hits does
// not detonate the goroutine count. Pre-fix the code spawned one goroutine
// per Get() L2 hit; post-fix all backfills funnel through a single worker.
func TestHybridBackend_L1BackfillBounded(t *testing.T) {
primary := newMockBackend()
secondary := newMockBackend()
hybrid, err := NewHybridBackend(&HybridConfig{
Primary: primary,
Secondary: secondary,
AsyncBufferSize: 256,
})
require.NoError(t, err)
defer hybrid.Close()
ctx := context.Background()
const burst = 1000
// Pre-populate L2 with `burst` distinct keys so each Get triggers a
// fresh L1 backfill enqueue.
for i := 0; i < burst; i++ {
require.NoError(t, secondary.Set(ctx, fmt.Sprintf("k:%d", i), []byte("v"), time.Minute))
}
baseline := runtime.NumGoroutine()
// Issue the burst as fast as possible; the backfill worker MUST be the
// only goroutine doing L1 writes. Allow brief slack for the test runtime
// scheduling but anything north of +20 means goroutine leakage.
peak := baseline
for i := 0; i < burst; i++ {
_, _, exists, err := hybrid.Get(ctx, fmt.Sprintf("k:%d", i))
require.NoError(t, err)
require.True(t, exists)
if g := runtime.NumGoroutine(); g > peak {
peak = g
}
}
delta := peak - baseline
if delta > 20 {
t.Fatalf("goroutine count grew by %d during burst (baseline=%d peak=%d); backfill worker not bounding goroutines",
delta, baseline, peak)
}
// L1 must eventually catch up via the worker. Worker drains serially so
// give it a generous window proportional to the burst size.
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
var populated int
for i := 0; i < burst; i++ {
if _, _, ok, _ := primary.Get(ctx, fmt.Sprintf("k:%d", i)); ok {
populated++
}
}
// Be lenient: drops are acceptable under buffer pressure, just want
// most of the keys to make it.
if populated >= burst-int(hybrid.l1BackfillDrops.Load()) {
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("L1 not backfilled within deadline: l2Hits=%d l1Writes=%d drops=%d",
hybrid.l2Hits.Load(), hybrid.l1Writes.Load(), hybrid.l1BackfillDrops.Load())
}
// TestHybridBackend_L1BackfillFullDrops verifies the drop semantics when the
// buffer is saturated. Drops must be counted, never block, never spawn a
// goroutine.
func TestHybridBackend_L1BackfillFullDrops(t *testing.T) {
primary := newMockBackend()
secondary := newMockBackend()
// Tiny buffer + slow primary writes via failSet so the worker stays
// blocked enough to overflow the buffer.
hybrid, err := NewHybridBackend(&HybridConfig{
Primary: primary,
Secondary: secondary,
AsyncBufferSize: 4,
})
require.NoError(t, err)
defer hybrid.Close()
// Stop the worker from draining: cancel the underlying context so the
// worker bails out, leaving us with a cold buffer and the queue method
// itself responsible for drop accounting.
hybrid.cancel()
// Wait for worker to exit so it can't drain.
time.Sleep(50 * time.Millisecond)
for i := 0; i < 50; i++ {
hybrid.queueL1Backfill(fmt.Sprintf("k:%d", i), []byte("v"), time.Minute)
}
assert.Greater(t, hybrid.l1BackfillDrops.Load(), int64(0),
"expected some drops when buffer is saturated and worker is stopped")
}
+3
View File
@@ -49,6 +49,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
poolConfig := &PoolConfig{
Address: config.RedisAddr,
Password: config.RedisPassword,
TLSServerName: config.TLSServerName,
DB: config.RedisDB,
MaxConnections: config.PoolSize,
ConnectTimeout: 2 * time.Second,
@@ -57,6 +58,8 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
EnableHealthCheck: true,
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
EnableTLS: config.EnableTLS,
TLSSkipVerify: config.TLSSkipVerify,
}
pool, err := NewConnectionPool(poolConfig)
+25 -3
View File
@@ -2,6 +2,7 @@ package backends
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
@@ -31,6 +32,7 @@ type ConnectionPool struct {
type PoolConfig struct {
Address string
Password string
TLSServerName string // SNI server name; defaults to host(Address) when empty
DB int
MaxConnections int
ConnectTimeout time.Duration
@@ -39,6 +41,8 @@ type PoolConfig struct {
EnableHealthCheck bool // Enable connection health validation
MaxRetries int // Max retries for failed operations
RetryDelay time.Duration // Initial delay between retries
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
}
// NewConnectionPool creates a new connection pool
@@ -96,7 +100,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
// No available connection, create new one if under limit
// #nosec G115 -- MaxConnections is a small config value that fits in int32
if p.totalConns.Load() < int32(p.config.MaxConnections) {
conn, err = p.createConnection()
conn, err = p.createConnection(ctx)
if err != nil {
// If this is the last attempt, return error
if attempt == maxAttempts-1 {
@@ -193,13 +197,31 @@ func (p *ConnectionPool) Stats() map[string]interface{} {
}
// createConnection creates a new Redis connection
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
// Connect with timeout
dialer := &net.Dialer{
Timeout: p.config.ConnectTimeout,
}
conn, err := dialer.Dial("tcp", p.config.Address)
var conn net.Conn
var err error
if p.config.EnableTLS {
serverName := p.config.TLSServerName
if serverName == "" {
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
serverName = host
}
}
tlsCfg := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
MinVersion: tls.VersionTLS12,
}
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
} else {
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
+230
View File
@@ -0,0 +1,230 @@
package backends
import (
"bufio"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// drainRESPRequest consumes a single RESP request (array or inline) from r and
// returns true on success. Any read error returns false.
func drainRESPRequest(r *bufio.Reader) bool {
header, err := r.ReadString('\n')
if err != nil {
return false
}
if !strings.HasPrefix(header, "*") {
return true // inline command (single line) — already consumed
}
n, err := strconv.Atoi(strings.TrimRight(strings.TrimPrefix(header, "*"), "\r\n"))
if err != nil || n <= 0 {
return false
}
for i := 0; i < n; i++ {
// Each bulk: "$len\r\n<bytes>\r\n"
if _, err := r.ReadString('\n'); err != nil {
return false
}
if _, err := r.ReadString('\n'); err != nil {
return false
}
}
return true
}
// startTLSPingServer spins up a TLS listener that speaks just enough RESP to
// answer PING with +PONG. Returns the listener address and a self-signed cert.
func startTLSPingServer(t *testing.T) (addr string, certPEM []byte, stop func()) {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "localhost"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
require.NoError(t, err)
tlsCert := tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: priv,
}
listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{tlsCert},
MinVersion: tls.VersionTLS12,
})
require.NoError(t, err)
var wg sync.WaitGroup
stopCh := make(chan struct{})
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopCh:
return
default:
}
c, acceptErr := listener.Accept()
if acceptErr != nil {
return
}
wg.Add(1)
go func(conn net.Conn) {
defer wg.Done()
defer conn.Close()
reader := bufio.NewReader(conn)
for {
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if !drainRESPRequest(reader) {
return
}
_, _ = conn.Write([]byte("+PONG\r\n"))
}
}(c)
}
}()
stop = func() {
close(stopCh)
_ = listener.Close()
wg.Wait()
}
return listener.Addr().String(), der, stop
}
// TestConnectionPool_TLSDial_SkipVerify verifies that EnableTLS=true with
// TLSSkipVerify=true successfully negotiates TLS and exchanges a Redis command.
// Regression test for issue #133 (enableTLS not propagated to client).
func TestConnectionPool_TLSDial_SkipVerify(t *testing.T) {
addr, _, stop := startTLSPingServer(t)
defer stop()
pool, err := NewConnectionPool(&PoolConfig{
Address: addr,
MaxConnections: 2,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: true,
})
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn)
defer pool.Put(conn)
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
}
// TestConnectionPool_TLSDial_VerifyFails verifies that EnableTLS=true with
// TLSSkipVerify=false rejects a self-signed server cert.
func TestConnectionPool_TLSDial_VerifyFails(t *testing.T) {
addr, _, stop := startTLSPingServer(t)
defer stop()
pool, err := NewConnectionPool(&PoolConfig{
Address: addr,
MaxConnections: 2,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: false,
})
require.NoError(t, err)
defer pool.Close()
_, err = pool.Get(context.Background())
require.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "tls")
}
// TestConnectionPool_TLSDial_PlainServerRejected verifies that EnableTLS=true
// fails to handshake against a plain (non-TLS) listener.
func TestConnectionPool_TLSDial_PlainServerRejected(t *testing.T) {
plain, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer plain.Close()
go func() {
for {
c, acceptErr := plain.Accept()
if acceptErr != nil {
return
}
_ = c.Close()
}
}()
pool, err := NewConnectionPool(&PoolConfig{
Address: plain.Addr().String(),
MaxConnections: 1,
ConnectTimeout: 1 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: true,
})
require.NoError(t, err)
defer pool.Close()
_, err = pool.Get(context.Background())
require.Error(t, err)
}
// TestConnectionPool_PlainDial_StillWorks ensures non-TLS path is unaffected
// when EnableTLS=false (default).
func TestConnectionPool_PlainDial_StillWorks(t *testing.T) {
mr := NewMiniredisServer(t)
pool, err := NewConnectionPool(&PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: false,
})
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
defer pool.Put(conn)
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
}
+135
View File
@@ -0,0 +1,135 @@
package traefikoidc
import (
"net/http"
"net/http/httptest"
"testing"
)
// TestIssue132_RefreshTokenHonorsUserIdentifierClaim reproduces and verifies
// the fix for issue #132: token refresh path hardcoded the "email" claim and
// ignored the configured userIdentifierClaim. Keycloak users without an email
// claim (using sub or another identifier) were being kicked out on refresh
// even though their initial login worked.
//
// The callback path (auth_flow.go) already honored userIdentifierClaim with
// "sub" fallback. The refresh path (token_manager.go) had drifted out of sync
// after PR #100 (commit a316a98).
func TestIssue132_RefreshTokenHonorsUserIdentifierClaim(t *testing.T) {
tests := []struct {
claims map[string]any
name string
userIdentifierClaim string
expectedIdentifier string
expectSuccess bool
}{
{
name: "sub claim configured, only sub present (Keycloak no-email case)",
userIdentifierClaim: "sub",
claims: map[string]any{
"sub": "user-uuid-keycloak-12345",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "user-uuid-keycloak-12345",
},
{
name: "preferred_username configured, claim present",
userIdentifierClaim: "preferred_username",
claims: map[string]any{
"sub": "user-uuid-12345",
"preferred_username": "alice",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "alice",
},
{
name: "configured claim missing, falls back to sub",
userIdentifierClaim: "preferred_username",
claims: map[string]any{
"sub": "fallback-sub-id",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "fallback-sub-id",
},
{
name: "email default, email present (backward compatibility)",
userIdentifierClaim: "email",
claims: map[string]any{
"sub": "user-uuid-12345",
"email": "user@example.com",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "user@example.com",
},
{
name: "email default, no email and no sub - refresh fails",
userIdentifierClaim: "email",
claims: map[string]any{
"exp": float64(9999999999),
},
expectSuccess: false,
expectedIdentifier: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sessionManager, err := NewSessionManager(
"test-encryption-key-32-bytes-long!!",
false,
"",
"",
0,
NewLogger("error"),
)
if err != nil {
t.Fatalf("session manager: %v", err)
}
defer sessionManager.Shutdown()
capturedClaims := tt.claims
tOidc := &TraefikOidc{
logger: NewLogger("error"),
userIdentifierClaim: tt.userIdentifierClaim,
sessionManager: sessionManager,
tokenExchanger: &EnhancedMockTokenExchanger{
RefreshResponse: &TokenResponse{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
IDToken: "new-id-token-jwt",
ExpiresIn: 3600,
},
},
tokenVerifier: &EnhancedMockTokenVerifier{Err: nil},
extractClaimsFunc: func(token string) (map[string]any, error) {
return capturedClaims, nil
},
}
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("get session: %v", err)
}
defer session.returnToPoolSafely()
session.SetRefreshToken("initial-refresh-token")
refreshed := tOidc.refreshToken(rw, req, session)
if refreshed != tt.expectSuccess {
t.Fatalf("refreshToken() = %v, want %v", refreshed, tt.expectSuccess)
}
if got := session.GetUserIdentifier(); got != tt.expectedIdentifier {
t.Errorf("session.GetUserIdentifier() = %q, want %q", got, tt.expectedIdentifier)
}
})
}
}
+449
View File
@@ -0,0 +1,449 @@
package traefikoidc
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"math/big"
"net/http"
"testing"
"time"
"github.com/gorilla/sessions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
// signGraphStyleAccessToken builds a JWT in Microsoft's Graph proprietary
// nonce-header form: bytes that get signed contain the SHA256 hash of the
// nonce, while the wire token ships the original nonce. A standard JWS
// verifier always rejects these with `crypto/rsa: verification error`, which
// is why Microsoft documents Graph access tokens as opaque to client apps:
//
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
// "you can't validate tokens for Microsoft Graph according to these rules
// due to their proprietary format"
func signGraphStyleAccessToken(t *testing.T, key *rsa.PrivateKey, kid, originalNonce string, claims map[string]any) string {
t.Helper()
wireHeader := map[string]any{
"alg": "RS256",
"kid": kid,
"typ": "JWT",
"nonce": originalNonce,
}
wireHeaderJSON, err := json.Marshal(wireHeader)
require.NoError(t, err)
hashed := sha256.Sum256([]byte(originalNonce))
signedHeader := map[string]any{
"alg": "RS256",
"kid": kid,
"typ": "JWT",
"nonce": fmt.Sprintf("%x", hashed),
}
signedHeaderJSON, err := json.Marshal(signedHeader)
require.NoError(t, err)
claimsJSON, err := json.Marshal(claims)
require.NoError(t, err)
wireHeaderB64 := base64.RawURLEncoding.EncodeToString(wireHeaderJSON)
signedHeaderB64 := base64.RawURLEncoding.EncodeToString(signedHeaderJSON)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
signedInput := signedHeaderB64 + "." + claimsB64
hSign := sha256.Sum256([]byte(signedInput))
sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, hSign[:])
require.NoError(t, err)
return wireHeaderB64 + "." + claimsB64 + "." + base64.RawURLEncoding.EncodeToString(sig)
}
// newAzureFollowupOIDC produces a TraefikOidc instance wired for an Azure
// AD tenant with a captured error log buffer. Used by the issue #134 followup
// tests to assert log behavior during validateAzureTokens flows.
func newAzureFollowupOIDC(t *testing.T, jwks *JWKSet) (*TraefikOidc, *bytes.Buffer) {
t.Helper()
tc := newTestCleanup(t)
errBuf := &bytes.Buffer{}
logger := &Logger{
logError: log.New(errBuf, "", 0),
logInfo: log.New(io.Discard, "", 0),
logDebug: log.New(io.Discard, "", 0),
}
tokenCache := tc.addTokenCache(NewTokenCache())
tokenBlacklist := tc.addCache(NewCache())
oidc := &TraefikOidc{
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
clientID: "test-client-id",
audience: "test-client-id",
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
limiter: rate.NewLimiter(rate.Every(time.Second), 100),
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
jwkCache: &MockJWKCache{JWKS: jwks},
tokenCache: tokenCache,
tokenBlacklist: tokenBlacklist,
extractClaimsFunc: extractClaims,
}
oidc.tokenVerifier = oidc
oidc.jwtVerifier = oidc
require.True(t, oidc.isAzureProvider(), "fixture must be detected as Azure provider")
return oidc, errBuf
}
// authedSessionWithTokens returns a SessionData populated with the supplied
// access and ID tokens, marked authenticated and recently created. The
// SessionManager carries a real ChunkManager so that GetAccessToken /
// GetIDToken / GetRefreshToken behave like the production code path.
func authedSessionWithTokens(t *testing.T, accessToken, idToken string) *SessionData {
t.Helper()
chunkLogger := NewLogger("error")
chunkManager := NewChunkManager(chunkLogger)
t.Cleanup(chunkManager.Shutdown)
sd := CreateMockSessionData()
sd.manager = &SessionManager{
sessionMaxAge: 24 * time.Hour,
chunkManager: chunkManager,
logger: chunkLogger,
}
sd.mainSession = sessions.NewSession(nil, "main")
sd.mainSession.Values["authenticated"] = true
sd.mainSession.Values["created_at"] = time.Now().Unix()
sd.accessSession = sessions.NewSession(nil, "access")
sd.accessSession.Values["token"] = accessToken
sd.accessSession.Values["compressed"] = false
sd.idTokenSession = sessions.NewSession(nil, "id")
sd.idTokenSession.Values["token"] = idToken
sd.idTokenSession.Values["compressed"] = false
sd.refreshSession = sessions.NewSession(nil, "refresh")
sd.refreshSession.Values["token"] = ""
sd.refreshSession.Values["compressed"] = false
return sd
}
// TestIssue134_Followup_GraphAccessTokenReproducesUsersError sanity-checks
// that our crafted Graph-style token reproduces the exact rsa error string
// quoted on the issue thread (dada-engineer 2026-05-08, friek 2026-05-11).
//
// Sanity test: must always pass, regardless of the issue #134 followup fix.
// It exists so a future contributor does not accidentally weaken the
// reproducer and assume the followup fix is no longer needed.
func TestIssue134_Followup_GraphAccessTokenReproducesUsersError(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-followup-kid"
graphToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce", map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "00000003-0000-0000-c000-000000000000",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "user-azure-id",
"scp": "User.Read",
})
parsedJWT, err := parseJWT(graphToken)
require.NoError(t, err)
pubKey := &rsaKey.PublicKey
alg, _ := parsedJWT.Header["alg"].(string)
verifyErr := verifySignatureWithKey(graphToken, pubKey, alg)
require.Error(t, verifyErr)
assert.Contains(t, verifyErr.Error(), "crypto/rsa: verification error",
"reproducer must emit the exact error string reported on issue #134")
}
// TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken is the
// failing-then-passing test for the followup fix.
//
// Symptom (before fix): validateAzureTokens calls verifyToken on every
// JWT-shaped access token. For Microsoft Graph access tokens (the default
// when no custom resource is registered), verification always fails with
// `crypto/rsa: verification error`, generating two error log lines per
// request:
//
// UNKNOWN token verification failed: signature verification failed:
// crypto/rsa: verification error
// DIAGNOSTIC: Signature verification failed for kid=<kid>, alg=RS256:
// crypto/rsa: verification error
//
// Microsoft's own documentation tells client apps not to validate Graph
// access tokens. The fix matches that guidance: when an Azure access token
// carries Microsoft's proprietary `nonce` JWT header, treat it as opaque
// (skip JWT verification, fall through to ID token validation).
func TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-followup-kid"
jwk := JWK{
Kty: "RSA",
Use: "sig",
Alg: "RS256",
Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
}
jwks := &JWKSet{Keys: []JWK{jwk}}
now := time.Now()
exp := now.Add(time.Hour).Unix()
graphAccessToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce-azure-graph", map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "00000003-0000-0000-c000-000000000000",
"exp": exp,
"iat": now.Unix(),
"sub": "user-azure-id",
"appid": "test-client-id",
"scp": "User.Read",
})
idToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": exp,
"iat": now.Add(-2 * time.Minute).Unix(),
"nbf": now.Add(-2 * time.Minute).Unix(),
"sub": "user-azure-id",
"email": "user@example.com",
"nonce": "id-token-oidc-nonce",
"jti": "id-token-jti-followup",
})
require.NoError(t, err)
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, graphAccessToken, idToken)
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
output := errBuf.String()
assert.NotContains(t, output, "crypto/rsa: verification error",
"validateAzureTokens must not log rsa verification error for Graph-style access tokens; got: %q", output)
assert.NotContains(t, output, "DIAGNOSTIC: Signature verification failed",
"DIAGNOSTIC line must not fire for Graph-style access tokens; got: %q", output)
assert.NotContains(t, output, "UNKNOWN token verification failed",
"UNKNOWN classification log must not fire for Graph-style access tokens; got: %q", output)
assert.True(t, authenticated, "session must remain authenticated via the ID token fallback")
assert.False(t, needsRefresh, "valid ID token must not signal a refresh need")
assert.False(t, expired, "valid ID token must not be reported as expired")
}
// TestIssue134_Followup_IsUnverifiableAzureAccessToken_Detection covers the
// classifier added by the followup fix. Pure-function unit test for the
// Microsoft proprietary marker we rely on (nonce in JWT header).
func TestIssue134_Followup_IsUnverifiableAzureAccessToken_Detection(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-detection-kid"
standardToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "user-azure-id",
})
require.NoError(t, err)
graphToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce", map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "00000003-0000-0000-c000-000000000000",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "user-azure-id",
"scp": "User.Read",
})
oidc, _ := newAzureFollowupOIDC(t, &JWKSet{})
cases := []struct {
name string
token string
wantUnverified bool
}{
{name: "standard JWT without nonce header", token: standardToken, wantUnverified: false},
{name: "Microsoft proprietary token (nonce in header)", token: graphToken, wantUnverified: true},
{name: "garbage token treated as unverifiable", token: "not-a-jwt-at-all", wantUnverified: true},
{name: "empty token treated as unverifiable", token: "", wantUnverified: true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := oidc.isUnverifiableAzureAccessToken(tc.token)
assert.Equal(t, tc.wantUnverified, got)
})
}
}
// TestIssue134_Followup_StandardAzureAccessTokenStillVerifies guards against
// regression in the happy path: an access token issued for our own clientID
// (custom Azure-registered API) — no proprietary nonce header, signed normally
// — must still flow through the standard verification path and authenticate.
func TestIssue134_Followup_StandardAzureAccessTokenStillVerifies(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-standard-kid"
jwk := JWK{
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
}
jwks := &JWKSet{Keys: []JWK{jwk}}
now := time.Now()
exp := now.Add(time.Hour).Unix()
// Custom-resource access token: aud points to the app, no nonce header.
accessToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": exp,
"iat": now.Add(-2 * time.Minute).Unix(),
"nbf": now.Add(-2 * time.Minute).Unix(),
"sub": "user-azure-id",
"scp": "api.read",
"jti": "standard-access-jti",
})
require.NoError(t, err)
idToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": exp,
"iat": now.Add(-2 * time.Minute).Unix(),
"nbf": now.Add(-2 * time.Minute).Unix(),
"sub": "user-azure-id",
"email": "user@example.com",
"nonce": "id-token-oidc-nonce",
"jti": "standard-id-jti",
})
require.NoError(t, err)
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, accessToken, idToken)
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
assert.True(t, authenticated, "standard Azure access token must verify and authenticate")
assert.False(t, needsRefresh)
assert.False(t, expired)
assert.NotContains(t, errBuf.String(), "crypto/rsa: verification error",
"standard Azure token must not produce signature errors")
}
// TestIssue134_Followup_GraphAccessTokenWithoutIDToken covers the edge where
// the session has only a Graph access token (no ID token). The classifier must
// preserve the existing "treat as opaque" semantics for backward compatibility:
// authenticated=true even when there is no ID token to verify.
func TestIssue134_Followup_GraphAccessTokenWithoutIDToken(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-no-idt-kid"
jwk := JWK{
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
}
jwks := &JWKSet{Keys: []JWK{jwk}}
graphAccessToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce-no-idt", map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "00000003-0000-0000-c000-000000000000",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "user-azure-id",
"scp": "User.Read",
})
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, graphAccessToken, "")
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
assert.True(t, authenticated, "Graph token without ID token must remain authenticated (matches existing opaque-token semantics)")
assert.False(t, needsRefresh)
assert.False(t, expired)
assert.NotContains(t, errBuf.String(), "crypto/rsa: verification error")
}
// TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification proves
// the classifier is not a security regression. An attacker who forges a JWT
// with a `nonce` JWT header (Microsoft's proprietary marker) but a payload
// claiming `aud=our-clientID` should NOT gain authenticated status simply by
// triggering the "treat as opaque" branch.
//
// This is the confused-deputy guardrail Microsoft warns about
// (https://cwe.mitre.org/data/definitions/441.html): we treat the access token
// as opaque, which means we DO NOT authorize from it — authorization comes
// only from a separately verifiable ID token. An attacker without a valid ID
// token must not be authenticated.
func TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
attackerKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-attack-kid"
jwk := JWK{
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
}
jwks := &JWKSet{Keys: []JWK{jwk}}
// Forged: attacker uses their OWN key, sets aud = our clientID, plants a
// `nonce` header to trip the opaque-detection path.
forgedAccessToken := signGraphStyleAccessToken(t, attackerKey, kid, "attacker-nonce", map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "attacker",
"scp": "admin",
})
// Forged ID token signed with the attacker's key — must fail verification
// against the tenant JWKS.
forgedIDToken, err := createTestJWT(attackerKey, "RS256", kid, map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Minute).Unix(),
"nbf": time.Now().Add(-2 * time.Minute).Unix(),
"sub": "attacker",
"email": "attacker@evil.example",
"nonce": "id-token-oidc-nonce",
"jti": "attacker-id-jti",
})
require.NoError(t, err)
oidc, _ := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, forgedAccessToken, forgedIDToken)
authenticated, _, _ := oidc.validateAzureTokens(session)
assert.False(t, authenticated,
"attacker's forged tokens must not authenticate even when the access token has a nonce header — ID token verification rejects the wrong-key signature")
}
+256
View File
@@ -0,0 +1,256 @@
package traefikoidc
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"log"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError reproduces and
// verifies the fix for issue #134.
//
// Symptom (before fix): with a Redis backend wired into UniversalCache,
// caching the parsed *parsedJWKS triggered:
//
// json: cannot unmarshal number 2251513...
// into Go value of type float64
//
// Root cause: under yaegi, json.Marshal of a struct exposes unexported
// fields with an X-prefixed name. parsedJWKS{ keys map[string]crypto.PublicKey }
// thus serialized the inner *rsa.PublicKey, whose modulus *big.Int marshals
// as a JSON number hundreds of digits long. On read, json.Unmarshal into
// interface{} parses numbers as float64, which cannot represent that range.
// The user saw the error log on every request even though auth still worked
// (fallback path rebuilt the keys in memory).
//
// Fix: route both *JWKSet and *parsedJWKS through SetLocal/GetLocal — the
// distributed backend never sees them.
func TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError(t *testing.T) {
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
redisCfg := backends.DefaultRedisConfig(mr.Addr())
redisCfg.RedisPrefix = "issue134:"
backend, err := backends.NewRedisBackend(redisCfg)
require.NoError(t, err)
defer backend.Close()
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-test-kid"
jwk := JWK{
Kty: "RSA",
Use: "sig",
Alg: "RS256",
Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
}
var fetchCount int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&fetchCount, 1)
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
}))
defer server.Close()
errBuf := &bytes.Buffer{}
infoBuf := &bytes.Buffer{}
logger := &Logger{
logError: log.New(errBuf, "", 0),
logInfo: log.New(infoBuf, "", 0),
logDebug: log.New(io.Discard, "", 0),
}
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeJWK,
MaxSize: 100,
Logger: logger,
}, backend)
defer cache.Close()
jwkCache := &JWKCache{cache: cache}
ctx := context.Background()
pub1, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
require.NoError(t, err, "first GetPublicKey should succeed")
require.NotNil(t, pub1)
gotRSA, ok := pub1.(*rsa.PublicKey)
require.True(t, ok, "returned key should be *rsa.PublicKey, got %T", pub1)
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N), "modulus must survive intact")
assert.Equal(t, rsaKey.E, gotRSA.E, "exponent must survive intact")
pub2, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
require.NoError(t, err, "second GetPublicKey should succeed")
require.True(t, samePublicKey(pub1, pub2), "second call must return the same parsed key (cache hit)")
assert.Equal(t, int32(1), atomic.LoadInt32(&fetchCount),
"upstream JWKS endpoint must be hit exactly once; second call must be served from local cache")
errOutput := errBuf.String()
assert.NotContains(t, errOutput, "Failed to deserialize",
"deserialize error must not appear with the fix in place; got: %s", errOutput)
assert.NotContains(t, errOutput, "into Go value of type float64",
"float64 unmarshal error must not appear; got: %s", errOutput)
parsedKey := server.URL + parsedKeysSuffix
jwksKey := server.URL
for _, k := range []string{cache.prefixKey(parsedKey), cache.prefixKey(jwksKey)} {
fullKey := redisCfg.RedisPrefix + k
assert.False(t, mr.Exists(fullKey),
"key %q must not exist in Redis (local-only caching); got %v", fullKey, mr.Keys())
}
}
// TestIssue134_StalePoisonedRedisDataIgnored verifies that pre-existing bad
// data left in Redis under a JWK :parsed key from a prior buggy version is
// ignored: the local-only fix never reads that key, so no log spam, and the
// fallback path returns a real *rsa.PublicKey.
func TestIssue134_StalePoisonedRedisDataIgnored(t *testing.T) {
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
redisCfg := backends.DefaultRedisConfig(mr.Addr())
redisCfg.RedisPrefix = "issue134stale:"
backend, err := backends.NewRedisBackend(redisCfg)
require.NoError(t, err)
defer backend.Close()
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-test-kid"
jwk := JWK{
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
}))
defer server.Close()
// Pre-poison Redis with the kind of payload the old buggy path would have
// produced (huge unquoted JSON number for the modulus). With the fix the
// JWKCache must not even read this key.
poisoned := []byte("\x01" + strings.Replace(
`{"Xkeys":{"azure-test-kid":{"N":NUMBER,"E":65537}}}`,
"NUMBER", rsaKey.N.String(), 1,
))
parsedRedisKey := redisCfg.RedisPrefix + "jwk:" + server.URL + parsedKeysSuffix
require.NoError(t, mr.Set(parsedRedisKey, string(poisoned)))
errBuf := &bytes.Buffer{}
logger := &Logger{
logError: log.New(errBuf, "", 0),
logInfo: log.New(io.Discard, "", 0),
logDebug: log.New(io.Discard, "", 0),
}
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeJWK,
MaxSize: 100,
Logger: logger,
}, backend)
defer cache.Close()
jwkCache := &JWKCache{cache: cache}
pub, err := jwkCache.GetPublicKey(context.Background(), server.URL, kid, http.DefaultClient)
require.NoError(t, err)
require.NotNil(t, pub)
gotRSA, ok := pub.(*rsa.PublicKey)
require.True(t, ok)
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N))
assert.NotContains(t, errBuf.String(), "Failed to deserialize",
"poisoned Redis entry must not be touched; got error log: %s", errBuf.String())
}
// TestIssue134_SetLocalGetLocalSkipBackend verifies the new SetLocal/GetLocal
// pair never reads or writes the configured backend.
func TestIssue134_SetLocalGetLocalSkipBackend(t *testing.T) {
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
redisCfg := backends.DefaultRedisConfig(mr.Addr())
redisCfg.RedisPrefix = "local:"
backend, err := backends.NewRedisBackend(redisCfg)
require.NoError(t, err)
defer backend.Close()
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 10,
Logger: GetSingletonNoOpLogger(),
}, backend)
defer cache.Close()
type unsafeShape struct {
hidden map[string]interface{}
}
val := &unsafeShape{hidden: map[string]interface{}{"k": 1}}
require.NoError(t, cache.SetLocal("local-key", val, 1*time.Hour))
got, found := cache.GetLocal("local-key")
require.True(t, found)
assert.Same(t, val, got, "GetLocal must return the exact pointer stored, no JSON round-trip")
for _, k := range mr.Keys() {
assert.NotContains(t, k, "local-key",
"SetLocal must not write to Redis; found key %q (all keys: %v)", k, mr.Keys())
}
cache.mu.Lock()
delete(cache.items, "local-key")
cache.lruList.Init()
cache.currentSize = 0
cache.currentMemory = 0
cache.mu.Unlock()
_, found = cache.GetLocal("local-key")
assert.False(t, found, "GetLocal must not fall back to backend after local cache cleared")
}
// big2bytes returns the big-endian byte slice for a positive int.
func big2bytes(e int) []byte {
if e <= 0 {
return []byte{}
}
var buf []byte
for e > 0 {
buf = append([]byte{byte(e & 0xff)}, buf...)
e >>= 8
}
return buf
}
// samePublicKey reports whether two crypto.PublicKey instances represent the
// same RSA key, used to confirm cache hits return identical reconstructed
// keys.
func samePublicKey(a, b interface{}) bool {
ar, ok1 := a.(*rsa.PublicKey)
br, ok2 := b.(*rsa.PublicKey)
if !ok1 || !ok2 {
return false
}
return ar.N.Cmp(br.N) == 0 && ar.E == br.E
}
+925
View File
@@ -0,0 +1,925 @@
package traefikoidc
// issue135_regression_test.go — regression tests for RFC 7523 private_key_jwt
// client authentication (issue #135).
//
// These tests guard:
// - Correct JWT construction and cryptographic signature for all supported
// algorithms (RS*/PS*/ES*).
// - Proper validation of alg/key type combinations and empty-kid rejection.
// - JTI uniqueness across concurrent calls.
// - PEM variant tolerance (PKCS#8, PKCS#1, SEC1).
// - Config.Validate() behavior for all private_key_jwt configuration paths.
// - buildClientAssertionSignerFromConfig: inline PEM, file-backed PEM, default alg.
// - Wire-up in exchangeTokens: assertion fields sent, client_secret absent.
// - Wire-up in RevokeTokenWithProvider: assertion fields sent, audience = tokenURL.
// - Back-compat: client_secret_post path unchanged when clientAssertion == nil.
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ── A. Signer unit tests ──────────────────────────────────────────────────────
// TestIssue135_SignerRSAFamily verifies that NewClientAssertionSigner + Sign
// produces a well-formed, cryptographically valid JWT for every RSA-family
// algorithm (RS256/RS384/RS512/PS256/PS384/PS512).
func TestIssue135_SignerRSAFamily(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
cases := []struct {
alg string
hashFn func([]byte) []byte
isPS bool
hash crypto.Hash
}{
{"RS256", func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, false, crypto.SHA256},
{"RS384", func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, false, crypto.SHA384},
{"RS512", func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, false, crypto.SHA512},
{"PS256", func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, true, crypto.SHA256},
{"PS384", func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, true, crypto.SHA384},
{"PS512", func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, true, crypto.SHA512},
}
const (
audience = "https://example.com/token"
clientID = "client-abc"
kid = "kid-1"
)
for _, tc := range cases {
t.Run(tc.alg, func(t *testing.T) {
signer, err := NewClientAssertionSigner(pemBytes, tc.alg, kid)
require.NoError(t, err)
jwtStr, err := signer.Sign(audience, clientID)
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3, "JWT must have three dot-separated parts")
// Decode and check header.
hdr := decodeJSONPart(t, parts[0])
assert.Equal(t, tc.alg, hdr["alg"])
assert.Equal(t, "JWT", hdr["typ"])
assert.Equal(t, kid, hdr["kid"])
// Decode and check claims.
clms := decodeJSONPart(t, parts[1])
assert.Equal(t, clientID, clms["iss"])
assert.Equal(t, clientID, clms["sub"])
assert.Equal(t, audience, clms["aud"])
iat, ok := clms["iat"].(float64)
require.True(t, ok, "iat must be numeric")
exp, ok := clms["exp"].(float64)
require.True(t, ok, "exp must be numeric")
assert.InDelta(t, 60, exp-iat, 2, "exp-iat must equal ~60s")
now := float64(time.Now().Unix())
assert.True(t, iat <= now+2 && iat >= now-5, "iat must be current time ±5s")
jti, ok := clms["jti"].(string)
require.True(t, ok, "jti must be a string")
assert.Len(t, jti, 32, "jti must be 32-char hex (16 bytes → hex)")
// Verify cryptographic signature.
sigInput := parts[0] + "." + parts[1]
digest := tc.hashFn([]byte(sigInput))
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
pub := &rsaKey.PublicKey
if tc.isPS {
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: tc.hash}
assert.NoError(t, rsa.VerifyPSS(pub, tc.hash, digest, sigBytes, opts),
"PSS signature verification failed for %s", tc.alg)
} else {
assert.NoError(t, rsa.VerifyPKCS1v15(pub, tc.hash, digest, sigBytes),
"PKCS1v15 signature verification failed for %s", tc.alg)
}
})
}
}
// TestIssue135_SignerECDSAFamily verifies correct JWT production for all
// ECDSA algorithms (ES256/ES384/ES512) including that the signature is the
// raw r||s encoding (not ASN.1 DER) and is verifiable with the matching key.
func TestIssue135_SignerECDSAFamily(t *testing.T) {
cases := []struct {
alg string
curve elliptic.Curve
hashFn func([]byte) []byte
hash crypto.Hash
}{
{"ES256", elliptic.P256(), func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, crypto.SHA256},
{"ES384", elliptic.P384(), func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, crypto.SHA384},
{"ES512", elliptic.P521(), func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, crypto.SHA512},
}
const (
audience = "https://idp.example.com/token"
clientID = "ec-client"
kid = "ec-kid"
)
for _, tc := range cases {
t.Run(tc.alg, func(t *testing.T) {
ecKey, err := ecdsa.GenerateKey(tc.curve, rand.Reader)
require.NoError(t, err)
pemBytes := encodeECPKCS8(t, ecKey)
signer, err := NewClientAssertionSigner(pemBytes, tc.alg, kid)
require.NoError(t, err)
jwtStr, err := signer.Sign(audience, clientID)
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
byteLen := (tc.curve.Params().BitSize + 7) / 8
assert.Len(t, sigBytes, 2*byteLen,
"ECDSA signature must be raw r||s (2×%d bytes for %s)", byteLen, tc.alg)
r := new(big.Int).SetBytes(sigBytes[:byteLen])
s := new(big.Int).SetBytes(sigBytes[byteLen:])
sigInput := parts[0] + "." + parts[1]
digest := tc.hashFn([]byte(sigInput))
ok := ecdsa.Verify(&ecKey.PublicKey, digest, r, s)
assert.True(t, ok, "ECDSA signature verification failed for %s", tc.alg)
})
}
}
// TestIssue135_SignerRejectsAlgKeyMismatch verifies that the signer constructor
// rejects type mismatches between key type and algorithm, unknown algorithms,
// and an empty kid.
func TestIssue135_SignerRejectsAlgKeyMismatch(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
rsaPEM := encodeRSAPKCS8(t, rsaKey)
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
ecPEM := encodeECPKCS8(t, ecKey)
cases := []struct {
name string
pemBytes []byte
alg string
kid string
wantErr string
}{
{
name: "RSA key with ES256",
pemBytes: rsaPEM,
alg: "ES256",
kid: "k1",
wantErr: "EC key",
},
{
name: "EC key with RS256",
pemBytes: ecPEM,
alg: "RS256",
kid: "k1",
wantErr: "RSA key",
},
{
name: "unknown alg HS256",
pemBytes: rsaPEM,
alg: "HS256",
kid: "k1",
wantErr: "unsupported",
},
{
name: "empty kid",
pemBytes: rsaPEM,
alg: "RS256",
kid: "",
wantErr: "kid must not be empty",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, err := NewClientAssertionSigner(tc.pemBytes, tc.alg, tc.kid)
require.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), strings.ToLower(tc.wantErr),
"error should mention %q", tc.wantErr)
})
}
}
// TestIssue135_SignerJTIUniqueness signs 50 assertions with the same signer
// and asserts all jti values are distinct. Guards against broken entropy reuse.
func TestIssue135_SignerJTIUniqueness(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "jti-kid")
require.NoError(t, err)
seen := make(map[string]bool, 50)
for i := range 50 {
jwtStr, err := signer.Sign("https://example.com/token", "client-x")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
clms := decodeJSONPart(t, parts[1])
jti, ok := clms["jti"].(string)
require.True(t, ok)
assert.False(t, seen[jti], "jti %q was reused at iteration %d", jti, i)
seen[jti] = true
}
}
// TestIssue135_SignerPEMVariants confirms that all PEM block types understood
// by NewClientAssertionSigner are parsed correctly: PKCS#8 ("PRIVATE KEY"),
// PKCS#1 ("RSA PRIVATE KEY"), and SEC1 ("EC PRIVATE KEY").
func TestIssue135_SignerPEMVariants(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
t.Run("RSA PKCS8", func(t *testing.T) {
pemBytes := encodeRSAPKCS8(t, rsaKey)
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "k1")
require.NoError(t, err)
assertValidRSAJWT(t, rsaKey, signer, "RS256")
})
t.Run("RSA PKCS1", func(t *testing.T) {
der := x509.MarshalPKCS1PrivateKey(rsaKey)
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: der})
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "k1")
require.NoError(t, err)
assertValidRSAJWT(t, rsaKey, signer, "RS256")
})
t.Run("EC PKCS8", func(t *testing.T) {
pemBytes := encodeECPKCS8(t, ecKey)
signer, err := NewClientAssertionSigner(pemBytes, "ES256", "k1")
require.NoError(t, err)
jwtStr, err := signer.Sign("https://example.com/token", "cid")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
})
t.Run("EC SEC1", func(t *testing.T) {
der, err := x509.MarshalECPrivateKey(ecKey)
require.NoError(t, err)
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})
signer, err := NewClientAssertionSigner(pemBytes, "ES256", "k1")
require.NoError(t, err)
jwtStr, err := signer.Sign("https://example.com/token", "cid")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
})
}
// ── B. Config validation ──────────────────────────────────────────────────────
// TestIssue135_ConfigValidation table-drives Config.Validate() for every
// client-authentication-related validation branch.
func TestIssue135_ConfigValidation(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
validPEM := string(encodeRSAPKCS8(t, rsaKey))
// baseConfig returns the minimum valid config, modified per test case.
base := func() *Config {
return &Config{
ProviderURL: "https://idp.example.com",
CallbackURL: "/cb",
ClientID: "cid",
ClientSecret: "secret",
SessionEncryptionKey: "01234567890123456789012345678901", // 32 chars
RateLimit: 100,
}
}
cases := []struct {
name string
mutate func(*Config)
wantErr string // empty = expect nil error
}{
{
name: "default empty method + secret ok",
mutate: func(c *Config) { /* nothing extra */ },
wantErr: "",
},
{
name: "explicit client_secret_post + secret ok",
mutate: func(c *Config) {
c.ClientAuthMethod = "client_secret_post"
},
wantErr: "",
},
{
name: "private_key_jwt inline key + kid ok",
mutate: func(c *Config) {
c.ClientAuthMethod = "private_key_jwt"
c.ClientSecret = ""
c.ClientAssertionPrivateKey = validPEM
c.ClientAssertionKeyID = "k1"
},
wantErr: "",
},
{
name: "private_key_jwt no key at all",
mutate: func(c *Config) {
c.ClientAuthMethod = "private_key_jwt"
c.ClientSecret = ""
c.ClientAssertionKeyID = "k1"
},
wantErr: "clientAssertionPrivateKey",
},
{
name: "private_key_jwt both inline and path",
mutate: func(c *Config) {
c.ClientAuthMethod = "private_key_jwt"
c.ClientSecret = ""
c.ClientAssertionPrivateKey = validPEM
c.ClientAssertionKeyPath = "/tmp/key.pem"
c.ClientAssertionKeyID = "k1"
},
wantErr: "only one of",
},
{
name: "private_key_jwt key but no kid",
mutate: func(c *Config) {
c.ClientAuthMethod = "private_key_jwt"
c.ClientSecret = ""
c.ClientAssertionPrivateKey = validPEM
},
wantErr: "clientAssertionKeyID",
},
{
name: "private_key_jwt unsupported alg HS256",
mutate: func(c *Config) {
c.ClientAuthMethod = "private_key_jwt"
c.ClientSecret = ""
c.ClientAssertionPrivateKey = validPEM
c.ClientAssertionKeyID = "k1"
c.ClientAssertionAlg = "HS256"
},
wantErr: "is not supported",
},
{
name: "unknown client auth method",
mutate: func(c *Config) {
c.ClientAuthMethod = "weird"
},
wantErr: "is not supported",
},
{
name: "client_secret_post with no secret",
mutate: func(c *Config) {
c.ClientAuthMethod = "client_secret_post"
c.ClientSecret = ""
},
wantErr: "clientSecret is required",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
cfg := base()
tc.mutate(cfg)
err := cfg.Validate()
if tc.wantErr == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.wantErr,
"error must mention %q", tc.wantErr)
}
})
}
}
// TestIssue135_ConfigKeyPathLoadsFile verifies that buildClientAssertionSignerFromConfig
// reads the PEM key from disk when ClientAssertionKeyPath is set.
func TestIssue135_ConfigKeyPathLoadsFile(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
dir := t.TempDir()
keyFile := dir + "/private.pem"
require.NoError(t, os.WriteFile(keyFile, pemBytes, 0o600))
cfg := &Config{
ClientAuthMethod: "private_key_jwt",
ClientAssertionKeyPath: keyFile,
ClientAssertionKeyID: "file-kid",
ClientAssertionAlg: "RS256",
}
signer, err := buildClientAssertionSignerFromConfig(cfg)
require.NoError(t, err, "should load signer from key file")
require.NotNil(t, signer)
// Confirm signer produces a valid JWT.
jwtStr, err := signer.Sign("https://example.com/token", "client-from-file")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3, "should produce a 3-part JWT")
}
// ── C. Wire-up — exchangeTokens ───────────────────────────────────────────────
// TestIssue135_AuthCodeExchangeUsesAssertion confirms that exchangeTokens sends
// client_assertion + client_assertion_type instead of client_secret when a
// ClientAssertionSigner is configured, and that the assertion JWT is valid.
func TestIssue135_AuthCodeExchangeUsesAssertion(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
var capturedBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body := make([]byte, r.ContentLength)
_, _ = r.Body.Read(body)
capturedBody = body
w.Header().Set("Content-Type", "application/json")
// Return a minimal token response so exchangeTokens doesn't error.
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "at",
IDToken: "it",
RefreshToken: "rt",
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "wire-kid")
require.NoError(t, err)
oidc := &TraefikOidc{
clientID: "wire-client",
tokenHTTPClient: server.Client(),
clientAssertion: signer,
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = server.URL
_, err = oidc.exchangeTokens(context.Background(), "authorization_code", "code-x", "https://app/cb", "")
require.NoError(t, err)
form, err := url.ParseQuery(string(capturedBody))
require.NoError(t, err)
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
form.Get("client_assertion_type"), "client_assertion_type must be set")
assertionJWT := form.Get("client_assertion")
assert.NotEmpty(t, assertionJWT, "client_assertion must be present")
assert.Empty(t, form.Get("client_secret"), "client_secret must not be sent when using assertion")
assert.Equal(t, "wire-client", form.Get("client_id"))
assert.Equal(t, "code-x", form.Get("code"))
assert.Equal(t, "authorization_code", form.Get("grant_type"))
// Verify assertion JWT: header, claims, signature.
parts := strings.Split(assertionJWT, ".")
require.Len(t, parts, 3)
hdr := decodeJSONPart(t, parts[0])
assert.Equal(t, "RS256", hdr["alg"])
clms := decodeJSONPart(t, parts[1])
assert.Equal(t, "wire-client", clms["iss"])
assert.Equal(t, "wire-client", clms["sub"])
assert.Equal(t, server.URL, clms["aud"],
"audience must be the tokenURL (RFC 7523 §3)")
// Verify signature with RSA public key.
sigInput := parts[0] + "." + parts[1]
digest := sha256SumBytes([]byte(sigInput))
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
assert.NoError(t, rsa.VerifyPKCS1v15(&rsaKey.PublicKey, crypto.SHA256, digest, sigBytes))
}
// TestIssue135_RefreshTokenUsesAssertion verifies that the refresh_token grant
// type also sends client_assertion and the correct form fields.
func TestIssue135_RefreshTokenUsesAssertion(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
var capturedForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, r.ParseForm())
capturedForm = r.Form
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "new-at",
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "rt-kid")
require.NoError(t, err)
oidc := &TraefikOidc{
clientID: "rt-client",
tokenHTTPClient: server.Client(),
clientAssertion: signer,
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = server.URL
_, err = oidc.exchangeTokens(context.Background(), "refresh_token", "rt-y", "", "")
require.NoError(t, err)
assert.Equal(t, "refresh_token", capturedForm.Get("grant_type"))
assert.Equal(t, "rt-y", capturedForm.Get("refresh_token"))
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
capturedForm.Get("client_assertion_type"))
assert.NotEmpty(t, capturedForm.Get("client_assertion"))
assert.Empty(t, capturedForm.Get("client_secret"))
}
// TestIssue135_BackcompatClientSecretPath confirms that exchangeTokens sends
// client_secret and does NOT send client_assertion when clientAssertion is nil.
func TestIssue135_BackcompatClientSecretPath(t *testing.T) {
var capturedForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, r.ParseForm())
capturedForm = r.Form
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "at",
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
oidc := &TraefikOidc{
clientID: "legacy-client",
clientSecret: "legacy-secret",
tokenHTTPClient: server.Client(),
clientAssertion: nil, // back-compat path
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = server.URL
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "code-bc", "https://app/cb", "")
require.NoError(t, err)
assert.Equal(t, "legacy-secret", capturedForm.Get("client_secret"),
"client_secret must be sent on the classic path")
assert.Empty(t, capturedForm.Get("client_assertion"),
"client_assertion must NOT be present on the classic path")
assert.Empty(t, capturedForm.Get("client_assertion_type"),
"client_assertion_type must NOT be present on the classic path")
}
// TestIssue135_ClientSecretBasicAuth verifies that when clientAuthMethod is
// "client_secret_basic", exchangeTokens sends an HTTP Basic Authorization
// header carrying url-encoded client_id:client_secret per RFC 6749 §2.3.1,
// and that neither client_id nor client_secret appears in the form body.
func TestIssue135_ClientSecretBasicAuth(t *testing.T) {
var capturedAuth string
var capturedForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
require.NoError(t, r.ParseForm())
capturedForm = r.Form
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "at-basic", TokenType: "Bearer", ExpiresIn: 3600,
})
}))
defer server.Close()
oidc := &TraefikOidc{
clientID: "basic-client",
clientSecret: "basic-secret",
clientAuthMethod: "client_secret_basic",
tokenHTTPClient: server.Client(),
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = server.URL
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "code-bb", "https://app/cb", "")
require.NoError(t, err)
require.True(t, strings.HasPrefix(capturedAuth, "Basic "),
"Authorization header must start with 'Basic ', got %q", capturedAuth)
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
require.NoError(t, err, "Authorization payload must be valid base64")
user, pass, ok := strings.Cut(string(raw), ":")
require.True(t, ok, "Authorization payload must contain a single ':' separator")
assert.Equal(t, "basic-client", user, "client_id should round-trip through QueryEscape")
assert.Equal(t, "basic-secret", pass, "client_secret should round-trip through QueryEscape")
assert.Empty(t, capturedForm.Get("client_id"),
"client_id must NOT be in the body when using client_secret_basic")
assert.Empty(t, capturedForm.Get("client_secret"),
"client_secret must NOT be in the body when using client_secret_basic")
assert.Empty(t, capturedForm.Get("client_assertion"),
"client_assertion must NOT be present on the basic-auth path")
}
// TestIssue135_ClientSecretBasicURLEncodesReservedChars verifies that
// credentials containing reserved characters (`:`, `+`, `/`, etc.) are
// form-urlencoded before base64 per RFC 6749 §2.3.1, so the receiving
// authorization server can decode them deterministically.
func TestIssue135_ClientSecretBasicURLEncodesReservedChars(t *testing.T) {
var capturedAuth string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TokenResponse{AccessToken: "at", TokenType: "Bearer", ExpiresIn: 3600})
}))
defer server.Close()
const (
clientID = "weird:id+1"
clientSecret = "p@ss/word=&" //nolint:gosec // test fixture
)
oidc := &TraefikOidc{
clientID: clientID,
clientSecret: clientSecret,
clientAuthMethod: "client_secret_basic",
tokenHTTPClient: server.Client(),
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = server.URL
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "c", "https://app/cb", "")
require.NoError(t, err)
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
require.NoError(t, err)
wantUser := url.QueryEscape(clientID)
wantPass := url.QueryEscape(clientSecret)
assert.Equal(t, wantUser+":"+wantPass, string(raw),
"both halves must be form-urlencoded before the base64 step")
}
// TestIssue135_ClientSecretBasicRevocation verifies that the revocation path
// honors client_secret_basic identically to the token path.
func TestIssue135_ClientSecretBasicRevocation(t *testing.T) {
var capturedAuth string
var capturedForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
require.NoError(t, r.ParseForm())
capturedForm = r.Form
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
oidc := &TraefikOidc{
clientID: "rev-basic",
clientSecret: "rev-secret",
clientAuthMethod: "client_secret_basic",
httpClient: server.Client(),
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = "https://idp.example.com/token"
oidc.revocationURL = server.URL
require.NoError(t, oidc.RevokeTokenWithProvider("opaque-tok", "access_token"))
require.True(t, strings.HasPrefix(capturedAuth, "Basic "), "got %q", capturedAuth)
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
require.NoError(t, err)
assert.Equal(t, "rev-basic:rev-secret", string(raw))
assert.Equal(t, "opaque-tok", capturedForm.Get("token"))
assert.Equal(t, "access_token", capturedForm.Get("token_type_hint"))
assert.Empty(t, capturedForm.Get("client_id"),
"client_id must NOT be in body on Basic-auth revocation")
assert.Empty(t, capturedForm.Get("client_secret"),
"client_secret must NOT be in body on Basic-auth revocation")
}
// ── D. Wire-up — RevokeTokenWithProvider ────────────────────────────────────
// TestIssue135_RevocationUsesAssertion verifies that RevokeTokenWithProvider
// sends client_assertion (not client_secret), and that the assertion's audience
// is the tokenURL, not the revocationURL (per RFC 7523 §3).
func TestIssue135_RevocationUsesAssertion(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
const (
tokenEndpoint = "https://idp.example.com/token" // audience for assertion
clientIDVal = "revoke-client"
)
var capturedForm url.Values
// Revocation endpoint — deliberate separate URL to confirm audience != revocationURL.
revokeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, r.ParseForm())
capturedForm = r.Form
w.WriteHeader(http.StatusOK)
}))
defer revokeServer.Close()
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "rev-kid")
require.NoError(t, err)
oidc := &TraefikOidc{
clientID: clientIDVal,
clientAssertion: signer,
httpClient: revokeServer.Client(),
logger: GetSingletonNoOpLogger(),
}
// tokenURL drives assertion audience; revocationURL is where the POST goes.
oidc.tokenURL = tokenEndpoint
oidc.revocationURL = revokeServer.URL
err = oidc.RevokeTokenWithProvider("some-token", "refresh_token")
require.NoError(t, err)
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
capturedForm.Get("client_assertion_type"))
assertionJWT := capturedForm.Get("client_assertion")
assert.NotEmpty(t, assertionJWT)
assert.Empty(t, capturedForm.Get("client_secret"),
"client_secret must not appear in revocation request with assertion")
// Verify the assertion audience is tokenURL (not revocationURL).
parts := strings.Split(assertionJWT, ".")
require.Len(t, parts, 3)
clms := decodeJSONPart(t, parts[1])
assert.Equal(t, tokenEndpoint, clms["aud"],
"assertion audience must be tokenURL, not revocationURL")
// Sanity-check cryptographic validity.
sigInput := parts[0] + "." + parts[1]
digest := sha256SumBytes([]byte(sigInput))
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
assert.NoError(t, rsa.VerifyPKCS1v15(&rsaKey.PublicKey, crypto.SHA256, digest, sigBytes))
}
// ── E. End-to-end via buildClientAssertionSignerFromConfig ───────────────────
// TestIssue135_BuildSignerFromInlineConfig confirms that the full config→signer
// pipeline works for an ES256 key specified inline in the Config struct.
func TestIssue135_BuildSignerFromInlineConfig(t *testing.T) {
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
pemBytes := encodeECPKCS8(t, ecKey)
cfg := &Config{
ClientAuthMethod: "private_key_jwt",
ClientAssertionPrivateKey: string(pemBytes),
ClientAssertionKeyID: "inline-ec-kid",
ClientAssertionAlg: "ES256",
}
signer, err := buildClientAssertionSignerFromConfig(cfg)
require.NoError(t, err)
require.NotNil(t, signer)
jwtStr, err := signer.Sign("https://example.com/token", "inline-client")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
hdr := decodeJSONPart(t, parts[0])
assert.Equal(t, "ES256", hdr["alg"])
assert.Equal(t, "inline-ec-kid", hdr["kid"])
// Verify the EC signature.
byteLen := (elliptic.P256().Params().BitSize + 7) / 8
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
require.Len(t, sigBytes, 2*byteLen)
r := new(big.Int).SetBytes(sigBytes[:byteLen])
s := new(big.Int).SetBytes(sigBytes[byteLen:])
sigInput := parts[0] + "." + parts[1]
digest := sha256SumBytes([]byte(sigInput))
assert.True(t, ecdsa.Verify(&ecKey.PublicKey, digest, r, s))
}
// TestIssue135_BuildSignerDefaultsToRS256 verifies that an empty
// ClientAssertionAlg defaults to RS256.
func TestIssue135_BuildSignerDefaultsToRS256(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
cfg := &Config{
ClientAssertionPrivateKey: string(pemBytes),
ClientAssertionKeyID: "default-alg-kid",
ClientAssertionAlg: "", // intentionally empty
}
signer, err := buildClientAssertionSignerFromConfig(cfg)
require.NoError(t, err)
jwtStr, err := signer.Sign("https://example.com/token", "default-client")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
hdr := decodeJSONPart(t, parts[0])
assert.Equal(t, "RS256", hdr["alg"], "empty alg must default to RS256")
}
// ── Helpers ───────────────────────────────────────────────────────────────────
// genRSAKey generates an RSA key of the given bit size, failing the test on error.
func genRSAKey(t *testing.T, bits int) *rsa.PrivateKey {
t.Helper()
k, err := rsa.GenerateKey(rand.Reader, bits)
require.NoError(t, err)
return k
}
// encodeRSAPKCS8 marshals an RSA key as PKCS#8 PEM ("PRIVATE KEY").
func encodeRSAPKCS8(t *testing.T, key *rsa.PrivateKey) []byte {
t.Helper()
der, err := x509.MarshalPKCS8PrivateKey(key)
require.NoError(t, err)
return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
}
// encodeECPKCS8 marshals an EC key as PKCS#8 PEM ("PRIVATE KEY").
func encodeECPKCS8(t *testing.T, key *ecdsa.PrivateKey) []byte {
t.Helper()
der, err := x509.MarshalPKCS8PrivateKey(key)
require.NoError(t, err)
return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
}
// decodeJSONPart base64url-decodes a JWT part and parses it as a JSON object.
func decodeJSONPart(t *testing.T, b64url string) map[string]any {
t.Helper()
raw, err := base64.RawURLEncoding.DecodeString(b64url)
require.NoError(t, err, "base64url decode of JWT part failed")
var m map[string]any
require.NoError(t, json.Unmarshal(raw, &m), "JSON unmarshal of JWT part failed")
return m
}
// sha256SumBytes returns the SHA-256 digest of b as a byte slice.
func sha256SumBytes(b []byte) []byte {
h := sha256.Sum256(b)
return h[:]
}
// assertValidRSAJWT signs a JWT with signer and verifies the RS256 signature
// against the given RSA public key. Used by PEM variant tests.
func assertValidRSAJWT(t *testing.T, key *rsa.PrivateKey, signer *ClientAssertionSigner, alg string) {
t.Helper()
jwtStr, err := signer.Sign("https://example.com/token", "pem-client")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
hdr := decodeJSONPart(t, parts[0])
assert.Equal(t, alg, hdr["alg"])
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
sigInput := parts[0] + "." + parts[1]
digest := sha256SumBytes([]byte(sigInput))
assert.NoError(t, rsa.VerifyPKCS1v15(&key.PublicKey, crypto.SHA256, digest, sigBytes))
}
+19 -5
View File
@@ -76,9 +76,15 @@ func NewJWKCache() *JWKCache {
}
// GetJWKS retrieves JWKS from cache or fetches from the remote URL if not cached.
//
// The entry is stored locally only via SetLocal/GetLocal. Going through a
// distributed backend defeats the cache: JSON round-tripping turns *JWKSet
// into map[string]interface{}, the type assertion below fails, and every
// request refetches from the upstream. JWK rotation is rare and a per-replica
// HTTP fetch on cold cache is cheap, so cross-replica coherence buys nothing.
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
// Check cache first
if cachedValue, found := c.cache.Get(jwksURL); found {
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
if jwks, ok := cachedValue.(*JWKSet); ok {
return jwks, nil
}
@@ -88,7 +94,7 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
defer c.mutex.Unlock()
// Double-check after acquiring lock
if cachedValue, found := c.cache.Get(jwksURL); found {
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
if jwks, ok := cachedValue.(*JWKSet); ok {
return jwks, nil
}
@@ -105,7 +111,7 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
}
// Cache for 1 hour
_ = c.cache.Set(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
_ = c.cache.SetLocal(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
return jwks, nil
}
@@ -114,9 +120,17 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
// caching the JWKS plus its derived parsedJWKS on miss. The parsed entry is
// stored alongside the raw JWKSet under a sibling cache key with the same
// 1-hour TTL, so both invalidate together when the upstream JWKS rotates.
//
// parsedJWKS is stored locally only (SetLocal/GetLocal). Its values are
// crypto.PublicKey interfaces wrapping *rsa.PublicKey/*ecdsa.PublicKey,
// which contain *big.Int that marshals to a hundreds-digit JSON number.
// On a distributed backend round-trip, json.Unmarshal into interface{} would
// try to fit that into float64 and fail with UnmarshalTypeError. Under yaegi
// the unexported parsedJWKS.keys field is exposed via an X-prefixed name on
// Marshal, leaking the modulus into the cached payload (issue #134).
func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
parsedKey := jwksURL + parsedKeysSuffix
if v, found := c.cache.Get(parsedKey); found {
if v, found := c.cache.GetLocal(parsedKey); found {
if pj, ok := v.(*parsedJWKS); ok {
if k, ok := pj.keys[kid]; ok {
return k, nil
@@ -130,7 +144,7 @@ func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpCl
}
pj := buildParsedJWKS(jwks)
_ = c.cache.Set(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical
_ = c.cache.SetLocal(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical
if k, ok := pj.keys[kid]; ok {
return k, nil
+28
View File
@@ -169,6 +169,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
introspectionCache: cacheManager.GetSharedIntrospectionCache(), // Cache for introspection results
clientID: config.ClientID,
clientSecret: config.ClientSecret,
clientAuthMethod: func() string {
if config.ClientAuthMethod != "" {
return config.ClientAuthMethod
}
return "client_secret_post"
}(),
audience: func() string {
if config.Audience != "" {
return config.Audience
@@ -226,6 +232,13 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
}
return 60 * time.Second
}(),
maxRefreshTokenAge: func() time.Duration {
// 0 (or unset) disables the heuristic; negative is rejected by Validate.
if config.MaxRefreshTokenAgeSeconds > 0 {
return time.Duration(config.MaxRefreshTokenAgeSeconds) * time.Second
}
return 0
}(),
tokenCleanupStopChan: make(chan struct{}),
metadataRefreshStopChan: make(chan struct{}),
ctx: pluginCtx,
@@ -242,6 +255,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
refreshResultCache: cacheManager.GetSharedRefreshResultCache(),
}
// Log audience configuration
@@ -260,6 +274,20 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
tokenResilienceConfig := DefaultTokenResilienceConfig()
t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger)
// Coalesces concurrent refresh-token grants per refresh_token to one upstream
// call, preventing the thundering herd that yields invalid_grant when the IdP
// rotates refresh tokens (Zitadel/Authentik default).
t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger)
if config.ClientAuthMethod == "private_key_jwt" {
signer, err := buildClientAssertionSignerFromConfig(config)
if err != nil {
cancelFunc()
return nil, fmt.Errorf("failed to build client assertion signer: %w", err)
}
t.clientAssertion = signer
}
t.extractClaimsFunc = extractClaims
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
+199 -47
View File
@@ -79,34 +79,186 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
}
}
// TestServeHTTP_EventStream tests the event-stream bypass functionality
// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the
// handshake must skip the OIDC redirect dance (clients can't follow it
// mid-stream) but it must STILL require an authenticated session, otherwise
// any caller could reach the backend by setting Accept: text/event-stream.
func TestServeHTTP_EventStream(t *testing.T) {
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
t.Run("unauthenticated_request_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated SSE bypass")
}
})
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
// Build an authenticated session and inject its cookies onto req.
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated SSE request to be forwarded to backend")
}
if forwardedUser != "user@example.com" {
t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser)
}
})
}
// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket
// handshake bypasses the OIDC redirect (clients can't follow it) but the
// session must already be authenticated, otherwise the backend is exposed
// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`.
func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
rw := httptest.NewRecorder()
t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
}))
oidc.ServeHTTP(rw, req)
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
rw := httptest.NewRecorder()
if !nextCalled {
t.Error("expected event-stream request to bypass OIDC")
}
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated WS bypass")
}
})
t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
}))
req := httptest.NewRequest("GET", "/ws", nil)
// Mixed-case + multi-token Connection header to exercise parsing.
req.Header.Set("Connection", "keep-alive, Upgrade")
req.Header.Set("Upgrade", "WebSocket")
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("ws-user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated WS handshake to be forwarded to backend")
}
if forwardedUser != "ws-user@example.com" {
t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser)
}
})
t.Run("plain_http_does_not_bypass", func(t *testing.T) {
// Sanity: requests without Upgrade headers must NOT hit the WS
// bypass branch (otherwise the new code path could short-circuit
// normal authentication).
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("backend must not be called for unauthenticated plain HTTP")
}))
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "keep-alive")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code == http.StatusOK {
t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200")
}
})
}
// TestServeHTTP_InitializationTimeout tests initialization timeout handling
@@ -256,7 +408,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "successful authorization with email",
setupSession: func() *MockSessionData {
session := &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
@@ -288,7 +440,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "no email triggers reauth",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "",
userIdentifier: "",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -309,7 +461,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "roles and groups authorization",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -342,7 +494,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "unauthorized role/group returns 403",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -369,7 +521,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "template headers processing",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
@@ -401,7 +553,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "OPTIONS request with CORS",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -452,7 +604,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
manager: &SessionManager{logger: NewLogger("debug")},
}
// Copy values from mock to concrete session
concreteSession.SetEmail(session.email)
concreteSession.SetUserIdentifier(session.userIdentifier)
concreteSession.SetIDToken(session.idToken)
concreteSession.SetAccessToken(session.accessToken)
concreteSession.SetRefreshToken(session.refreshToken)
@@ -502,23 +654,23 @@ func TestProcessAuthorizedRequest(t *testing.T) {
// MockSessionData is a test implementation of SessionData interface
type MockSessionData struct {
email string
idToken string
accessToken string
refreshToken string
csrf string
nonce string
codeVerifier string
redirectCount int
authenticated bool
isDirty bool
userIdentifier string
idToken string
accessToken string
refreshToken string
csrf string
nonce string
codeVerifier string
redirectCount int
authenticated bool
isDirty bool
}
func (m *MockSessionData) GetEmail() string { return m.email }
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
func (m *MockSessionData) GetIDToken() string { return m.idToken }
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
func (m *MockSessionData) SetEmail(email string) { m.email = email }
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
@@ -610,7 +762,7 @@ func TestMinimalHeaders(t *testing.T) {
}
// Set up session data
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Call processAuthorizedRequest directly
@@ -685,7 +837,7 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
@@ -771,7 +923,7 @@ func TestStripAuthCookies(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Now add OIDC session cookies (simulating what the browser would send)
@@ -852,7 +1004,7 @@ func TestStripAuthCookies_NoCookies(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
@@ -899,7 +1051,7 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only OIDC cookies
@@ -950,7 +1102,7 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only non-OIDC cookies
@@ -1013,7 +1165,7 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) {
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add cookies with the custom prefix (should be stripped)
+15 -15
View File
@@ -580,7 +580,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Generate a fresh valid token for this test case to avoid replay issues
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -603,7 +603,7 @@ func TestServeHTTP(t *testing.T) {
// even if session.SetAuthenticated(true) was called.
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -660,7 +660,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -678,7 +678,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -706,7 +706,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -741,7 +741,7 @@ func TestServeHTTP(t *testing.T) {
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken(nearExpiryToken)
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
},
@@ -772,7 +772,7 @@ func TestServeHTTP(t *testing.T) {
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken(validToken)
session.SetIDToken(validToken) // Ensure ID token is also set
session.SetRefreshToken("should-not-be-used-refresh-token")
@@ -792,7 +792,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@disallowed.com") // Use disallowed domain
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -814,7 +814,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@disallowed.com") // Use disallowed domain
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -2179,7 +2179,7 @@ func TestHandleExpiredToken(t *testing.T) {
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
},
expectedPath: "/original/path",
},
@@ -2756,7 +2756,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2782,7 +2782,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2809,7 +2809,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusForbidden,
},
@@ -2829,7 +2829,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2851,7 +2851,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{},
+150 -78
View File
@@ -14,21 +14,40 @@ import (
)
// bypassReason describes why a request is being forwarded without OIDC auth.
// It is only used for logging and to decide whether extra SSE-specific work
// It is only used for logging and to decide whether extra side-effects
// (propagating the user header from an existing session) should run.
const (
bypassReasonExcluded = "excluded-url"
bypassReasonSSE = "sse"
bypassReasonExcluded = "excluded-url"
bypassReasonSSE = "sse"
bypassReasonWebSocket = "websocket"
)
// isWebSocketUpgrade reports whether req is a WebSocket upgrade handshake
// (RFC 6455). The middleware can only see the handshake; once Traefik
// completes the upgrade it forwards frames directly, so we never re-process
// per-frame traffic. We bypass auth on the handshake the same way we do for
// SSE, because browser WebSocket clients cannot follow an OIDC redirect.
func isWebSocketUpgrade(req *http.Request) bool {
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
return false
}
for _, token := range strings.Split(req.Header.Get("Connection"), ",") {
if strings.EqualFold(strings.TrimSpace(token), "upgrade") {
return true
}
}
return false
}
// shouldBypassAuth decides whether a request must skip OIDC authentication
// entirely. It returns (true, reason) when either the request path matches a
// configured excluded URL or the Accept header asks for a text/event-stream
// response (SSE). The reason lets ServeHTTP apply any side-effects that are
// unique to the bypass kind (e.g. propagating user headers for SSE).
// configured excluded URL, the Accept header asks for a text/event-stream
// response (SSE), or the request is a WebSocket upgrade handshake. The
// reason lets ServeHTTP apply any side-effects that are unique to the bypass
// kind (e.g. propagating user headers).
//
// This must be called BEFORE waiting on t.initComplete so excluded and SSE
// traffic is never blocked by a slow/broken provider.
// This must be called BEFORE waiting on t.initComplete so excluded, SSE and
// WebSocket traffic is never blocked by a slow/broken provider.
func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
if t.determineExcludedURL(req.URL.Path) {
return true, bypassReasonExcluded
@@ -36,38 +55,55 @@ func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
if strings.Contains(req.Header.Get("Accept"), "text/event-stream") {
return true, bypassReasonSSE
}
if isWebSocketUpgrade(req) {
return true, bypassReasonWebSocket
}
return false, ""
}
// applySSEUserHeaders attempts to copy the authenticated user's identity from
// an existing session onto the outgoing SSE request so downstream services
// can still see who the user is. Failures are logged (not silenced) because
// they indicate either a corrupt cookie or a misconfigured session manager
// and are useful for debugging, but they never block the bypass itself.
func (t *TraefikOidc) applySSEUserHeaders(req *http.Request) {
// applyBypassUserHeaders enforces authentication on SSE / WebSocket bypass
// requests and, on success, copies the authenticated user's identity onto
// the outgoing request so downstream services can see who the user is.
//
// Returns true when the request carries a valid authenticated session and
// the bypass should proceed. Returns false when no usable session is
// present; callers must then reject the request (typically with 401) to
// prevent unauthenticated traffic from reaching the backend just by setting
// `Accept: text/event-stream` or sending a WebSocket upgrade.
//
// The check is cookie-only: the session cookie is sealed by our encryption
// key, so the authenticated flag cannot be forged. We do NOT run full token
// signature verification here so that SSE/WS keeps working when the OIDC
// provider is briefly unavailable for JWK fetches.
func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) bool {
if t.sessionManager == nil {
return
return false
}
session, err := t.sessionManager.GetSession(req)
if err != nil {
// Intentionally not fatal: SSE requests bypass auth, we just lose the
// forwarded-user header for this request.
t.logger.Debugf("SSE bypass: unable to load session for user header propagation: %v", err)
return
t.logger.Debugf("%s bypass: unable to load session: %v", reason, err)
return false
}
defer session.returnToPoolSafely()
email := session.GetEmail()
if email == "" {
return
if !session.GetAuthenticated() {
t.logger.Debugf("%s bypass: rejecting request without authenticated session", reason)
return false
}
req.Header.Set("X-Forwarded-User", email)
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-User", email)
userIdentifier := session.GetUserIdentifier()
if userIdentifier == "" {
t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason)
return false
}
t.logger.Debugf("SSE bypass: forwarded user %s from session", email)
req.Header.Set("X-Forwarded-User", userIdentifier)
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-User", userIdentifier)
}
t.logger.Debugf("%s bypass: forwarded user %s from session", reason, userIdentifier)
return true
}
// ServeHTTP implements the main middleware logic for processing HTTP requests.
@@ -124,16 +160,32 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.firstRequestMutex.Unlock()
}
// Evaluate auth-bypass once, before waiting for initialization. Excluded URLs
// and SSE requests must not block on provider init. For SSE we additionally
// attempt to forward the user identity from an existing session (best
// effort) so downstream handlers still see X-Forwarded-User.
// Evaluate auth-bypass once, before waiting for initialization. Excluded
// URLs, SSE and WebSocket upgrade requests must not block on provider
// init. For SSE/WebSocket we ALSO require an authenticated session
// (cookie-only check, no JWK fetch) and otherwise return 401 — clients
// of in-flight streams can't follow an OIDC redirect, so forwarding
// unauthenticated traffic would silently expose the backend.
if bypass, reason := t.shouldBypassAuth(req); bypass {
t.logger.Debugf("Bypassing OIDC for %s (%s)", req.URL.Path, reason)
if reason == bypassReasonSSE {
t.applySSEUserHeaders(req)
switch reason {
case bypassReasonExcluded:
// Operator-declared excluded URLs forward unconditionally.
t.next.ServeHTTP(rw, req)
case bypassReasonSSE, bypassReasonWebSocket:
// Skip the OIDC redirect dance (clients can't follow it
// mid-stream) but still require an authenticated session.
// Otherwise an unauthenticated client could hit the backend
// just by setting Accept: text/event-stream or sending a
// WebSocket upgrade.
if !t.applyBypassUserHeaders(req, reason) {
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
return
}
t.next.ServeHTTP(rw, req)
default:
t.next.ServeHTTP(rw, req)
}
t.next.ServeHTTP(rw, req)
return
}
@@ -237,7 +289,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
userIdentifier := session.GetUserIdentifier()
// User authorization check
if authenticated && userIdentifier != "" {
if !t.isAllowedUser(userIdentifier) {
@@ -309,7 +361,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
refreshed := t.refreshToken(rw, req, session)
if refreshed {
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
userIdentifier = session.GetUserIdentifier()
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
@@ -359,9 +411,9 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// - session: The user's session data containing tokens and claims.
// - redirectURL: The callback URL for re-authentication if needed.
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
email := session.GetEmail()
if email == "" {
t.logger.Info("No email found in session during final processing, initiating re-auth")
userIdentifier := session.GetUserIdentifier()
if userIdentifier == "" {
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
// Reset redirect count to prevent loops when session is invalid
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
@@ -374,7 +426,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
if idToken != "" {
sid, sub, createdAt := t.extractSessionInfo(idToken)
if t.isSessionInvalidated(sid, sub, createdAt) {
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", email)
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
// Clear the session and redirect to login
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing invalidated session: %v", err)
@@ -386,31 +438,52 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
}
}
tokenForClaims := session.GetIDToken()
if tokenForClaims == "" {
tokenForClaims = session.GetAccessToken()
if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 {
t.logger.Error("No token available but roles/groups checks are required")
// Reset redirect count to prevent loops when token is missing
// Resolve ID-token claims at most once per request. SessionData caches
// the parsed claims keyed on the raw ID token, so concurrent dashboard
// panel requests on the same session don't repeatedly base64-decode and
// JSON-unmarshal the same JWT (a real cost under the yaegi interpreter
// that hosts Traefik plugins). idClaims is reused below by the
// header-templates branch.
idToken := session.GetIDToken()
var (
idClaims map[string]interface{}
idClaimsErr error
)
if idToken != "" {
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
}
// Choose which claims drive groups/roles extraction. Prefer the ID
// token (cached) and fall back to the access token if there is no ID
// token in the session — matching the prior behavior for opaque
// ID-token providers.
var (
groupClaims map[string]interface{}
groupClaimsErr error
)
if idToken != "" {
groupClaims, groupClaimsErr = idClaims, idClaimsErr
} else if accessToken := session.GetAccessToken(); accessToken != "" {
groupClaims, groupClaimsErr = t.extractClaimsFunc(accessToken)
} else if len(t.allowedRolesAndGroups) > 0 {
t.logger.Error("No token available but roles/groups checks are required")
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
var groups, roles []string
if groupClaimsErr == nil && groupClaims != nil {
var err error
groups, roles, err = t.extractGroupsAndRolesFromClaims(groupClaims)
if err != nil && len(t.allowedRolesAndGroups) > 0 {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
}
// Initialize empty slices
var groups, roles []string
if tokenForClaims != "" {
var err error
groups, roles, err = t.extractGroupsAndRoles(tokenForClaims)
if err != nil && len(t.allowedRolesAndGroups) > 0 {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
// Reset redirect count to prevent loops when claim extraction fails
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
} else if err == nil {
if err == nil {
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
@@ -429,54 +502,53 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
t.logger.Infof("User %s does not have any allowed roles or groups", userIdentifier)
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
req.Header.Set("X-Forwarded-User", userIdentifier)
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetIDToken(); idToken != "" {
req.Header.Set("X-Auth-Request-User", userIdentifier)
if idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
}
}
if len(t.headerTemplates) > 0 {
// Reuse claims parsed earlier in this request if the ID token has not
// changed. Saves an unnecessary JWT parse on every authenticated
// request that uses headerTemplates.
claims, err := session.GetIDTokenClaims(t.extractClaimsFunc)
if err != nil {
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", err)
if idClaimsErr != nil {
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", idClaimsErr)
} else {
// idClaims may be nil when no ID token is present; templates
// referencing .Claims.* will simply produce empty values, which
// matches the prior behavior.
templateData := map[string]interface{}{
"AccessToken": session.GetAccessToken(),
"IDToken": session.GetIDToken(),
"IDToken": idToken,
"RefreshToken": session.GetRefreshToken(),
"Claims": claims,
"Claims": idClaims,
}
for headerName, tmpl := range t.headerTemplates {
var buf bytes.Buffer
if err := tmpl.Execute(&buf, templateData); err != nil {
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
continue
}
headerValue := buf.String()
req.Header.Set(headerName, headerValue)
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
}
session.MarkDirty()
t.logger.Debugf("Session marked dirty after templated header processing.")
// NOTE: templates only mutate request headers (not session state),
// so we deliberately do NOT MarkDirty / Save here. Previously every
// authenticated request with header templates re-encrypted and
// rewrote all session cookies, which was a measurable CPU and
// Set-Cookie tax on dashboards that poll many panels per second.
}
}
@@ -515,7 +587,7 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
}
}
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", userIdentifier)
t.next.ServeHTTP(rw, req)
}
+7 -7
View File
@@ -161,7 +161,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
// Create authenticated session
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
session.SetIDToken("dummy-token")
session.Save(req, httptest.NewRecorder())
@@ -203,7 +203,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
// Create session with forbidden domain
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@forbidden.com")
session.SetUserIdentifier("user@forbidden.com")
session.SetAuthenticated(true)
// Save and inject cookies
@@ -252,7 +252,7 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
// Create session with opaque token
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots)
session.SetAuthenticated(true)
@@ -291,7 +291,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("") // No email
session.SetUserIdentifier("") // No email
session.SetIDToken("dummy-token")
rw := httptest.NewRecorder()
@@ -321,7 +321,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetIDToken("") // No ID token
session.SetAccessToken("") // No access token
@@ -349,7 +349,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetIDToken("dummy-token")
rw := httptest.NewRecorder()
@@ -383,7 +383,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
testEmail := "user@example.com"
session.SetEmail(testEmail)
session.SetUserIdentifier(testEmail)
session.SetIDToken("dummy-id-token")
rw := httptest.NewRecorder()
+13
View File
@@ -466,10 +466,23 @@ func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
// hashRefreshToken creates a hash of the refresh token for deduplication
func (rc *RefreshCoordinator) hashRefreshToken(token string) string {
return refreshCoordinatorSessionID(token)
}
// refreshCoordinatorSessionID derives a stable identifier from a refresh token
// for both deduplication and per-session attempt tracking. Using sha256 of the
// raw token means each rotation produces a fresh sessionID with its own attempt
// budget, which is what we want.
func refreshCoordinatorSessionID(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
// refreshCoordinatorWaitTimeout caps how long a request may wait for a
// coordinated refresh result. It is wider than RefreshTimeout so a follower
// always sees the leader's result instead of timing out independently.
const refreshCoordinatorWaitTimeout = 35 * time.Second
// isUnderMemoryPressure checks if the system is under memory pressure by
// consulting the global memory monitor. Returns true when pressure reaches
// High or Critical, at which point we refuse new refresh operations to
+164
View File
@@ -0,0 +1,164 @@
package traefikoidc
import (
"context"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
)
// stubTokenExchanger lets us count how many upstream refresh-token grants
// happen for a given refresh_token across concurrent middleware-level calls.
type stubTokenExchanger struct {
calls int32
delay time.Duration
resp *TokenResponse
}
func (s *stubTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
return nil, nil
}
func (s *stubTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
atomic.AddInt32(&s.calls, 1)
if s.delay > 0 {
time.Sleep(s.delay)
}
return s.resp, nil
}
func (s *stubTokenExchanger) RevokeTokenWithProvider(_, _ string) error {
return nil
}
// TestCoordinatedTokenRefresh_SingleUpstreamCall verifies the wireup: many
// concurrent calls to coordinatedTokenRefresh with the same refresh token
// must collapse to a single tokenExchanger.GetNewTokenWithRefreshToken call.
//
// Without the wireup this assertion fails (one upstream call per goroutine).
func TestCoordinatedTokenRefresh_SingleUpstreamCall(t *testing.T) {
stub := &stubTokenExchanger{
delay: 100 * time.Millisecond,
resp: &TokenResponse{
AccessToken: "new_access",
RefreshToken: "new_refresh",
IDToken: "new_id",
ExpiresIn: 3600,
},
}
logger := NewLogger("error")
cfg := DefaultRefreshCoordinatorConfig()
cfg.MaxRefreshAttempts = 10000
cfg.MaxConcurrentRefreshes = 32
oidc := &TraefikOidc{
logger: logger,
tokenExchanger: stub,
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
}
defer oidc.refreshCoordinator.Shutdown()
const concurrency = 50
var wg sync.WaitGroup
wg.Add(concurrency)
req := httptest.NewRequest("GET", "/", nil)
start := make(chan struct{})
for i := 0; i < concurrency; i++ {
go func() {
defer wg.Done()
<-start
resp, err := oidc.coordinatedTokenRefresh(req, "shared_refresh_token")
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if resp == nil || resp.AccessToken != "new_access" {
t.Errorf("unexpected response: %+v", resp)
}
}()
}
close(start)
wg.Wait()
got := atomic.LoadInt32(&stub.calls)
// Up to 2 is acceptable to absorb the documented timing slack in the
// existing coordinator tests (e.g. operation just cleaned up before a
// late goroutine reads the in-flight map). Anything beyond that means
// coalescing is broken.
if got > 2 {
t.Fatalf("expected <=2 upstream refresh calls, got %d", got)
}
}
// TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator verifies the nil
// coordinator path so existing tests that build TraefikOidc literals stay
// green.
func TestCoordinatedTokenRefresh_FallsBackWithoutCoordinator(t *testing.T) {
stub := &stubTokenExchanger{
resp: &TokenResponse{AccessToken: "ok"},
}
oidc := &TraefikOidc{
logger: NewLogger("error"),
tokenExchanger: stub,
// refreshCoordinator deliberately nil
}
resp, err := oidc.coordinatedTokenRefresh(nil, "rt")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp == nil || resp.AccessToken != "ok" {
t.Fatalf("unexpected response: %+v", resp)
}
if got := atomic.LoadInt32(&stub.calls); got != 1 {
t.Fatalf("expected exactly 1 upstream call, got %d", got)
}
}
// TestCoordinatedTokenRefresh_DistinctTokensRunInParallel verifies that
// distinct refresh tokens are not falsely coalesced.
func TestCoordinatedTokenRefresh_DistinctTokensRunInParallel(t *testing.T) {
stub := &stubTokenExchanger{
delay: 20 * time.Millisecond,
resp: &TokenResponse{AccessToken: "ok"},
}
logger := NewLogger("error")
cfg := DefaultRefreshCoordinatorConfig()
cfg.MaxRefreshAttempts = 10000
cfg.MaxConcurrentRefreshes = 32
cfg.DeduplicationCleanupDelay = 0
oidc := &TraefikOidc{
logger: logger,
tokenExchanger: stub,
refreshCoordinator: NewRefreshCoordinator(cfg, logger),
}
defer oidc.refreshCoordinator.Shutdown()
const distinct = 8
var wg sync.WaitGroup
wg.Add(distinct)
for i := 0; i < distinct; i++ {
i := i
go func() {
defer wg.Done()
_, err := oidc.coordinatedTokenRefresh(nil, refreshCoordinatorSessionID(string(rune('a'+i))))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}()
}
wg.Wait()
if got := atomic.LoadInt32(&stub.calls); int(got) != distinct {
t.Fatalf("expected %d distinct upstream calls, got %d", distinct, got)
}
}
+186
View File
@@ -0,0 +1,186 @@
package traefikoidc
import (
"context"
"errors"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
)
// inMemoryCache is the smallest CacheInterface that satisfies the cross-
// replica dedup contract: Set/Get with TTL. Used in place of the universal
// cache singleton so these tests stay hermetic.
type inMemoryCache struct {
entries map[string]inMemoryCacheEntry
mu sync.Mutex
}
type inMemoryCacheEntry struct {
expiresAt time.Time
value interface{}
}
func newInMemoryCache() *inMemoryCache {
return &inMemoryCache{entries: make(map[string]inMemoryCacheEntry)}
}
func (c *inMemoryCache) Set(key string, value any, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.entries[key] = inMemoryCacheEntry{value: value, expiresAt: time.Now().Add(ttl)}
}
func (c *inMemoryCache) Get(key string) (any, bool) {
c.mu.Lock()
defer c.mu.Unlock()
e, ok := c.entries[key]
if !ok {
return nil, false
}
if time.Now().After(e.expiresAt) {
delete(c.entries, key)
return nil, false
}
return e.value, true
}
func (c *inMemoryCache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.entries, key)
}
func (c *inMemoryCache) SetMaxSize(int) {}
func (c *inMemoryCache) Cleanup() {}
func (c *inMemoryCache) Close() {}
func (c *inMemoryCache) Size() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.entries)
}
func (c *inMemoryCache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.entries = map[string]inMemoryCacheEntry{}
}
func (c *inMemoryCache) GetStats() map[string]any { return map[string]any{} }
// erroringTokenExchanger always errors - simulates an IdP rejection.
type erroringTokenExchanger struct {
calls int32
}
func (e *erroringTokenExchanger) ExchangeCodeForToken(_ context.Context, _, _, _, _ string) (*TokenResponse, error) {
return nil, errors.New("not used")
}
func (e *erroringTokenExchanger) GetNewTokenWithRefreshToken(_ string) (*TokenResponse, error) {
atomic.AddInt32(&e.calls, 1)
return nil, errors.New("invalid_grant")
}
func (e *erroringTokenExchanger) RevokeTokenWithProvider(_, _ string) error { return nil }
// TestCoordinatedTokenRefresh_CrossReplicaCacheHit simulates a peer Traefik
// replica having just refreshed: the shared cache already has the result, so
// this pod must reuse it without ever calling the IdP.
func TestCoordinatedTokenRefresh_CrossReplicaCacheHit(t *testing.T) {
stub := &stubTokenExchanger{
resp: &TokenResponse{AccessToken: "should_not_be_called"},
}
logger := NewLogger("error")
cache := newInMemoryCache()
preExisting := &TokenResponse{
AccessToken: "from_peer",
RefreshToken: "rotated_by_peer",
IDToken: "id_from_peer",
}
rt := "shared_refresh_token"
cache.Set(refreshResultCacheKey(refreshCoordinatorSessionID(rt)), preExisting, refreshResultCacheTTL)
oidc := &TraefikOidc{
logger: logger,
tokenExchanger: stub,
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
refreshResultCache: cache,
}
defer oidc.refreshCoordinator.Shutdown()
resp, err := oidc.coordinatedTokenRefresh(httptest.NewRequest("GET", "/", nil), rt)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp == nil || resp.AccessToken != "from_peer" {
t.Fatalf("expected peer-provided response, got %+v", resp)
}
if got := atomic.LoadInt32(&stub.calls); got != 0 {
t.Fatalf("expected 0 upstream calls (peer already refreshed), got %d", got)
}
}
// TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache verifies that on a
// cache miss the leader stores its result for peers to find within the TTL.
func TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache(t *testing.T) {
stub := &stubTokenExchanger{
resp: &TokenResponse{AccessToken: "fresh_grant"},
}
logger := NewLogger("error")
cache := newInMemoryCache()
oidc := &TraefikOidc{
logger: logger,
tokenExchanger: stub,
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
refreshResultCache: cache,
}
defer oidc.refreshCoordinator.Shutdown()
rt := "fresh_refresh_token"
resp, err := oidc.coordinatedTokenRefresh(nil, rt)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp == nil || resp.AccessToken != "fresh_grant" {
t.Fatalf("unexpected response: %+v", resp)
}
if got := atomic.LoadInt32(&stub.calls); got != 1 {
t.Fatalf("expected 1 upstream call, got %d", got)
}
v, ok := cache.Get(refreshResultCacheKey(refreshCoordinatorSessionID(rt)))
if !ok {
t.Fatal("expected refresh result to be cached after upstream success")
}
if tr, ok := v.(*TokenResponse); !ok || tr.AccessToken != "fresh_grant" {
t.Fatalf("cached value malformed: %+v", v)
}
}
// TestCoordinatedTokenRefresh_ErrorIsNotCached makes sure we don't poison the
// dedup cache when the IdP rejects the grant. Peers must run their own
// refresh; they cannot inherit an error.
func TestCoordinatedTokenRefresh_ErrorIsNotCached(t *testing.T) {
failing := &erroringTokenExchanger{}
logger := NewLogger("error")
cache := newInMemoryCache()
oidc := &TraefikOidc{
logger: logger,
tokenExchanger: failing,
refreshCoordinator: NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), logger),
refreshResultCache: cache,
}
defer oidc.refreshCoordinator.Shutdown()
if _, err := oidc.coordinatedTokenRefresh(nil, "doomed_refresh_token"); err == nil {
t.Fatal("expected an error from the failing exchanger")
}
if cache.Size() != 0 {
t.Fatalf("error result must not be cached, size=%d", cache.Size())
}
}
+68
View File
@@ -0,0 +1,68 @@
package traefikoidc
import (
"testing"
"time"
"github.com/gorilla/sessions"
)
// sessionWithIssuedAt builds the smallest SessionData that GetRefreshTokenIssuedAt
// reads from. We can't reuse sessionPool.Get() here because that requires a
// fully initialized SessionManager - overkill for this unit-level check.
func sessionWithIssuedAt(t *testing.T, issuedAt time.Time) *SessionData {
t.Helper()
rs := sessions.NewSession(nil, "refresh")
if !issuedAt.IsZero() {
rs.Values["issued_at"] = issuedAt.Unix()
}
return &SessionData{
refreshSession: rs,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
idTokenChunks: make(map[int]*sessions.Session),
}
}
func TestIsRefreshTokenExpired_DisabledWhenAgeZero(t *testing.T) {
tr := &TraefikOidc{maxRefreshTokenAge: 0}
sd := sessionWithIssuedAt(t, time.Now().Add(-30*24*time.Hour))
if tr.isRefreshTokenExpired(sd) {
t.Fatal("expected isRefreshTokenExpired=false when maxRefreshTokenAge is 0")
}
}
func TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp(t *testing.T) {
tr := &TraefikOidc{maxRefreshTokenAge: time.Hour}
sd := sessionWithIssuedAt(t, time.Time{}) // no issued_at value
if tr.isRefreshTokenExpired(sd) {
t.Fatal("expected isRefreshTokenExpired=false when issued_at missing (legacy session)")
}
}
func TestIsRefreshTokenExpired_WithinWindow(t *testing.T) {
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
sd := sessionWithIssuedAt(t, time.Now().Add(-1*time.Hour))
if tr.isRefreshTokenExpired(sd) {
t.Fatal("expected isRefreshTokenExpired=false within max age")
}
}
func TestIsRefreshTokenExpired_BeyondWindow(t *testing.T) {
tr := &TraefikOidc{maxRefreshTokenAge: 6 * time.Hour}
sd := sessionWithIssuedAt(t, time.Now().Add(-7*time.Hour))
if !tr.isRefreshTokenExpired(sd) {
t.Fatal("expected isRefreshTokenExpired=true beyond max age")
}
}
func TestIsRefreshTokenExpired_NilGuards(t *testing.T) {
var tr *TraefikOidc
if tr.isRefreshTokenExpired(nil) {
t.Fatal("nil receiver must not panic and must return false")
}
tr = &TraefikOidc{maxRefreshTokenAge: time.Hour}
if tr.isRefreshTokenExpired(nil) {
t.Fatal("nil session must return false")
}
}
+2 -2
View File
@@ -129,7 +129,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
// Simulate successful Azure authentication
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Azure may use opaque access tokens
session.SetAccessToken("opaque-azure-access-token")
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ") // trufflehog:ignore
@@ -152,7 +152,7 @@ func testIssue53ReverseProxyHTTPS(t *testing.T) {
require.NoError(t, err)
assert.True(t, session2.GetAuthenticated(), "User should remain authenticated")
assert.Equal(t, "user@example.com", session2.GetEmail())
assert.Equal(t, "user@example.com", session2.GetUserIdentifier())
assert.NotEmpty(t, session2.GetAccessToken(), "Access token should persist")
assert.NotEmpty(t, session2.GetIDToken(), "ID token should persist")
assert.NotEmpty(t, session2.GetRefreshToken(), "Refresh token should persist")
+2 -2
View File
@@ -485,7 +485,7 @@ func TestSessionFixationAttack(t *testing.T) {
// Set up the attacker's session with malicious data
attackerSession.SetAuthenticated(true)
attackerSession.SetEmail("attacker@evil.com")
attackerSession.SetUserIdentifier("attacker@evil.com")
attackerSession.SetIDToken(ValidIDToken)
attackerSession.SetAccessToken(ValidAccessToken)
@@ -512,7 +512,7 @@ func TestSessionFixationAttack(t *testing.T) {
}
// Get the email from the session
email := session.GetEmail()
email := session.GetUserIdentifier()
w.Header().Set("X-User-Email", email)
w.WriteHeader(http.StatusOK)
})
+26 -26
View File
@@ -100,7 +100,7 @@ type combinedSessionPayload struct {
A string `json:"a,omitempty"`
R string `json:"r,omitempty"`
I string `json:"i,omitempty"`
E string `json:"e,omitempty"`
Ui string `json:"ui,omitempty"`
Cs string `json:"cs,omitempty"`
N string `json:"n,omitempty"`
Cv string `json:"cv,omitempty"`
@@ -113,11 +113,11 @@ type combinedSessionPayload struct {
// knownSessionKeys are the standard keys that are handled explicitly in the combined payload.
// All other mainSession.Values keys are stored in the X (extra) field.
var knownSessionKeys = map[string]bool{
"access_token": true,
"refresh_token": true,
"id_token": true,
"email": true,
"authenticated": true,
"access_token": true,
"refresh_token": true,
"id_token": true,
"user_identifier": true,
"authenticated": true,
"csrf": true,
"nonce": true,
"code_verifier": true,
@@ -1134,7 +1134,7 @@ func (sm *SessionManager) loadFromCombinedCookies(r *http.Request, sessionData *
sessionData.idTokenSession, _ = sm.store.Get(r, sm.idTokenCookieName())
// Populate legacy session values from combined payload
sessionData.mainSession.Values["email"] = payload.E
sessionData.mainSession.Values["user_identifier"] = payload.Ui
sessionData.mainSession.Values["authenticated"] = payload.Au
sessionData.mainSession.Values["csrf"] = payload.Cs
sessionData.mainSession.Values["nonce"] = payload.N
@@ -1278,7 +1278,7 @@ func (sd *SessionData) saveCombined(r *http.Request, w http.ResponseWriter, opti
A: sd.getAccessTokenUnsafe(),
R: sd.getRefreshTokenUnsafe(),
I: sd.getIDTokenUnsafe(),
E: sd.getEmailUnsafe(),
Ui: sd.getUserIdentifierUnsafe(),
Au: sd.getAuthenticatedUnsafe(),
Cs: sd.getCSRFUnsafe(),
N: sd.getNonceUnsafe(),
@@ -2469,30 +2469,30 @@ func (sd *SessionData) SetCodeVerifier(codeVerifier string) {
}
}
// GetEmail retrieves the authenticated user's email address.
// The email is extracted from ID token claims and used for
// authorization decisions and header injection.
// GetUserIdentifier retrieves the authenticated user's identifier as extracted
// from the configured userIdentifierClaim of the ID token (email, sub, oid,
// upn, preferred_username, etc.). The value is used for authorization
// decisions and header injection.
// Returns:
// - The user's email address string, or an empty string if not set.
func (sd *SessionData) GetEmail() string {
// - The user identifier string, or an empty string if not set.
func (sd *SessionData) GetUserIdentifier() string {
sd.sessionMutex.RLock()
defer sd.sessionMutex.RUnlock()
email, _ := sd.mainSession.Values["email"].(string)
return email
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
return userIdentifier
}
// SetEmail stores the authenticated user's email address.
// The email is typically extracted from the 'email' claim in the ID token.
// SetUserIdentifier stores the authenticated user's identifier value.
// Parameters:
// - email: The user's email address to store.
func (sd *SessionData) SetEmail(email string) {
// - userIdentifier: The user identifier to store (email, sub, or other claim value).
func (sd *SessionData) SetUserIdentifier(userIdentifier string) {
sd.sessionMutex.Lock()
defer sd.sessionMutex.Unlock()
currentVal, _ := sd.mainSession.Values["email"].(string)
if currentVal != email {
sd.mainSession.Values["email"] = email
currentVal, _ := sd.mainSession.Values["user_identifier"].(string)
if currentVal != userIdentifier {
sd.mainSession.Values["user_identifier"] = userIdentifier
sd.dirty = true
}
}
@@ -2626,10 +2626,10 @@ func (sd *SessionData) getRefreshTokenUnsafe() string {
return result.Token
}
// getEmailUnsafe retrieves the email without acquiring locks.
func (sd *SessionData) getEmailUnsafe() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
// getUserIdentifierUnsafe retrieves the user identifier without acquiring locks.
func (sd *SessionData) getUserIdentifierUnsafe() string {
userIdentifier, _ := sd.mainSession.Values["user_identifier"].(string)
return userIdentifier
}
// getCSRFUnsafe retrieves the CSRF token without acquiring locks.
+6 -7
View File
@@ -320,17 +320,16 @@ func (s *SessionBehaviourSuite) TestSessionData_DirtyTracking() {
s.False(session.IsDirty())
}
// TestSessionData_SetEmail tests email setter with dirty tracking
func (s *SessionBehaviourSuite) TestSessionData_SetEmail() {
// TestSessionData_SetUserIdentifier tests user identifier setter with dirty tracking
func (s *SessionBehaviourSuite) TestSessionData_SetUserIdentifier() {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
session, err := s.sessionManager.GetSession(req)
s.Require().NoError(err)
defer session.returnToPoolSafely()
// Set email
session.SetEmail("test@example.com")
s.Equal("test@example.com", session.GetEmail())
session.SetUserIdentifier("test@example.com")
s.Equal("test@example.com", session.GetUserIdentifier())
s.True(session.IsDirty())
}
@@ -568,7 +567,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Clear() {
// Set some data
err = session.SetAuthenticated(true)
s.Require().NoError(err)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
session.SetCSRF("csrf-token")
// Clear session
@@ -588,7 +587,7 @@ func (s *SessionBehaviourSuite) TestSessionData_Save() {
defer session.returnToPoolSafely()
// Modify session
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
s.True(session.IsDirty())
// Save session
+6 -6
View File
@@ -2688,7 +2688,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
// Set up initial session state (what user has when first logging in)
session1.SetAuthenticated(true)
session1.SetEmail(originalUserData["email"].(string))
session1.SetUserIdentifier(originalUserData["email"].(string))
session1.SetAccessToken("initial-valid-access-token-longer-than-20-chars")
session1.SetIDToken("initial-valid-id-token-longer-than-20-chars")
session1.SetRefreshToken("valid-refresh-token-should-last-30-days")
@@ -2732,7 +2732,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
// Simulate what happens when middleware detects expired tokens
// It should preserve session state while attempting token refresh
originalAuth := session2.GetAuthenticated()
originalEmail := session2.GetEmail()
originalEmail := session2.GetUserIdentifier()
// Reconstruct user data from individual stored keys
originalUserDataStored := make(map[string]interface{})
@@ -2813,7 +2813,7 @@ func TestSessionStatePreservationWithExpiredTokens(t *testing.T) {
// Verify all session data is still intact after token refresh
postRefreshAuth := session2.GetAuthenticated()
postRefreshEmail := session2.GetEmail()
postRefreshEmail := session2.GetUserIdentifier()
userDataPresent := true
for k := range originalUserData {
if session2.mainSession.Values["user_data_"+k] == nil {
@@ -2907,7 +2907,7 @@ func TestSessionExpiryVsTokenExpiry(t *testing.T) {
// Set up session with specific creation time
session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
session.mainSession.Values["created_at"] = sessionCreatedAt.Unix()
// Create tokens with specific expiry
@@ -3018,7 +3018,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
// Set up session with data that should be preserved or removed
session.SetAuthenticated(true)
session.SetEmail("cleanup@example.com")
session.SetUserIdentifier("cleanup@example.com")
session.mainSession.Values["user_data"] = "Test User|user-123"
session.mainSession.Values["preferences"] = "theme:dark,lang:en"
@@ -3049,7 +3049,7 @@ func TestSessionCleanupOnTokenExpiry(t *testing.T) {
if scenario.shouldCleanup {
if sessionTooOld {
session.SetAuthenticated(false)
session.SetEmail("")
session.SetUserIdentifier("")
session.SetAccessToken("")
session.SetRefreshToken("")
for key := range session.mainSession.Values {
+71 -2
View File
@@ -55,6 +55,15 @@ type Config struct {
AllowedUsers []string `json:"allowedUsers"`
Headers []TemplatedHeader `json:"headers"`
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
// MaxRefreshTokenAgeSeconds is a heuristic upper bound on the lifetime of
// a stored refresh token. Once the token has been in the session longer
// than this, requests treat it as expired up-front - returning 401 to
// AJAX callers and triggering full re-auth on navigations - instead of
// hammering the IdP with grants that will only fail with invalid_grant.
// IdPs do not expose RT TTL on the wire, so this is intentionally a
// conservative heuristic; tune to match your provider configuration.
// Default 21600 (6h). Set to 0 to disable the check.
MaxRefreshTokenAgeSeconds int `json:"maxRefreshTokenAgeSeconds"`
SessionMaxAge int `json:"sessionMaxAge"`
RateLimit int `json:"rateLimit"`
OverrideScopes bool `json:"overrideScopes"`
@@ -84,6 +93,38 @@ type Config struct {
// providers. Enabling this in production is a security hole — prefer
// CACertPath/CACertPEM. Emits a loud warning at startup.
InsecureSkipVerify bool `json:"insecureSkipVerify,omitempty"`
// ClientAuthMethod selects the OAuth 2.0 client authentication method used
// at the token / revocation / introspection endpoints. Supported values:
//
// - "client_secret_post" (default, current behavior): clientSecret is
// sent in the request body alongside client_id.
// - "private_key_jwt" (RFC 7523 §2.2): the plugin signs a short-lived JWT
// assertion with a configured private key and sends it as
// client_assertion. Use this when your IdP enforces short-lived secrets
// or mandates secretless client auth (Entra ID, Okta, Auth0, Keycloak).
//
// When set to "private_key_jwt", clientSecret may be left empty and one of
// clientAssertionPrivateKey / clientAssertionKeyPath must be configured.
ClientAuthMethod string `json:"clientAuthMethod,omitempty"`
// ClientAssertionPrivateKey is an inline PEM-encoded private key used to
// sign client_assertion JWTs. Mutually exclusive with
// ClientAssertionKeyPath. Supports PKCS#8, PKCS#1 (RSA), and SEC1 (EC).
ClientAssertionPrivateKey string `json:"clientAssertionPrivateKey,omitempty"`
// ClientAssertionKeyPath is a filesystem path to a PEM-encoded private key,
// equivalent to ClientAssertionPrivateKey but loaded from disk.
ClientAssertionKeyPath string `json:"clientAssertionKeyPath,omitempty"`
// ClientAssertionKeyID is the JWK key id (kid) advertised in the JWS
// header. Required when using private_key_jwt so the IdP can locate the
// matching public key registered for the client.
ClientAssertionKeyID string `json:"clientAssertionKeyID,omitempty"`
// ClientAssertionAlg is the JWS signing algorithm. Defaults to RS256.
// Supported: RS256/384/512, PS256/384/512, ES256/384/512.
ClientAssertionAlg string `json:"clientAssertionAlg,omitempty"`
}
// loadCACertPool assembles an x509.CertPool from CACertPath and CACertPEM.
@@ -247,6 +288,7 @@ func CreateConfig() *Config {
EnablePKCE: false, // PKCE is opt-in
OverrideScopes: false, // Default to appending scopes, not overriding
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
MaxRefreshTokenAgeSeconds: 21600, // 6h - conservative heuristic, see field doc
SecurityHeaders: createDefaultSecurityConfig(),
Redis: nil, // Redis is disabled by default, configure via Traefik or env vars
}
@@ -313,8 +355,30 @@ func (c *Config) Validate() error {
if c.ClientID == "" {
return fmt.Errorf("clientID is required")
}
if c.ClientSecret == "" {
return fmt.Errorf("clientSecret is required")
authMethod := c.ClientAuthMethod
if authMethod == "" {
authMethod = "client_secret_post"
}
switch authMethod {
case "client_secret_post", "client_secret_basic":
if c.ClientSecret == "" {
return fmt.Errorf("clientSecret is required when clientAuthMethod is %q", authMethod)
}
case "private_key_jwt":
if c.ClientAssertionPrivateKey == "" && c.ClientAssertionKeyPath == "" {
return fmt.Errorf("clientAssertionPrivateKey or clientAssertionKeyPath is required when clientAuthMethod is private_key_jwt")
}
if c.ClientAssertionPrivateKey != "" && c.ClientAssertionKeyPath != "" {
return fmt.Errorf("only one of clientAssertionPrivateKey or clientAssertionKeyPath may be set")
}
if c.ClientAssertionKeyID == "" {
return fmt.Errorf("clientAssertionKeyID is required when clientAuthMethod is private_key_jwt")
}
if c.ClientAssertionAlg != "" && !isSupportedClientAssertionAlg(c.ClientAssertionAlg) {
return fmt.Errorf("clientAssertionAlg %q is not supported (use RS256/384/512, PS256/384/512, or ES256/384/512)", c.ClientAssertionAlg)
}
default:
return fmt.Errorf("clientAuthMethod %q is not supported", authMethod)
}
// Validate session encryption key
@@ -370,6 +434,11 @@ func (c *Config) Validate() error {
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
}
// Validate refresh-token max-age heuristic
if c.MaxRefreshTokenAgeSeconds < 0 {
return fmt.Errorf("maxRefreshTokenAgeSeconds cannot be negative")
}
// Validate audience if specified
if c.Audience != "" {
// Validate audience format - should be a valid identifier or URL
+1 -1
View File
@@ -293,7 +293,7 @@ func (tf *TestFramework) CreateAuthenticatedRequest(method, path string) (*http.
}
session.SetAuthenticated(true)
session.SetEmail(tf.fixtures.UserEmail)
session.SetUserIdentifier(tf.fixtures.UserEmail)
session.SetAccessToken(tf.fixtures.AccessToken)
session.SetRefreshToken(tf.fixtures.RefreshToken)
session.SetIDToken(tf.GenerateJWT(tf.fixtures.Claims))
+213 -29
View File
@@ -11,6 +11,7 @@ import (
"io"
"net/http"
"net/url"
"runtime"
"strings"
"time"
)
@@ -46,6 +47,17 @@ func (t *TraefikOidc) VerifyToken(token string) error {
}
}
// Hot-path fast-return: a previously-verified token has already passed
// signature, claims, and replay checks. Skipping the parseJWT cost here
// matters under bursty traffic (e.g. 10+ concurrent panel requests on
// every Grafana dashboard refresh) where the same token is validated
// dozens of times per second by validateStandardTokens.
if t.tokenCache != nil {
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
return nil
}
}
parsedJWT, parseErr := parseJWT(token)
if parseErr != nil {
return fmt.Errorf("failed to parse JWT for blacklist check: %w", parseErr)
@@ -63,12 +75,6 @@ func (t *TraefikOidc) VerifyToken(token string) error {
}
}
// Check token cache FIRST - if token is already verified and cached, return immediately
// This prevents false positives when multiple goroutines validate the same token concurrently
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
return nil
}
// Only check JTI blacklist for tokens that aren't already in the cache
// This is for FIRST-TIME validation to detect replay attacks
if jti, ok := parsedJWT.Claims["jti"].(string); ok && jti != "" {
@@ -335,7 +341,17 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
if err := verifySignatureWithKey(token, pubKey, alg); err != nil {
if !t.suppressDiagnosticLogs {
t.safeLogErrorf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s: %v", kid, alg, err)
// Microsoft Graph access tokens carry a `nonce` JWT header and are
// signed in a proprietary form Microsoft documents as unverifiable
// by client applications. They reach this path only when the
// per-provider classifier (validateAzureTokens) didn't catch them,
// so log at debug to keep the error stream actionable while still
// surfacing the cause for diagnostics.
if _, isMSProprietary := jwt.Header["nonce"]; isMSProprietary {
t.safeLogDebugf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s (Microsoft proprietary nonce header — token is opaque to clients): %v", kid, alg, err)
} else {
t.safeLogErrorf("DIAGNOSTIC: Signature verification failed for kid=%s, alg=%s: %v", kid, alg, err)
}
}
return fmt.Errorf("signature verification failed: %w", err)
}
@@ -416,7 +432,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
}
t.logger.Debugf("Attempting refresh with token starting with %s...", tokenPrefix)
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
newToken, err := t.coordinatedTokenRefresh(req, initialRefreshToken)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
@@ -428,7 +444,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
session.SetRefreshToken("")
session.SetAccessToken("")
session.SetIDToken("")
session.SetEmail("")
session.SetUserIdentifier("")
// Clear CSRF tokens as well to prevent any replay attacks
session.SetCSRF("")
session.SetNonce("")
@@ -470,12 +486,18 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
t.logger.Errorf("refreshToken failed: Failed to extract claims from refreshed token: %v", err)
return false
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Errorf("refreshToken failed: Email claim missing or empty in refreshed token")
return false
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
if userIdentifier == "" {
if t.userIdentifierClaim != "sub" {
userIdentifier, _ = claims["sub"].(string)
}
if userIdentifier == "" {
t.logger.Errorf("refreshToken failed: User identifier claim '%s' missing or empty in refreshed token", t.userIdentifierClaim)
return false
}
t.logger.Debugf("Configured claim '%s' not found in refreshed token, using 'sub' claim as fallback", t.userIdentifierClaim)
}
session.SetEmail(email)
session.SetUserIdentifier(userIdentifier)
// Get token expiry information for logging
var expiryTime time.Time
@@ -501,7 +523,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
session.SetAccessToken("")
session.SetIDToken("")
session.SetRefreshToken("")
session.SetEmail("")
session.SetUserIdentifier("")
return false
}
@@ -518,6 +540,91 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return true
}
// coordinatedTokenRefresh routes a refresh-token grant through the
// RefreshCoordinator so that concurrent requests sharing the same refresh
// token coalesce into a single upstream call. This prevents the thundering
// herd that yields invalid_grant when the IdP rotates refresh tokens.
//
// Falls back to a direct call when the coordinator is nil, which only
// happens in tests that build TraefikOidc literals without going through
// NewWithContext.
func (t *TraefikOidc) coordinatedTokenRefresh(req *http.Request, refreshToken string) (*TokenResponse, error) {
if t.refreshCoordinator == nil {
return t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
}
parentCtx := context.Background()
if req != nil {
parentCtx = req.Context()
}
ctx, cancel := context.WithTimeout(parentCtx, refreshCoordinatorWaitTimeout)
defer cancel()
sessionID := refreshCoordinatorSessionID(refreshToken)
return t.refreshCoordinator.CoordinateRefresh(
ctx,
sessionID,
refreshToken,
func() (*TokenResponse, error) {
// Cross-replica dedup. The in-process coordinator already
// collapses concurrent grants on this pod; this Redis-backed
// short-TTL cache covers the (rare) case of a failover or
// load-balancer reroute mid-refresh, where two pods would
// otherwise both POST the same refresh_token to the IdP.
if cached, ok := t.lookupCachedRefreshResult(sessionID); ok {
return cached, nil
}
resp, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
if err == nil && resp != nil {
t.cacheRefreshResult(sessionID, resp)
}
return resp, err
},
)
}
// lookupCachedRefreshResult returns a previously-stored TokenResponse for the
// given refresh-token hash, if one exists and is still within its short TTL.
// The cache wraps the universal cache, which is Redis-backed in production -
// so a "hit" here means another Traefik replica refreshed this same token
// within the last few seconds.
func (t *TraefikOidc) lookupCachedRefreshResult(sessionID string) (*TokenResponse, bool) {
if t.refreshResultCache == nil {
return nil, false
}
v, ok := t.refreshResultCache.Get(refreshResultCacheKey(sessionID))
if !ok || v == nil {
return nil, false
}
if tr, ok := v.(*TokenResponse); ok && tr != nil {
return tr, true
}
return nil, false
}
// cacheRefreshResult stores the new TokenResponse under the refresh-token
// hash for a short window. TTL is intentionally tight: the rotated refresh
// token cannot be re-presented to the IdP, and any peer waiting longer than
// this window has almost certainly given up via its own coordinator timeout.
func (t *TraefikOidc) cacheRefreshResult(sessionID string, resp *TokenResponse) {
if t.refreshResultCache == nil || resp == nil {
return
}
t.refreshResultCache.Set(refreshResultCacheKey(sessionID), resp, refreshResultCacheTTL)
}
// refreshResultCacheKey namespaces refresh-result entries inside the shared
// cache namespace.
func refreshResultCacheKey(sessionID string) string {
return "rt-result:" + sessionID
}
// refreshResultCacheTTL bounds how long a peer can lean on the dedup cache.
// Long enough for a sibling replica to observe the result, short enough that
// a stale entry never re-supplies a token after the IdP has already moved on.
const refreshResultCacheTTL = 5 * time.Second
// RevokeToken revokes a token locally by adding it to the blacklist cache.
// It removes the token from the verification cache and adds both the token
// and its JTI (if present) to the blacklist to prevent future use.
@@ -563,11 +670,33 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
}
t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, revocationURL)
// Read tokenURL with RLock — used as audience for private_key_jwt (RFC 7523 §3).
t.metadataMu.RLock()
tokenURL := t.tokenURL
t.metadataMu.RUnlock()
data := url.Values{
"token": {token},
"token_type_hint": {tokenType},
"client_id": {t.clientID},
"client_secret": {t.clientSecret},
}
// client_id is sent in the body for every method except client_secret_basic,
// where it is carried in the Authorization header per RFC 6749 §2.3.1.
if t.clientAuthMethod != "client_secret_basic" || t.clientAssertion != nil {
data.Set("client_id", t.clientID)
}
useBasicAuth := false
if t.clientAssertion != nil {
assertion, err := t.clientAssertion.Sign(tokenURL, t.clientID)
if err != nil {
return fmt.Errorf("failed to sign client assertion: %w", err)
}
data.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
data.Set("client_assertion", assertion)
} else if t.clientAuthMethod == "client_secret_basic" {
useBasicAuth = true
} else {
data.Set("client_secret", t.clientSecret)
}
req, err := http.NewRequestWithContext(context.Background(), "POST", revocationURL, strings.NewReader(data.Encode()))
@@ -577,6 +706,9 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
if useBasicAuth {
setOAuthBasicAuth(req, t.clientID, t.clientSecret)
}
// Send the request with circuit breaker protection if available
var resp *http.Response
@@ -663,6 +795,27 @@ func (t *TraefikOidc) isGoogleProvider() bool {
return strings.Contains(issuerURL, "google") || strings.Contains(issuerURL, "accounts.google.com")
}
// isUnverifiableAzureAccessToken reports whether a JWT-shaped access token
// matches the Microsoft proprietary format that client applications must not
// validate. Microsoft injects a `nonce` value into the JWT header, signs over
// the SHA256 hash of that nonce, and ships the original nonce on the wire,
// guaranteeing that any standard JWS verifier rejects the signature. This is
// the documented mechanism that keeps access tokens opaque to non-resource
// holders (Microsoft Graph, Azure Management API).
//
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
//
// Returns true on parse failure as well — a token we cannot parse should not
// be passed through the verification path that emits ERROR logs.
func (t *TraefikOidc) isUnverifiableAzureAccessToken(token string) bool {
parsed, err := parseJWT(token)
if err != nil {
return true
}
_, hasProprietaryNonce := parsed.Header["nonce"]
return hasProprietaryNonce
}
// isAzureProvider detects if the configured OIDC provider is Azure AD.
// It checks the issuer URL for Microsoft Azure AD domains.
// Returns:
@@ -705,6 +858,31 @@ func (t *TraefikOidc) validateAzureTokens(session *SessionData) (bool, bool, boo
if accessToken != "" {
if strings.Count(accessToken, ".") == 2 {
// Microsoft documents that client apps cannot validate access
// tokens issued for Microsoft-owned APIs (Graph, Azure Mgmt) due
// to their proprietary signing format (nonce in JWT header is
// the marker — signed bytes hash the nonce, wire bytes ship the
// raw value, so rsa verification always fails). Treat such
// tokens as opaque, matching Microsoft's guidance and avoiding
// per-request signature-error log spam (issue #134 followup).
//
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
// "you can't validate tokens for Microsoft Graph according to
// these rules due to their proprietary format"
if t.isUnverifiableAzureAccessToken(accessToken) {
t.logger.Debug("Azure access token is Microsoft-proprietary (Graph/Mgmt) — treating as opaque per Microsoft guidance")
if idToken != "" {
if err := t.verifyToken(idToken); err != nil {
t.logger.Debugf("Azure: ID token validation failed while access token was opaque: %v", err)
if session.GetRefreshToken() != "" {
return false, true, false
}
return false, false, true
}
return t.validateTokenExpiry(session, idToken)
}
return true, false, false
}
if err := t.verifyToken(accessToken); err != nil {
if idToken != "" {
if err := t.verifyToken(idToken); err != nil {
@@ -1103,9 +1281,14 @@ func (t *TraefikOidc) startTokenCleanup() {
sessionManager := t.sessionManager
logger := t.logger
// Only use the fast cleanup interval when actually running under `go test`.
// runtime.Compiler == "yaegi" makes isTestMode() return true in production
// (Traefik interprets the plugin via yaegi), which would otherwise pin this
// ticker to 20 Hz on a real cluster despite tokenCache.Cleanup and
// jwkCache.Cleanup both being no-ops there.
cleanupInterval := 1 * time.Minute
if isTestMode() {
cleanupInterval = 50 * time.Millisecond // Fast interval for tests
if isTestMode() && runtime.Compiler != "yaegi" {
cleanupInterval = 50 * time.Millisecond
}
// Create cleanup function
@@ -1147,25 +1330,27 @@ func (t *TraefikOidc) startTokenCleanup() {
}
// extractGroupsAndRoles extracts group and role information from token claims.
// It parses the 'groups' and 'roles' claims from the ID token and validates their format.
// Parameters:
// - idToken: The ID token containing claims to extract.
// It parses the configured group/role claims from the supplied ID token.
//
// Returns:
// - groups: Array of group names from the 'groups' claim.
// - roles: Array of role names from the 'roles' claim.
// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present
// but not arrays of strings.
// Most callers should prefer extractGroupsAndRolesFromClaims when claims have
// already been parsed for the request (e.g. via SessionData.GetIDTokenClaims),
// to avoid re-parsing the JWT.
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
}
return t.extractGroupsAndRolesFromClaims(claims)
}
// extractGroupsAndRolesFromClaims extracts group and role information from
// already-parsed claims. Hot path: callers that have a cached claims map (such
// as SessionData.GetIDTokenClaims) should use this to skip a redundant
// base64+JSON decode of the JWT on every authenticated request.
func (t *TraefikOidc) extractGroupsAndRolesFromClaims(claims map[string]interface{}) ([]string, []string, error) {
var groups []string
var roles []string
// Extract groups using configurable claim name (defaults to "groups")
if groupsClaim, exists := claims[t.groupClaimName]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
@@ -1181,7 +1366,6 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
}
}
// Extract roles using configurable claim name (defaults to "roles")
if rolesClaim, exists := claims[t.roleClaimName]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
+5
View File
@@ -95,6 +95,7 @@ type TraefikOidc struct {
cancelFunc context.CancelFunc
errorRecoveryManager *ErrorRecoveryManager
tokenResilienceManager *TokenResilienceManager
refreshCoordinator *RefreshCoordinator
goroutineWG *sync.WaitGroup
dcrConfig *DynamicClientRegistrationConfig
dynamicClientRegistrar *DynamicClientRegistrar
@@ -118,17 +119,21 @@ type TraefikOidc struct {
audience string
clientID string
clientSecret string
clientAuthMethod string
clientAssertion *ClientAssertionSigner
registrationURL string
backchannelLogoutPath string
frontchannelLogoutPath string
scopesSupported []string
scopes []string
refreshGracePeriod time.Duration
maxRefreshTokenAge time.Duration
metadataMu sync.RWMutex
shutdownOnce sync.Once
metadataRetryMutex sync.Mutex
firstRequestMutex sync.Mutex
sessionInvalidationCache CacheInterface
refreshResultCache CacheInterface
minimalHeaders bool
stripAuthCookies bool
enableBackchannelLogout bool
+32
View File
@@ -252,6 +252,25 @@ func (c *UniversalCache) Set(key string, value interface{}, ttl time.Duration) e
}
}
return c.setLocal(key, value, ttl)
}
// SetLocal stores a value only in the in-memory LRU, bypassing any
// distributed backend. Use for values that don't survive JSON round-tripping
// — interfaces holding concrete crypto keys, *big.Int, or types whose
// unexported fields yaegi exposes under an X prefix on Marshal. Each replica
// caches independently; correctness must not depend on cross-replica
// coherence for these keys.
func (c *UniversalCache) SetLocal(key string, value interface{}, ttl time.Duration) error {
if ttl == 0 {
ttl = c.config.DefaultTTL
}
return c.setLocal(key, value, ttl)
}
// setLocal performs the in-memory portion of a write. ttl must already be
// resolved against DefaultTTL by the caller.
func (c *UniversalCache) setLocal(key string, value interface{}, ttl time.Duration) error {
size := c.estimateSize(value)
c.mu.Lock()
@@ -343,6 +362,19 @@ func (c *UniversalCache) Get(key string) (interface{}, bool) {
}
}
return c.getLocal(key)
}
// GetLocal retrieves a value only from the in-memory LRU, never querying the
// distributed backend. Pair with SetLocal for values that aren't safe to
// serialize (see SetLocal docstring).
func (c *UniversalCache) GetLocal(key string) (interface{}, bool) {
return c.getLocal(key)
}
// getLocal returns the in-memory entry for key honoring expiry, grace
// periods, and the RLock fast path used by token/JWK/session caches.
func (c *UniversalCache) getLocal(key string) (interface{}, bool) {
// Fast read path for caches whose eviction is dominated by TTL rather than
// access-recency (token, JWK, session). Holding only an RLock here lets all
// concurrent readers verify cached tokens in parallel — under yaegi the
+40 -1
View File
@@ -23,6 +23,7 @@ type UniversalCacheManager struct {
metadataCache *UniversalCache
dcrCredentialsCache *UniversalCache // DCR credentials storage for distributed environments
sessionInvalidationCache *UniversalCache // Session invalidation cache for backchannel/front-channel logout
refreshResultCache *UniversalCache // Short-lived cross-replica refresh-result dedup (paired with RefreshCoordinator)
logger *Logger
blacklistCache *UniversalCache
cancel context.CancelFunc
@@ -181,6 +182,18 @@ func initializeDefaultCaches(manager *UniversalCacheManager, logger *Logger) {
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
})
// Refresh-result cache: short-lived store keyed by sha256(refreshToken).
// In Redis-backed mode this gives cross-replica dedup of refresh grants;
// in memory-only mode it's effectively redundant with RefreshCoordinator
// but safe and cheap to keep.
manager.refreshResultCache = NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 1000,
DefaultTTL: 5 * time.Second,
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
})
}
// initializeCachesWithRedis initializes caches with Redis/Hybrid backends based on configuration
@@ -197,6 +210,8 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
RedisPrefix: redisConfig.KeyPrefix,
PoolSize: redisConfig.PoolSize,
EnableMetrics: true,
EnableTLS: redisConfig.EnableTLS,
TLSSkipVerify: redisConfig.TLSSkipVerify,
}
// Use concrete type to avoid Yaegi reflection issues with interface assignment
@@ -387,6 +402,21 @@ func initializeCachesWithRedis(manager *UniversalCacheManager, logger *Logger, r
createBackend("session_invalidation"),
)
// Refresh-result cache - shared via Redis so concurrent refreshes across
// Traefik replicas can dedup their grants. The 5s TTL is long enough for
// peers to observe a recent refresh and short enough that a stale entry
// can't be replayed against a now-rotated refresh token.
manager.refreshResultCache = NewUniversalCacheWithBackend(
UniversalCacheConfig{
Type: CacheTypeToken,
MaxSize: 1000,
DefaultTTL: 5 * time.Second,
Logger: logger,
SkipAutoCleanup: true, // Managed cleanup
},
createBackend("refresh_result"),
)
logger.Infof("Cache manager initialized with %s backend configuration", redisConfig.CacheMode)
}
@@ -436,6 +466,7 @@ func (m *UniversalCacheManager) performConsolidatedCleanup() {
m.tokenTypeCache,
m.dcrCredentialsCache,
m.sessionInvalidationCache,
m.refreshResultCache,
}
m.mu.RUnlock()
@@ -498,6 +529,14 @@ func (m *UniversalCacheManager) GetSessionInvalidationCache() *UniversalCache {
return m.sessionInvalidationCache
}
// GetRefreshResultCache returns the short-lived refresh-result cache used to
// coalesce refresh-token grants across Traefik replicas.
func (m *UniversalCacheManager) GetRefreshResultCache() *UniversalCache {
m.mu.RLock()
defer m.mu.RUnlock()
return m.refreshResultCache
}
// GetDCRCredentialsCache returns the DCR credentials cache for distributed storage
func (m *UniversalCacheManager) GetDCRCredentialsCache() *UniversalCache {
m.mu.RLock()
@@ -520,7 +559,7 @@ func (m *UniversalCacheManager) Close() error {
// Close all caches first (they won't close the shared backend)
for _, cache := range []*UniversalCache{
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache,
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache, m.dcrCredentialsCache, m.sessionInvalidationCache, m.refreshResultCache,
} {
if cache != nil {
_ = cache.Close() // Safe to ignore: best effort cache cleanup
+5
View File
@@ -250,6 +250,11 @@ func (t *TraefikOidc) Close() error {
t.safeLogDebug("metadataRefreshStopChan closed")
}
if t.refreshCoordinator != nil {
t.refreshCoordinator.Shutdown()
t.safeLogDebug("refreshCoordinator shut down")
}
if t.goroutineWG != nil {
done := make(chan struct{})
go func() {
+62
View File
@@ -0,0 +1,62 @@
package traefikoidc
import (
"testing"
"time"
"golang.org/x/time/rate"
)
// TestVerifyToken_CacheHitSkipsParse proves the hot-path optimization: when a
// token is in the cache, VerifyToken returns nil without calling parseJWT.
// We construct a token that PASSES the cheap format checks (3 segments, len
// >= 10) but whose body is unparseable JSON. With the cache hit hoisted ahead
// of parseJWT, the function returns nil. Without the hoist, parseJWT would
// fail with "failed to parse JWT for blacklist check".
func TestVerifyToken_CacheHitSkipsParse(t *testing.T) {
tr := &TraefikOidc{
logger: NewLogger("error"),
tokenCache: NewTokenCache(),
// limiter intentionally absent; if we reached the rate-limit check
// the test would NPE - this is a stronger assertion that we exit
// before that point.
limiter: rate.NewLimiter(rate.Inf, 1),
}
tr.tokenVerifier = tr
// Three segments separated by '.', body is junk after base64-decode + JSON.
// Pre-fix this fails parseJWT; post-fix it returns nil because the cache
// short-circuits.
junkToken := "header.bm90LWpzb24.signature" // base64(not-json) in the middle
tr.tokenCache.Set(junkToken, map[string]interface{}{
"exp": float64(time.Now().Add(time.Hour).Unix()),
"sub": "test",
}, time.Hour)
if err := tr.VerifyToken(junkToken); err != nil {
t.Fatalf("expected cache-hit fast path to return nil, got: %v", err)
}
}
// TestVerifyToken_CacheMissStillParses ensures we did not skip too aggressively
// - on a cache miss, the function must still parse and reach the rate-limit
// check. We assert by passing a syntactically valid token whose signature
// won't verify, expecting an error from later in the pipeline.
func TestVerifyToken_CacheMissStillParses(t *testing.T) {
tr := &TraefikOidc{
logger: NewLogger("error"),
tokenCache: NewTokenCache(),
limiter: rate.NewLimiter(rate.Inf, 1),
// no tokenBlacklist, no jwkCache - the function will fail somewhere
// after parseJWT. We just need a non-nil error to confirm we did
// progress past the cache check.
}
tr.tokenVerifier = tr
// Real JWT structure but unsigned/unverifiable.
rawToken := "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature"
if err := tr.VerifyToken(rawToken); err == nil {
t.Fatal("expected an error past parseJWT for an unsigned token, got nil")
}
}