Compare commits

..

36 Commits

Author SHA1 Message Date
lukaszraczylo 546ceb949c security: remediate audit findings (ranks 1–16 + 22 Lows) + yaegi load validation (#144)
* fix(security): encrypt session cookies + fail closed on invalid config

Batch 1 of security audit remediation (ranks 1, 2, 6).

- session.go: derive independent HMAC + AES-256 keys via stdlib HKDF-SHA256
  and build the gorilla cookie store with both, so session cookies are now
  encrypted, not merely signed. The single-key store previously left OIDC
  access/refresh/ID tokens recoverable from raw cookie bytes. Cookie format
  changes, so existing sessions are invalidated on deploy (one-time re-login).
- main.go: call config.Validate() at construction and error out on failure,
  instead of silently substituting a public hardcoded encryption key for
  empty/short keys (which allowed session forgery). The yaegi analyzer
  passes via .traefik.yml testData.
- settings.go: isValidSecureURL permits plaintext HTTP for loopback hosts
  only (RFC 8252); remote providers must still use HTTPS.
- tests: complete configs that did not satisfy Validate(); add regression
  tests in security_audit_fixes_test.go.

Configs below documented minimums (rateLimit < 10, key < 32 chars) are now
rejected at startup (fail closed).

* fix(security): validate discovered OIDC endpoints + pin introspection host

Batch 2 of security audit remediation (ranks 3, 4).

- url_helpers.go: add validateDiscoveredEndpoint, an SSRF screen for endpoints
  taken from the provider discovery document (jwks_uri, token, authorization,
  revocation, end_session, introspection, registration). Blocks link-local
  (cloud metadata 169.254.169.254), multicast, unspecified and private
  addresses (unless allowPrivateIPAddresses); blocks loopback unless the
  configured providerURL is itself loopback (dev/test). Cross-domain JWKS
  hosts (e.g. Google) stay allowed. Add sameHost helper.
- main.go: updateMetadataEndpoints screens every discovered endpoint and
  blanks any that fail (fail closed downstream). The introspection endpoint
  carries the client secret via HTTP Basic, so it is additionally pinned to
  the providerURL host to stop a poisoned discovery document exfiltrating the
  secret to an attacker-controlled host.
- tests: regression tests for the SSRF guard and the host pin.

* fix(security): close open redirects + anchor excluded-URL matching

Batch 3 of security audit remediation (ranks 5, 14, 15).

- auth_flow.go: run the stored incoming path through normalizeLogoutPath
  before using it as the post-login redirect, so //evil.com and /\evil.com
  payloads become host-relative (open-redirect, rank 5).
- url_helpers.go: excluded-URL matching is anchored at a natural boundary
  (exact, sub-path "/", or file extension "."), so excluding "/public" no
  longer also bypasses auth on "/publicsecret"; "/favicon" still matches
  "/favicon.ico" (rank 14).
- internal/utils: X-Forwarded-Host is sanitized (first value only; reject
  CRLF/whitespace/multi-value) before building redirect URLs (rank 15).
- helpers.go: the logout redirect used when there is no provider end-session
  endpoint is host-relative, never an absolute URL derived from the
  client-controllable request host (logout open-redirect, rank 15).
- tests: update two logout cases that asserted the old absolute redirect;
  add regression tests.

* fix(security): reject unverified Azure tokens; fix transport TLS reuse

Batch 4 of security audit remediation (ranks 7, 11).

- token_validation_rs.go: an Azure nonce-bearing access token that cannot be
  cryptographically verified no longer returns "authenticated" when there is
  no ID token to corroborate it; it refreshes (if possible) or forces
  re-authentication instead of failing open (rank 7).
- http_client_pool.go: the at-limit transport-reuse path now takes the write
  lock before mutating refCount (fixes a data race) and only reuses a
  transport whose TLS settings (CA pool + InsecureSkipVerify) match the
  caller's, never one with a different trust store; if none matches it returns
  nil so the caller falls back to a verifying default transport (rank 11).
- tests: add a transport-pool TLS-isolation regression test.

* fix(security): stop logging templated header values (token leak)

Batch 5 of security audit remediation (rank 16).

middleware.go: templated downstream headers commonly carry the access token
(e.g. "Authorization: Bearer {{.AccessToken}}"). The debug log line printed
the full header value, leaking credentials into logs. Log the header name and
byte length instead.

* fix(security): cache-key collision, cache-config divergence, fleet cleanup

Batch 6 of security audit remediation (ranks 9, 10, 12).

- token_manager.go: detectTokenType keys its cache on a SHA-256 hash of the
  full token instead of the first 32 chars (which are only the base64url JWT
  header). Distinct tokens sharing alg+kid no longer collide and get
  mis-classified (rank 10).
- cache_manager.go: the process-global cache manager is initialized once and
  shared across plugin instances; it now logs a loud warning when a later
  instance requests a different explicit Redis backend that is silently
  ignored, surfacing the cross-instance state-isolation hazard (rank 9).
- singleton_resources.go / main.go / utilities.go: track a process-global live
  instance count; the shared singleton-token-cleanup task is stopped only when
  the LAST instance shuts down, so one instance's Close() (e.g. a config reload)
  no longer kills cleanup for surviving instances (rank 12).
- tests: update TestDetectTokenTypeCaching for the new key; add regression tests.

* fix(security): bound introspection cache + cookie lifetime to config

Batch 7 of security audit remediation (ranks 8, 13).

- token_introspection.go: when requireTokenIntrospection is enabled, cap the
  positive introspection-result cache at 30s (instead of 5m) so a token
  revoked at the provider stops passing within ~30s, matching the operator's
  near-real-time revocation expectation (rank 8).
- session.go: bind the cookie store's MaxAge to the configured sessionMaxAge,
  so the cookie codec's cryptographic timestamp validity is no longer fixed at
  gorilla's 30-day default; a stolen cookie is valid only for the configured
  session lifetime (rank 13).
- tests: add a cookie-lifetime regression test.

* fix(security): low-severity hardening (cache, DoS caps, PKCE, throttle)

Batch 8 of security audit remediation — low severity
(ranks 24, 25, 27, 29, 31, 36, 37, 41, 45, 46, 49).

- universal_cache.go: updateLocalCache updates an existing key in place instead
  of orphaning its LRU element and double-counting currentSize/currentMemory
  (rank 36 — the only production-reachable bug in this batch).
- jwk.go / metadata_cache.go / token_introspection.go: bound response bodies
  with io.LimitReader (1 MiB) to prevent memory exhaustion from a hostile or
  buggy provider (ranks 24, 25).
- jwk.go: skip JWKs not usable for signature verification (use != sig, or
  key_ops without "verify") when building the key set (rank 49).
- auth_flow.go: fail closed at the callback when PKCE is enabled but the code
  verifier is missing, instead of silently dropping it (rank 27).
- utilities.go / main.go: match allowedUserDomains case-insensitively (rank 31).
- bearer_auth.go: a single success no longer wipes an active per-IP penalty;
  the counter resets only when no penalty is in effect (rank 29).
- main.go: handle (not discard) the NewSessionManager error (rank 37).
- error_recovery.go: take a write lock in isServiceDegraded (it deletes from a
  map); compare retryable-error substrings case-insensitively (ranks 45, 46).
- singleton_resources.go: bind the generic-cache cleanup goroutine to the
  resource-manager shutdown channel so it cannot outlive its owner (rank 41).
- tests: update the bearer throttle test to the corrected penalty semantics.

* fix(security): header sanitization, issuer pinning, fail-closed paths

Batch 9 of security audit remediation (ranks 18, 19, 20, 21, 22, 30, 33, 34).

- middleware.go / bearer_auth.go: sanitize claim-derived values on the cookie
  auth path before injecting them into downstream headers. Drop group/role and
  identifier values containing control chars, bidi-override runes, or the
  , ; = delimiters (a comma would inject phantom entries into X-User-Groups);
  reject control/bidi/over-length in rendered templated header output (but
  permit , ; = in free-form values such as a bearer token). The bearer path
  already sanitized; the cookie path did not (ranks 33, 34).
- main.go / metadata_cache.go: pin the discovered issuer to the configured
  provider host (sameHost) and refuse/never-cache a mismatch, so a poisoned
  discovery document cannot redefine the JWT trust anchor (ranks 21, 22).
- token_introspection.go: when a distinct API audience is configured, fail
  closed on a missing or mismatched introspection audience; aud parsed as
  string-or-array per RFC 7662 (rank 19).
- logout.go: front-channel logout requires a matching issuer; an empty iss is
  rejected (blocks unauthenticated forced-logout via a known sid) (rank 30).
- token_validation_rs.go: an opaque access token with no ID token and no
  successful introspection fails closed (re-auth) instead of authenticating
  (ranks 18, 20).
- tests: realistic same-host provider mocks; regression tests for the header
  sanitization distinction and the fail-closed paths.

* chore(security): remove unwired dead code with latent footguns

Batch 10 of security audit remediation — delete confirmed-dead, unwired
subsystems (ranks 26, 35, 50). None had a production caller (grep-verified);
removal eliminates the latent footguns and ~2.1k lines of dead code.

- token_validator.go (deleted): an unused *TokenValidator whose validateJWT set
  Valid=true with NO signature verification — a severe footgun if ever wired
  (rank 50). The wired RS-aware validators are unaffected.
- security_monitoring.go (deleted): an unused *SecurityMonitor / ExtractClientIP
  that trusted spoofable X-Forwarded-For / X-Real-IP. The live bearer throttle
  uses clientIPForBearer (RemoteAddr-only), unchanged (rank 35).
- dynamic_client_registration.go: removed the RFC 7592 management methods
  (Update/Read/DeleteClientRegistration) that dereferenced an attacker-
  influenced RegistrationClientURI with the registration token attached and no
  HTTPS/SSRF gate, and had no callers. The wired RFC 7591 RegisterClient and
  credential-store helpers are kept (rank 26).
- tests: removed the tests covering the deleted code.

* chore: add Makefile with yaegi load validation

No Makefile existed. The new `yaegi-validate` target interprets the plugin
under the yaegi interpreter the same way Traefik loads it, catching yaegi-only
incompatibilities (unsupported stdlib symbols, reflection edge cases) that the
native `go build` / `go test` toolchain does not. Importing the plugin forces
yaegi to interpret every file plus its vendored deps; CreateConfig + New
exercise the instantiation path.

- cmd/yaegicheck/main.go: the load driver, marked //go:build ignore so it is
  excluded from `go build ./...` (avoids VCS-stamping a main binary, which
  fails in git-worktree layouts) yet is run explicitly by yaegi.
- Makefile: build / fmt / vet / lint / test / vendor / yaegi-validate / check
  targets; `make check` runs vet + tests + yaegi-validate.

Verified: `make yaegi-validate` passes on this branch — the HKDF cookie
encryption, net-based endpoint validation, and claim sanitizers all interpret
and instantiate cleanly under yaegi.

* ci: bump workflow Go toolchain to 1.25; pin yaegi-validate to v0.16.1

Traefik v3.7.1 (the deployed version) is built with `go 1.25.0`, so the PR and
release workflows now use Go 1.25.x to match the toolchain Traefik uses.

Important distinction: the CI Go version is the build TOOLCHAIN. The plugin's
actual interpreter-compatibility ceiling is the yaegi version Traefik bundles
(v0.16.1, which declares go 1.21 and ships a ~Go 1.22 stdlib symbol surface),
NOT the CI Go version. That ceiling is enforced by `make yaegi-validate` plus
the go.mod language directive — e.g. it is why HKDF is hand-rolled with
hmac+sha256 rather than Go 1.24's crypto/hkdf, which yaegi v0.16.1 lacks.

Also pin Makefile YAEGI_VERSION to v0.16.1 (what Traefik v3.7.1 vendors) so
yaegi-validate exercises the real deployed interpreter instead of @latest,
which could pass on a newer yaegi that supports symbols the deployed one does
not.

* docs: align README/CONFIGURATION with branch behavior changes

- excludedURLs: documented as segment/extension-boundary matching (was
  "prefix-matched") — "/public" no longer also matches "/publicsecret" (rank 14).
- Front-channel logout now requires a matching `iss`; requests without one are
  rejected with 400 (rank 30).
- Add an "Upgrading from an earlier release" note: session cookies are now
  AES-256 encrypted with lifetime tracking sessionMaxAge (one-time re-login on
  upgrade), and invalid configuration (rateLimit < 10, key < 32 bytes, missing
  callbackURL, non-HTTPS remote providerURL) now fails closed at startup.

* fix: remove staticcheck-flagged unused functions; wire staticcheck into make check

CI Static Analysis (standalone staticcheck) failed with U1000 "unused":
- dynamic_client_registration.go: deleteCredentialsFromStore — its only caller
  was the RFC 7592 DeleteClientRegistration removed in the dead-code batch.
- token_test.go: createTestJWTSimple — its only callers were the TokenValidator
  tests removed in the same batch.
Both confirmed to have zero remaining callers and removed. build / vet /
go test ./... / staticcheck ./... all green.

The pre-commit hook runs golangci-lint, but CI runs standalone staticcheck
(which flags U1000). Add a `staticcheck` Makefile target and include it in
`make check` so this class of finding is caught locally before push.

* fix(test): stabilize flaky TestWorkerPool_TaskPanic

tasksFailed is incremented in the worker's deferred recover(), which runs after the panicking task's own defer wg.Done(). wg.Wait() could therefore return before the failure was recorded, so reading the counter immediately raced and flaked on slow CI runners. Poll until the failure lands (2s budget) instead. Verified 200x plain + 50x under -race/GOMAXPROCS=1.
2026-05-30 14:10:32 +01:00
lukaszraczylo f75b2f20e0 fix: resolve cache eviction lock-up and migrate telemetry [patch-release]
universal_cache: stop the write-lock convoy / 100%-CPU spin (observed via pprof: one ServeHTTP goroutine holding c.mu.Lock for hours while 119 requests queued). The per-request populate path (updateLocalCache) PushFronted a duplicate LRU node + overwrote items[key] without removing the prior node; once eviction deleted the key, orphan nodes at Back() were never removable and the eviction loop spun forever under the write lock. Replace the entry in place (mirroring setLocal) and harden evictOldest with a forward-progress guard. Adds universal_cache_orphan_test.go.

telemetry: delete the hand-rolled client; call oss-telemetry v0.2.3 (vendored, Yaegi-safe) directly from New(), once per process via sync.Once.

version: add version.go + workflow-prepare.sh so the release semver is stamped into source at build time (the value cannot be resolved at runtime under Yaegi). dev/source builds keep the 0.0.0-dev sentinel and emit no telemetry.
2026-05-30 13:22:03 +01:00
paiking1 cf6ed1da55 feat: feat: add extraAuthParams (extra authorization request parameters) (#139)
Adds optional extraAuthParams map[string]string config.

Extra params are appended to the authorization request but can never
override plugin-managed params (client_id, state, nonce, etc.).
2026-05-27 21:41:09 +01:00
lukaszraczylo f821b8829b fix: remove write-lock convoy in getLocal + fix mutateState CAS bug
UniversalCache.getLocal(): when a cached token expires, the RLock fast
path (line 385-398) previously fell through to c.mu.Lock() (write lock).
Under Yaegi, the write-lock holder takes 10-100ms for LRU manipulation,
and Go's RWMutex writer-priority blocks ALL new RLock callers. A single
expired-token event turned every concurrent request from read-parallel
into write-serialized — the convoy that produced the 737-goroutine
pileup at 0x400275a608 (pprof captured at /tmp/traefik-spike-1779663149).

Fix: return (nil, false) immediately on expiry for Token/JWK/Session
cache types. The periodic cleanup goroutine handles eviction. Write lock
is never taken on the read path for these cache types.

refreshAttemptTracker.mutateState(): the CAS loop used
t.state.CompareAndSwap(t.state.Load(), next) — a second Load that can
see a different value from a concurrent writer, silently overwriting
their update. Fixed to CompareAndSwap(cur, next) using the snapshot we
computed the mutation from.
2026-05-25 00:06:47 +01:00
lukaszraczylo 5f9c574f95 refactor: delete dead non-RS validators; tests use RS variants
After v1.0.20 the non-RS validation chain had no production callers —
middleware.ServeHTTP dispatched exclusively through isUserAuthenticatedRS.
The orphaned functions stayed reachable only from a handful of test
files and risked silent logic drift against their RS counterparts.

Deleted from production code (~440 LOC):
  - auth_flow.go:        isUserAuthenticated
  - token_manager.go:    validateAzureTokens
  - token_manager.go:    validateGoogleTokens
  - token_manager.go:    validateStandardTokens
  - token_manager.go:    validateTokenExpiry
  - removed now-unused encoding/base64 and encoding/json imports
    from token_manager.go (only the deleted validateStandardTokens
    needed them; the RS variant in token_validation_rs.go keeps its
    own imports).

Added (3 LOC):
  - token_validation_rs.go: validateGoogleTokensRS (trivial delegator,
    parity with the deleted non-RS variant so isUserAuthenticatedRS
    can dispatch cleanly).

Tests ported (10 call sites across 3 files):
  - audience_test.go:                ts.tOidc.validateStandardTokens
  - azure_oidc_test.go:              tOidc.validateAzureTokens,
                                     ts.tOidc.validateGoogleTokens,
                                     ts.tOidc.validateAzureTokens,
                                     ts.tOidc.isUserAuthenticated
  - issue134_followup_graph_test.go: oidc.validateAzureTokens (4x)

Each ported site now constructs a *requestState from its existing
*SessionData via (&requestState{}).captureSession(session) and calls
the *RS variant. Same data, different read source.

Net diff: -440 LOC production, ~+25 LOC tests, +3 LOC stub.
Production now has a single source of truth for token validation;
no parallel implementations to keep in sync.

All tests pass with -race; golangci-lint clean.
2026-05-23 13:04:26 +01:00
lukaszraczylo 7c6f09fb20 feat(middleware): RS-aware token validators (kill ~21 RLocks/request)
Adds token_validation_rs.go with requestState-aware variants of the
token validation path:

  isUserAuthenticatedRS(rs) -> dispatches by provider
    validateStandardTokensRS(rs) -> standard path (eliminates 17 RLocks)
    validateAzureTokensRS(rs)    -> Azure path     (eliminates 10 RLocks)
    validateGoogleTokensRS(rs)   -> delegates to standard
  validateTokenExpiryRS(rs, tok) -> shared expiry check (eliminates 4 RLocks)

middleware.ServeHTTP now calls isUserAuthenticatedRS(rs) on the hot
path. The pre-v1.0.20 non-RS variants are kept untouched for tests
and any future caller that doesn't have a captured snapshot.

Why
---
The standard validation path read SessionData via session.GetX() 17
times, with GetRefreshToken alone called 11 times (every "return
'needs refresh'" branch re-reads it). Each call acquires
sd.sessionMutex.RLock(). Under Yaegi each RLock costs ~1-5ms of
interpreter dispatch. The captured snapshot already lives on rs, so
the RS variants substitute direct struct field reads.

Per-request cost on the hot authenticated path
----------------------------------------------
  ServeHTTP enters:
    + 1 RLock to populate rs (was 0)
  Validation path:
    Standard: was 17 RLocks, now 0
    Azure:    was 10 RLocks, now 0
  processAuthorizedRequestRS:
    was 4-6 GetX calls, now 0 (already in v1.0.19)

Net: ~22-27 fewer Yaegi-dispatched RLock acquisitions per authenticated
request on the hot path.

Caveats
-------
* Refresh / expired / callback paths still use the non-RS validators
  because they can mutate session state between validation and use.
* The RS variants are by-design line-for-line equivalents of the
  originals. If logic in the originals changes, the RS variants need
  matching updates. This is acceptable for now; a future refactor
  could collapse them once the non-RS callers are gone.

All tests pass with -race; golangci-lint clean.
2026-05-23 12:38:42 +01:00
lukaszraczylo 68e1c4319c feat(middleware): per-request context object (requestState)
Adds requeststate.go and threads a *requestState through the
ServeHTTP -> processAuthorizedRequestRS -> forwardAuthorized path.
rs is allocated once at the top of ServeHTTP, populates SessionData
field snapshots under a SINGLE sd.sessionMutex.RLock, and caches the
MetadataSnapshot. Downstream handlers read the cached fields instead
of calling session.GetX() / t.metadataSnap() repeatedly.

Why
---
Under Yaegi each method dispatch (including RWMutex.RLock) costs
~1-5ms of interpreter overhead. SessionData getters each take an
RLock on sd.sessionMutex; the previous hot path called 5-7 of them
per request (GetAuthenticated, GetAccessToken, GetIDToken,
GetRefreshToken, GetUserIdentifier, plus the same set again inside
processAuthorizedRequest). With one batched RLock + cached fields,
that drops to a single RLock for the whole handler chain.

This is scoped — not a wholesale architectural refactor:

* requestState is per-request (alloc at ServeHTTP entry, dropped on
  return). It is NOT a shared cache and never escapes the request.
* The original processAuthorizedRequest is kept unchanged for any
  callers we don't migrate this round (bearer path, callback
  handlers, expired-token handlers). New code path is the RS-aware
  processAuthorizedRequestRS, which middleware.ServeHTTP now uses for
  the happy authenticated-and-not-needing-refresh case.
* Cross-request caches (tokenCache, JWKCache, sessionEntries,
  sessionInvalidationCache) are unchanged. rs is additive, not a
  replacement.

What this does NOT change
-------------------------
* The refresh path still calls session.GetX() in middleware.go
  (handleExpiredToken, refreshToken, defaultInitiateAuthentication)
  because those flows can mutate session state and a stale rs would
  be wrong.
* validateStandardTokens still has its own session.GetX() calls.
  Deep plumbing into the token-verification path is a follow-up.
* No semantic changes to authentication, refresh, or session
  lifecycle — only the read path is optimised.

All tests pass with -race; golangci-lint clean.
2026-05-23 12:22:51 +01:00
lukaszraczylo 17e3f8ef62 fix: snapshot patterns for refresh-tracker and metadata URLs
Two related lock-free snapshot refactors addressing the remaining
post-v1.0.16 code-review findings.

1. refreshAttemptTracker: per-field atomic.Load/Store -> atomic.Value
   snapshot of *attemptState (refresh_coordinator.go).

   Previously each tracker held five independently-atomic fields. The
   cooldown-exit reset wrote cooldownEndNano = 0 first, then separately
   stored attempts = 1 and windowStartNano = now. A concurrent
   isInCooldown call could observe cooldownEndNano = 0 (reset just
   completed) with attempts still at MaxRefreshAttempts, immediately
   triggering a fresh cooldown — a benign double-trigger race that
   nonetheless meant the state machine had observable intermediate
   states.

   New design: state is a *attemptState (immutable) published via
   atomic.Value. All transitions (record/success/failure/window-reset/
   cooldown-enter/cooldown-exit) go through mutateState, which runs a
   CAS loop: load current snapshot -> construct fresh snapshot ->
   CompareAndSwap. Either the entire new state publishes or none of
   it does — no intermediate visibility, no cross-field race.

   Under Yaegi this collapses 3-5 per-field atomic dispatches into one
   atomic.Value.Load on the read path. Write paths pay an extra
   allocation for the new snapshot but avoid the cross-field hazard.

2. MetadataSnapshot: hot-path readers use atomic.Value instead of
   metadataMu.RLock (middleware.go, types.go, main.go, utilities.go).

   middleware.ServeHTTP previously took metadataMu.RLock on every
   non-bypass request to read the single field issuerURL. Under Yaegi
   each RLock acquisition costs 1-5ms of interpreter dispatch.
   updateMetadataEndpoints now also publishes an immutable
   *MetadataSnapshot via atomic.Value; the hot-path reader loads it
   in one op via t.metadataSnap(). Falls back to the legacy
   metadataMu.RLock pattern when the snapshot is unpublished (some
   test setups initialize the struct fields directly without going
   through updateMetadataEndpoints).

   Less-frequent callers (helpers, logout, token_introspection) still
   take metadataMu.RLock and are unchanged. The snapshot strictly
   subsets the metadataMu-protected fields, so those readers see
   identical data.

Note on atomic.Pointer[T]: this would have been the cleaner type but
yaegi v0.16.1's stdlib (used by traefik:v3.7.1) exposes only the
legacy unsafe.Pointer-based atomic primitives — no generic Pointer[T].
atomic.Value provides the same semantics via interface{} + type assert.

All tests pass with -race; golangci-lint clean.
2026-05-23 11:31:51 +01:00
lukaszraczylo 827926bc3a fix(refresh-coordinator): trim per-request mutex/map ops
Three related changes addressing post-v1.0.15 code-review findings and
the user's observation that we have been "throwing maps around" — under
Yaegi every sync.Map / atomic / mutex dispatch costs ~1-5ms of
interpreter overhead, so the number of dispatches per request matters
as much as whether they are lock-free.

1. Remove cleanupTimers map + cleanupTimerMu sync.Mutex.

   scheduleDelayedCleanup previously tracked every pending timer in a
   map guarded by a mutex so a duplicate scheduling could cancel the
   prior timer. That "shouldn't happen" path was the only consumer of
   the map, but the mutex fired on every successful refresh
   completion — another per-request Yaegi-dispatched lock.
   performCleanup is already idempotent (LoadAndDelete on the sync.Map),
   so a duplicate firing is at worst a no-op second call. Dropped the
   map entirely; time.AfterFunc callback now calls performCleanup
   directly.

   Net: -1 sync.Mutex, -1 map field, -2 Lock/Unlock pairs per refresh
   completion. Shutdown simplified — no need to enumerate-and-stop
   timers since the callbacks no longer need teardown.

2. Reorder applyLeaderGates: cooldown check BEFORE recordRefreshAttempt.

   Previously incremented the attempt counter and then checked cooldown.
   Under burst load (many concurrent leaders with different token hashes
   but the same session) every goroutine could increment past
   MaxRefreshAttempts before any one of them observed the threshold,
   so the gate fired too late — same thundering-herd shape that drove
   v1.0.14 into the ground. Reordering makes the gate authoritative:
   only attempts that pass the gate are recorded.

   Semantic change: with MaxRefreshAttempts=N, exactly N attempts now
   run to completion before the (N+1)th is denied. Previously the Nth
   was denied as it tried to record (off-by-one stricter). Test
   assertion updated to N (was N-1).

3. Fix getOrCreateOperation MaxConcurrentRefreshes overshoot.

   The previous CAS-loop allowed a transient overshoot of up to N-1
   leaders when several goroutines all observed `current < max` in the
   same scheduling slice before any one of them succeeded their CAS —
   visible to readers as currentInFlightRefreshes > MaxConcurrentRefreshes
   for a brief window.

   Replaced with the ticket-and-return pattern: increment optimistically,
   decrement if we overshot. Strictly bounded: only the goroutine that
   produces max+1 sees max+1 as committed; the rest decrement back
   immediately. No CAS retry loop needed.

What was NOT done in this commit, and why:

* metadataMu.RLock cached via atomic snapshot — code-reviewer flagged
  this at severity 7 (3 RLocks per request: middleware.go:213,
  token_manager.go:349, token_manager.go:408). The clean fix is an
  atomic.Pointer[*MetadataSnapshot], but generic atomic.Pointer[T] is
  NOT exposed by yaegi v0.16.1's stdlib (only legacy unsafe.Pointer
  primitives). atomic.Value would work but requires a snapshot-struct
  refactor across ~15 call sites (helpers/logout/token_introspection/
  token_manager/main/middleware). Deferred to a focused future PR.

* isInCooldown multi-field reset race — the cooldown-reset CAS wins
  on cooldownEndNano, then separately stores attempts/consecutiveFailures/
  windowStartNano. A concurrent isInCooldown can briefly see the
  pre-reset attempts value and trigger a fresh cooldown. Semantic glitch
  (double-cooldown), not a correctness disaster. Fix is a single atomic
  pointer swap of an immutable snapshot — same atomic.Pointer constraint
  as above. Deferred.

All tests pass with -race; golangci-lint clean.
2026-05-23 11:23:16 +01:00
lukaszraczylo abbfdb02a7 fix(jwk): replace JWKCache.mutex with singleflight pattern
JWKCache.GetJWKS previously held a sync.RWMutex.Lock() across the entire
HTTP round-trip to the IdP's JWKS endpoint (jwk.go:93). On a cold cache
(cold pod, JWK rotation, transient network blip) every concurrent
request piled up on this single global write-lock. Under Yaegi each
Lock() acquisition costs 10-50ms of interpreter dispatch — same
architectural shape as the bugs v1.0.14 and v1.0.15 already fixed,
just one that hadn't surfaced as the dominant bottleneck yet.

Code-review post-spike #2 flagged this at confidence 9/10 as the next
likely death-spiral on pod cold-start.

Change replaces the lock with a sync.Map-based singleflight: the first
caller for a given JWKS URL performs the fetch; concurrent callers
attach to the same *jwksFetch and wait on its done channel for the
result. Cold-cache cost is now O(1) HTTP fetch regardless of how many
goroutines are waiting, and no Yaegi-dispatched lock is held during the
fetch itself.

Correctness:
- LoadOrStore winner does the work; losers wait on a done channel.
- Done channel close is in a defer, so panics in fetchJWKS still
  unblock waiters.
- Map entry is removed in the same defer, so a fresh failed fetch can
  be retried by the next request without waiting for any stale entry.
- ctx.Done() unblocks waiters independently of the leader's progress.
- Re-checks the cache after winning LoadOrStore, since another fetch
  may have populated the cache between the initial miss and the win.

Cleanup: also removes a stray yaegi-extract output file
(github_com-lukaszraczylo-traefikoidc.go) that was accidentally
committed during local yaegi compatibility testing.

All tests pass with -race; golangci-lint clean.
2026-05-23 11:05:24 +01:00
lukaszraczylo 72e2b682bb fix: eliminate per-request global mutexes in Yaegi hot paths
The v1.0.14 fix replaced one contended sync.RWMutex (RefreshCoordinator.
refreshMutex) with sync.Map. Production showed the same death-spiral
signature recurring ~2 hours later — same shape, different mutex:
65 goroutines stuck on a sync.(*RWMutex).Lock at one address, pod
pinned at 1000m CPU, identical Yaegi runCfg/reflect.Value.Call stack
pattern. The mutex was RefreshCoordinator.attemptsMutex.

Generalising: under Yaegi (interpreted Go for traefik plugins), any
per-request global mutex acquisition is a latent serialization point.
reflect.Value.Call dispatch on a held lock turns a microsecond
critical section into a multi-millisecond one, and on a GOMAXPROCS=1
pod the queue is unbounded.

This commit removes every per-request global mutex on the hot path:

1. RefreshCoordinator.attemptsMutex (sync.RWMutex)
   sessionRefreshAttempts: map -> sync.Map.
   refreshAttemptTracker: all fields atomic (int32, int64 UnixNano,
   cooldownEndNano == 0 as the not-in-cooldown sentinel, replacing
   the inCooldown bool).
   isInCooldown / recordRefreshAttempt / recordRefreshSuccess /
   recordRefreshFailure all become lock-free. Cooldown entry uses
   CompareAndSwapInt64 so only one goroutine logs the transition.

2. RefreshCircuitBreaker.mutex (sync.RWMutex)
   lastFailureTime / lastSuccessTime -> atomic.Int64 UnixNano.
   state and failures already atomic.
   AllowRequest / RecordSuccess / RecordFailure now pure atomic ops.

3. TraefikOidc.firstRequestMutex (sync.Mutex)
   firstRequestReceived bool -> firstRequestStarted int32.
   metadataRefreshStarted bool -> metadataRefreshStartedAtomic int32.
   ServeHTTP bootstrap path uses CompareAndSwapInt32 — fires once,
   zero steady-state cost. Previously the mutex was acquired on
   every non-health request forever.

4. TraefikOidc.metadataRetryMutex (sync.Mutex)
   lastMetadataRetryTime time.Time -> lastMetadataRetryNano int64.
   The 30-second retry throttle is now a CAS on lastMetadataRetryNano.

cleanupStaleEntries iterates via sync.Map.Range; eviction is a
CompareAndDelete by pointer identity so a tracker freshly re-used by
a concurrent caller is not lost.

Empirical evidence (3 specialist-agent analysis of the v1.0.14 spike,
profiles in /tmp/traefik-spike-1779511683/):
  * mutex profile: 97% delay in sync.(*Mutex).Unlock via
    HTTPHandlerSwitcher -> accesslog -> metrics -> backoff.RetryNotify
  * 65 stuck goroutines at one RWMutex address (0x40022eb648),
    identical Yaegi CFG pointer, all on rc.attemptsMutex via
    recordRefreshAttempt + isInCooldown
  * traffic driver: long-lived in-cluster Go-http-client doing
    ~5.4 req/s POST embeddings via OIDC cookie session → same
    sessionID → contention all funnels to one tracker entry

Yaegi support for sync/atomic confirmed at
github.com/traefik/yaegi@v0.16.1/stdlib/go1_22_sync_atomic.go:
AddInt32/Int64, LoadInt32/Int64, StoreInt32/Int64,
CompareAndSwapInt32/Int64 all exposed via reflect.ValueOf. Yaegi
dispatches each call through reflect.Value.Call to the COMPILED
atomic.* function, which executes a single hardware CAS/LOCK-XADD
instruction. Each atomic op still pays Yaegi dispatch cost but
cannot block — no queueing, no death spiral.

Trade-off acknowledged: v1.0.15 issues ~6-8 atomic/sync.Map ops per
leader-path request vs the 4 mutex ops of v1.0.14. Under low
contention this is a modest CPU bump. Under high contention it's
an unbounded → bounded transformation. Net win.

All tests pass with -race; golangci-lint clean.
2026-05-23 10:47:21 +01:00
lukaszraczylo ae4ccaa89d fix(refresh-coordinator): replace global RWMutex with sync.Map
Under Yaegi, the RefreshCoordinator.refreshMutex was held for tens of
milliseconds per request because every operation inside the critical
section (map access, isInCooldown, recordRefreshAttempt,
isUnderMemoryPressure, atomic ops, struct allocation) is dispatched
through reflect.Value.Call with full arg boxing/unboxing.

Concurrent refreshes on the same coordinator serialized into a queue
that grew without bound. Live capture in production (3 Grafana
dashboards left open) showed:
  * 63 goroutines stuck on rc.refreshMutex.Lock() for 1-11 minutes
  * pod pinned at 1000m CPU (GOMAXPROCS=1)
  * 5.15M allocs/sec, 0.45 RPS effective throughput
  * yaegi.call.func9 accounting for 92.66% of cumulative allocs
  * mutex profile dominated by sync.(*Mutex).Unlock via the request chain

Change inFlightRefreshes from map[string]*refreshOperation+RWMutex to
sync.Map and rewrite getOrCreateOperation to:
  1. Speculatively allocate the candidate operation.
  2. Atomically LoadOrStore by tokenHash. Joiners take the existing
     operation; leader takes the new one. No global lock acquired.
  3. Leader runs rate-limit / cooldown / memory-pressure gates AFTER
     the atomic store. Joiners share the leader's outcome via op.done.
  4. Reserve the concurrent-refresh slot via CompareAndSwap so the
     count cannot overshoot in absence of the old serializing lock.
  5. On any gate failure the leader calls failCandidate, which deletes
     the entry from sync.Map, records the error on op.result and closes
     op.done so any joiner that snuck in returns the same error.

performCleanup becomes a single sync.Map.LoadAndDelete, eliminating
the lock entirely on the cleanup path.

Net effect: critical section is no longer Yaegi-interpreted; it
collapses to atomic instructions on a sharded sync.Map. Refresh
contention disappears even under Yaegi.

All tests pass with -race; golangci-lint clean.
2026-05-23 02:34:49 +01:00
lukaszraczylo 984fd1c08f docs: add Telemetry section linking to oss-telemetry opt-out docs
Discloses the single anonymous adoption ping sent on first plugin
instantiation. Points users to the upstream README section for the
disclosure pattern and to the local telemetry.go for the inline
implementation.
2026-05-21 04:07:19 +01:00
lukaszraczylo 99bdd23986 feat: anonymous usage telemetry via inline oss-telemetry
Adds a yaegi-safe inline telemetry helper that fires a single
fire-and-forget ping at plugin load. Helps track adoption and version
spread. No persistent identifiers are collected.

Implementation notes:
- inline (no external dep) so Traefik plugin loader does not need to
  resolve a new vendored module
- stdlib-only, no generics, no range-over-int — verified to load under
  yaegi 0.16.x (full plugin import + CreateConfig/New symbol lookup OK)
- avoids `switch{case A,B,C:}` blocks where some yaegi releases
  mis-evaluate comma-separated case lists
- sync.Once guards against amplified pings on Traefik dynamic config
  reloads (which re-instantiate the middleware)

Opt out via any of:
  DO_NOT_TRACK=1
  OSS_TELEMETRY_DISABLED=1
  TRAEFIKOIDC_DISABLE_TELEMETRY=1
2026-05-21 03:20:36 +01:00
lukaszraczylo a548665edb feat: opt-in M2M bearer-token authentication (supersedes #93) (#140)
* docs: bearer-token auth design spec

* docs: harden bearer-auth spec with security review findings

* feat(bearer): opt-in M2M bearer-token authentication

Adds an opt-in Authorization: Bearer <jwt> path for machine-to-machine
clients. Replaces and supersedes the broken approach in PR #93
(synthetic-session that omitted user_identifier and skipped ID-token
rejection / replay-protection-semantics / kid-pinning / etc.).

Design

  Two auth entrypoints feed one shared post-auth pipeline:

    cookie path  ─┐
                  ├── forwardAuthorized(rw, req, *principal)
    bearer path  ─┘    (roles/groups, header injection, security
                        headers, cookie strip, forward)

  buildPrincipalFromSession and buildPrincipalFromBearerToken produce
  the same `principal` value type. forwardAuthorized is session-agnostic
  and runs the existing post-auth work; processAuthorizedRequest now
  wraps it with the session-specific concerns (backchannel-logout,
  dirty/Save). The cookie path's behaviour is byte-identical to before
  this PR; the existing test suite passes unmodified.

Security hardening baked into the bearer path

  - Audience MANDATORY. Startup fails when EnableBearerAuth=true and
    Audience is empty.
  - BearerIdentifierClaim defaults to "sub"; "email" is rejected at
    startup to avoid the unverified-email spoofing footgun. Cookie
    path's UserIdentifierClaim is unaffected and still defaults to
    "email".
  - ID tokens explicitly rejected via the existing detectTokenType
    helper (nonce, typ=at+jwt, token_use, scope, aud-vs-clientID
    heuristics); belt-and-braces nonce/token_use=id rejection on top.
  - alg pinned to asymmetric allowlist (RS/PS/ES 256/384/512) BEFORE
    JWKS fetch, blocking alg=none and alg=HS* probes from amplifying
    into upstream calls.
  - kid length capped at 256 bytes and charset-restricted before JWKS
    fetch, blocking pathological-kid JWKS amplification.
  - Multi-audience tokens require azp == clientID.
  - iat upper-age bound (MaxTokenAgeSeconds, default 24h) bounds clock-
    manipulation and forever-token abuse.
  - Identifier sanitization: length cap, control-char + bidi-override
    + delimiter (, ; =) rejection.
  - Per-IP failure throttle: configurable threshold/window/penalty;
    returns 429 + Retry-After. Limits offline-guessing-style attacks
    and protects the shared rate-limiter / JWKS endpoint.
  - JTI replay marking suppressed via new internal verifyOpts
    {skipReplayMarking} so the same bearer can be reused until exp;
    the blacklist Get stays active so RevokeToken still terminates a
    bearer token immediately. The existing exported VerifyToken
    interface is unchanged so all mocks continue to work.
  - Cookie wins by default when both bearer and cookie are present
    (safer against browser/extension/proxy bearer injection).
    Operator can flip via BearerOverridesCookie.
  - Authorization header stripped on forward by default; also stripped
    on excluded URLs so the token can't leak into health/metrics
    downstream logs.
  - Optional RFC 7662 introspection via existing
    requireTokenIntrospection. Introspection-endpoint failure returns
    503 (distinguishes infra from token rejection).
  - 401s use RFC 6750 WWW-Authenticate hints (toggleable). Failure
    reason is logged at debug; raw tokens are never logged.

Implementation

  - principal.go: pure-data principal type and buildPrincipalFromSession.
  - bearer_auth.go: alg/kid pin, classifier, identifier sanitization,
    multi-aud azp gate, iat age check, per-IP failure tracker,
    handleBearerRequest, buildPrincipalFromBearerToken.
  - token_manager.go: VerifyToken now wraps a new verifyTokenWithOpts
    that accepts internal-only verifyOpts. Existing callers, the
    TokenVerifier interface, and all mocks unchanged.
  - middleware.go: extracted forwardAuthorized from
    processAuthorizedRequest; wired bearer detection after init wait
    + after bypass; excluded-URL Authorization strip when bearer
    enabled.
  - settings.go: ten new config fields with defaults applied in
    CreateConfig.
  - main.go: startup validation for audience + identifier-claim
    guard; bearer failure tracker init.

Tests

  - bearer_auth_test.go: table-driven helper tests for every new
    component (parseBearerJOSEHeader, sanitizeBearerIdentifier,
    resolveBearerIdentifier, enforceMultiAudienceAzp, enforceIatAge,
    bearerFailureTracker, detectBearerToken). Integration tests
    through ServeHTTP covering happy path, ID-token rejection,
    alg=none rejection, oversized kid, multi-aud with/without azp,
    iat-too-old, bidi identifier, replay (100x reuse), 429 throttle
    trip, excluded-URL strip, roles gate, cookie-wins precedence,
    BearerOverridesCookie, oversized token, malformed JWT,
    feature-off pass-through. Startup validation for audience-
    required and email-identifier-rejected.
  - All existing tests pass unmodified (cookie-path regression).
  - go vet clean. golangci-lint clean (0 issues). Race detector
    clean on bearer tests.

Documentation

  - README.md: bearer auth section with security highlights and
    config snippet; doc link in the index.
  - .traefik.yml: commented config block exposing every bearer knob.
  - docs/CONFIGURATION.md: new subsection with full parameter table.
  - docs/BEARER_AUTH.md: threat model, hardening matrix, failure
    response table, operational guidance, known follow-ups.
  - docs/superpowers/specs/2026-05-18-bearer-token-auth-design.md:
    design spec + security-review hardening history.

* fix(cache): redact raw cache keys in debug logs (CodeQL go/clear-text-logging)

CodeQL flagged 9 high-severity alerts (go/clear-text-logging) where the
in-memory cache and the hybrid L1+L2 backend printed `key=%s` at debug.
Cache callers (token cache, blacklist, introspection cache) pass raw
access / refresh / id tokens as cache keys, so any debug-enabled
deployment would write them to log streams.

Pre-existing issue. CodeQL started flagging it on this PR because the
new bearer-auth path adds a data-flow source (req.Header.Get("Authorization"))
that reaches the existing logging sinks via the same cache. The cookie
path had the same risk but wasn't tracked as taint by CodeQL.

Fix: hash the key (SHA-256[:8] hex) before printing. Same approach the
bearer-auth logger uses for principal identifiers (spec §13). Doesn't
change cache semantics — same key still produces the same hash, so
debug correlation across log lines is preserved without exposing the
raw value.

Touches both affected packages:
  - internal/cache/cache.go (2 sites: Set + LRU eviction)
  - internal/cache/backends/hybrid.go (12 sites: L1/L2 read/write/fallback)

New helper `redactKey` colocated with each package (unexported,
package-local) keeps the change blast radius narrow. Tests green; lint
clean.

* docs(bearer): how to obtain bearer tokens from the OIDC provider

Adds a section walking operators through the OAuth 2.0 client_credentials
flow (RFC 6749 §4.4) and the JWT bearer assertion alternative (RFC 7523),
with a worked Auth0-shape curl example, a per-provider quick reference
(Auth0, Okta, Keycloak, Entra v2, Cognito, GitLab, Google), operational
notes (token TTL, caching, JWKS rotation, revocation, scope vs audience,
secret hygiene), and a three-line validation loop.

Most common operator confusion: "I enabled the feature but tokens get
401'd" — almost always missing or wrong audience. The new section makes
the audience-matching requirement loud, with per-provider parameter
names so people don't have to dig through IdP docs.

Locations:
  - docs/BEARER_AUTH.md  — full section under "Quick start"
  - README.md            — short snippet + deep link
2026-05-18 17:35:37 +01:00
lukaszraczylo 8c5df82dcf fix(azure): treat Microsoft proprietary access tokens as opaque (#134) (#138)
Followup to issue #134 — two reporters returned saying that even with the
JWKS caching fix in v1.0.7/v1.0.8, every request emitted:

  ERROR: TraefikOidcPlugin: UNKNOWN token verification failed:
    signature verification failed: crypto/rsa: verification error
  ERROR: TraefikOidcPlugin: DIAGNOSTIC: Signature verification failed for
    kid=<kid>, alg=RS256: crypto/rsa: verification error

Root cause: when an Azure tenant is configured without a custom API
resource, Microsoft issues access tokens for Microsoft Graph (or Azure
Mgmt). These tokens carry a `nonce` value in the JWT *header*; the bytes
that get signed contain SHA256(nonce), while the wire token ships the
original nonce. Any standard JWS verifier rejects the signature, which is
exactly Microsoft's intent — they document the format as proprietary and
tell client apps not to validate it
(https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
"you can't validate tokens for Microsoft Graph according to these rules
due to their proprietary format").

validateAzureTokens was nonetheless attempting JWT verification on every
JWT-shaped access token, then silently falling back to the ID token when
verification failed. Auth still worked end-to-end, but every request
spammed two error log lines.

Two-layer defense:

* validateAzureTokens now detects the proprietary-nonce header before
  calling verifyToken on the access token. When detected, the token is
  treated as opaque (matching the existing branch for non-JWT tokens) and
  validation proceeds via the ID token, exactly as Microsoft prescribes.

* VerifyJWTSignatureAndClaims downgrades the DIAGNOSTIC error log to
  debug for tokens carrying the same proprietary marker, in case any
  path outside validateAzureTokens reaches it.

Authorization still hinges on a separately-verifiable ID token — the
confused-deputy guard from CWE-441 is preserved (and explicitly tested).
2026-05-11 17:31:37 +01:00
lukaszraczylo aa96e9dbee Add sponsorship
Just in case you appreciate this project, feel generous and want to sponsor my caffeine addiction.
2026-05-10 21:25:26 +01:00
lukaszraczylo 1e33bb0a4d feat(auth): support private_key_jwt and client_secret_basic (#137)
revocation endpoints, joining the existing client_secret_post default.
Both are opt-in via the new clientAuthMethod config field. Closes #135.

private_key_jwt (RFC 7523 §2.2 / OpenID Connect Core §9)
========================================================
Plugin signs a short-lived JWT with a configured private key and presents
it as client_assertion. Use when the IdP enforces short secret TTLs or
requires secretless client auth (Microsoft Entra ID / Azure AD, Okta,
Auth0, Keycloak).

New Config fields:
  clientAuthMethod          (default: client_secret_post)
  clientAssertionPrivateKey (inline PEM)
  clientAssertionKeyPath    (PEM file path; mutually exclusive)
  clientAssertionKeyID      (JWS kid header — required)
  clientAssertionAlg        (default: RS256; RS/PS/ES 256–512 supported)

PEM forms accepted: PKCS#8, PKCS#1, SEC1.
Assertion claims: iss=sub=clientID, aud=tokenURL, iat=now, exp=now+60s,
random 16-byte hex jti per request. ECDSA signatures are raw r||s per
RFC 7515 (not ASN.1).

client_secret_basic (RFC 6749 §2.3.1)
=====================================
Sends credentials in the Authorization: Basic header instead of the
body. Both halves are form-urlencoded individually before base64 — that
encoding step is required by the spec and is NOT what stdlib's
http.Request.SetBasicAuth does, so the plugin uses its own helper. The
form body omits client_id and client_secret on this path.

Wire-up
=======
Both methods are dispatched at the same two call sites:
  helpers.go:exchangeTokens — auth_code + refresh_token grants
  token_manager.go:RevokeTokenWithProvider — RFC 7009 revocation

Existing clientSecret deployments are unaffected — empty
clientAuthMethod maps to the historical client_secret_post behavior, and
clientAssertion remains nil unless the new fields are set.

Yaegi compatibility
===================
All required crypto/rsa, crypto/ecdsa, crypto/x509, encoding/pem and
crypto/sha256/384/512 symbols are exposed by the traefik/yaegi stdlib
symbol tables (RSA SignPKCS1v15 + SignPSS, ECDSA Sign,
ParsePKCS8/1PrivateKey, ParseECPrivateKey).

Tests (16 new)
==============
Algorithm-family coverage:
  TestIssue135_SignerRSAFamily — RS256/384/512 + PS256/384/512
  TestIssue135_SignerECDSAFamily — ES256/384/512, raw r||s shape
  TestIssue135_SignerRejectsAlgKeyMismatch
  TestIssue135_SignerJTIUniqueness — 50 sigs, all jti distinct
  TestIssue135_SignerPEMVariants — PKCS#8, PKCS#1, SEC1

Config validation:
  TestIssue135_ConfigValidation — full Validate() matrix
  TestIssue135_ConfigKeyPathLoadsFile

Wire-up:
  TestIssue135_AuthCodeExchangeUsesAssertion
  TestIssue135_RefreshTokenUsesAssertion
  TestIssue135_BackcompatClientSecretPath
  TestIssue135_RevocationUsesAssertion
  TestIssue135_BuildSignerFromInlineConfig
  TestIssue135_BuildSignerDefaultsToRS256
  TestIssue135_ClientSecretBasicAuth — Authorization header, no body creds
  TestIssue135_ClientSecretBasicURLEncodesReservedChars — :, +, /, @, =, &
  TestIssue135_ClientSecretBasicRevocation — revocation parity

Documentation
=============
  README.md — required-row note + 5 optional rows + dedicated section
  docs/CONFIGURATION.md — new Client Authentication section with three
    method subsections, OpenSSL keygen snippet, RFC links
  docs/index.html — 5 new config-table rows + Private Key JWT
    explainer card
  .traefik.yml + examples/complete-traefik-config.yaml — commented
    opt-in example

Out of scope (deferred)
=======================
mTLS / tls_client_auth (RFC 8705) — separate change; requires per-call
http.Client with tls.Config.Certificates and conflicts with the current
pooled HTTP client architecture.
2026-05-09 18:02:41 +01:00
lukaszraczylo bfd702a447 fix(jwk): keep parsed JWKS in local cache only (#134) (#136)
Under yaegi (Traefik's plugin runtime) json.Marshal exposes unexported
struct fields with an X-prefixed name. parsedJWKS{ keys map[string]
crypto.PublicKey } therefore round-tripped through Redis as
{"Xkeys":{"<kid>":{"N":<huge>,"E":65537}}} — *rsa.PublicKey.N is a
*big.Int that marshals to a JSON number hundreds of digits long. On
read, json.Unmarshal into interface{} parses numbers as float64, which
cannot represent that range:

  Failed to deserialize value for key .../discovery/v2.0/keys:parsed:
  json: cannot unmarshal number 2251513...
    into Go value of type float64

Auth still worked (the JWKCache rebuilt the keys in memory on every
miss) but the error log spammed every request.

Two structural problems were behind it:

* parsedJWKS holds crypto.PublicKey interface values that aren't
  meaningfully JSON-serializable. Even on compiled Go (where the
  unexported field marshals to {}), the post-roundtrip type assertion
  v.(*parsedJWKS) silently failed and the cache was useless.
* The same pattern applied to *JWKSet — the struct shape survived JSON
  but the type assertion still failed, defeating the cache for every
  call that went through Redis.

Both keys now use the new UniversalCache.SetLocal/GetLocal pair, which
skips the configured distributed backend entirely. JWK rotation is rare
and a per-replica HTTP fetch on cold cache is cheap, so cross-replica
coherence buys nothing for these entries.

Stale Redis entries written by previous versions are simply ignored —
the new code never reads under those keys, and Redis TTL retires them.

Includes regression coverage for the Azure round-trip, the
poisoned-stale-data scenario, and the SetLocal/GetLocal isolation
contract.

patch-release
2026-05-08 13:35:23 +01:00
lukaszraczylo 68c150eba4 fix(cache/redis): honor enableTLS for Redis backend (#133)
The redis.enableTLS / redis.tlsSkipVerify settings were accepted by the
config layer but silently dropped before reaching the connection pool, so
the plugin always dialed Redis in plaintext. This blocked TLS-only Redis
deployments such as AWS ElastiCache with in-transit encryption.

- Add EnableTLS, TLSSkipVerify, TLSServerName to backends.Config and
  PoolConfig and forward them through universal_cache_singleton ->
  backends.Config -> PoolConfig.
- In the connection pool, dial via tls.Dialer.DialContext (TLS 1.2
  minimum) with SNI defaulting to the host part of the configured
  Address when TLSServerName is empty, so ElastiCache cluster endpoints
  validate out of the box. Plain dial path now also propagates ctx.
- Add regression tests covering successful TLS negotiation with skip-
  verify, rejection of self-signed certs without skip-verify, rejection
  of plain TCP servers when EnableTLS=true, and unaffected plaintext
  behavior.
- Document maxRefreshTokenAgeSeconds (added in 1b6c861) and the implicit
  SSE / WebSocket auth bypass (added in 684a990) in README.md,
  docs/CONFIGURATION.md and docs/index.html.
- Add the missing redis.tlsSkipVerify row to docs/index.html and clarify
  the redis.enableTLS description.

patch-release
2026-05-07 12:24:13 +01:00
lukaszraczylo 9cbca4c4fb fix(refresh): honor userIdentifierClaim in token refresh path (#132)
patch-release

The refresh path in token_manager.go hardcoded the "email" claim when
extracting the user identifier from a refreshed ID token, ignoring the
configured userIdentifierClaim. Keycloak users without an email claim
(using sub or another identifier) were kicked out on refresh even
though their initial login worked.

The callback path (auth_flow.go:226-239) already honored
userIdentifierClaim with "sub" fallback; PR #100 (commit a316a98)
added that support but missed the refresh path.

Mirror the callback logic in refreshToken so both paths behave the same.

Cleanup: rename Get/SetEmail to Get/SetUserIdentifier on SessionData
to match the actual semantics. The slot already stored the configured
identifier (email, sub, oid, upn, preferred_username), only the API
name was misleading. Storage key "email" → "user_identifier" and
combinedSessionPayload field E (json:"e") → Ui (json:"ui").

Compat note: existing user sessions invalidate on upgrade — every active
user re-authenticates once after deploying this change.
2026-05-07 09:21:41 +01:00
lukaszraczylo 684a990f59 fix: reduce yaegi CPU footprint + require auth on SSE/WebSocket bypass
minor-release

Behaviour changes (potentially breaking for operators relying on the prior
unauthenticated SSE bypass):

* SSE (`Accept: text/event-stream`) and WebSocket upgrade requests now
  return 401 when no authenticated session is present. Previously the
  bypass forwarded unconditionally, which let any caller reach the
  backend by setting the right header. Excluded URLs are unchanged.
  Operators relying on unauthenticated SSE/WS access must move the path
  into ExcludedURLs.

Performance fixes (target: long-running dashboards like Grafana / ArgoCD
where many panels poll concurrently while the page stays open):

* Stop honouring isTestMode() for the singleton-token-cleanup interval
  under yaegi (the Traefik plugin runtime). In production the plugin was
  running a 20 Hz no-op cleanup ticker because runtime.Compiler ==
  "yaegi" tripped the test-mode branch.
* processAuthorizedRequest now resolves ID-token claims at most once per
  request via SessionData.GetIDTokenClaims (already cached on the
  session) and reuses them for both groups/roles extraction and
  header-template rendering. Previously every authenticated request
  parsed the JWT twice.
* Added extractGroupsAndRolesFromClaims to drive groups/roles off
  pre-parsed claims; extractGroupsAndRoles still works for tests.
* Removed the unconditional session.MarkDirty() in the header-templates
  branch. Templates only mutate request headers, not session state, so
  the prior MarkDirty was re-encrypting and rewriting all session
  cookies on every authenticated request that used header templates.

Other:

* Added isWebSocketUpgrade (RFC 6455 handshake detection — Connection:
  Upgrade + Upgrade: websocket, tolerant of multi-token Connection
  headers and case).
* Renamed applySSEUserHeaders -> applyBypassUserHeaders; it now returns
  bool so the dispatcher can reject unauthenticated SSE/WS with 401.
* Added tests for SSE and WS bypass covering both the auth-rejection
  path and the authenticated forward path.
2026-05-02 03:12:20 +01:00
lukaszraczylo 1b6c8616fd fix(refresh): coalesce refresh-token grants + bound goroutines + cache hot path (target v0.8.27) (#131)
* fix(refresh): wire RefreshCoordinator into the live refresh path

The RefreshCoordinator existed but was never instantiated. The actual
refresh path used only session.refreshMutex, which is per-SessionData
instance - and SessionData is pulled from a sync.Pool per request -
so concurrent requests sharing a refresh token had ZERO coordination.

Symptom: when access_token expired (e.g. 5min Zitadel default), every
in-flight request from a polling client (Grafana panels) entered the
refresh path simultaneously and POSTed the same refresh_token to the
IdP. With refresh-token rotation enabled (Zitadel/Authentik default),
only one grant succeeded; the rest got invalid_grant and each cleared
the entire session. Subsequent requests then thrashed in re-auth loops.

This commit:
- adds refreshCoordinator field on TraefikOidc
- instantiates it in NewWithContext with DefaultRefreshCoordinatorConfig
- shuts it down in Close() under shutdownOnce
- routes refreshToken() through the coordinator via coordinatedTokenRefresh,
  which collapses concurrent grants to a single upstream call per
  refresh_token hash
- exports refreshCoordinatorSessionID for both internal hashing and the
  middleware-level wireup so dedup keys stay aligned

Behavioural notes:
- nil-coordinator fallback preserves existing tests that build TraefikOidc
  literals without going through the constructor
- followers receive the same TokenResponse/error as the leader, so no
  per-instance code paths change
- existing TestGetNewTokenWithRefreshToken_Concurrency still passes
  because it hits GetNewTokenWithRefreshToken directly, below the
  coordinator boundary

Tests:
- refresh_coordinator_wireup_test.go: 50 concurrent refreshes coalesce
  to <=2 upstream calls; distinct tokens still run in parallel; nil
  coordinator falls back cleanly

* perf(cache): bound L1 backfill goroutines in HybridBackend

Get() and GetMany() previously spawned a goroutine per L2 hit to write
the value through to L1. Under sustained polling traffic (e.g. a Grafana
dashboard refreshing every 30s with N panels) this minted thousands of
goroutines, each running in Yaegi - directly contributing to the
~1000% CPU spike that pairs with the refresh-token herd.

Replace the per-hit goroutines with a single l1BackfillWorker fed by
l1BackfillBuffer, mirroring the existing asyncWriteBuffer/asyncWriteWorker
pattern for L2 writes. Buffer overflow drops the backfill (counted via
l1BackfillDrops) - a dropped backfill just means the next L2 hit for
that key re-queues it, which is safe.

Tests:
- TestHybridBackend_L1BackfillBounded: 1000 distinct L2 hits keep
  goroutine count within +20 of baseline (pre-fix it grew by ~1000)
- TestHybridBackend_L1BackfillFullDrops: drops are accounted for when
  the buffer is saturated and the worker is stopped

* feat(refresh): implement isRefreshTokenExpired heuristic

Replace the placeholder `return false` with a real check based on the
issued_at timestamp that SetRefreshToken already stamps into the session.
Gated by a new MaxRefreshTokenAgeSeconds config field (default 21600 =
6h, matching the existing comment). 0 disables the check.

This wires the previously-dead refreshTokenExpired branch in middleware.go,
which short-circuits AJAX requests with a 401 instead of letting them
hammer the IdP for a refresh token that's almost certainly stale - the
classic Grafana-after-long-pause failure mode.

Behaviour:
- maxRefreshTokenAge=0 disables the check (preserves prior behaviour)
- legacy sessions without issued_at still attempt one refresh; the IdP
  remains the source of truth on first try
- nil-receiver and nil-session guards keep test code that builds
  TraefikOidc literals safe

Tests:
- TestIsRefreshTokenExpired_DisabledWhenAgeZero
- TestIsRefreshTokenExpired_LegacySessionWithoutTimestamp
- TestIsRefreshTokenExpired_WithinWindow
- TestIsRefreshTokenExpired_BeyondWindow
- TestIsRefreshTokenExpired_NilGuards

* perf(token): skip parseJWT on cache hit in VerifyToken

The token cache fast-return existed but ran AFTER parseJWT, so every
validation paid for base64 + JSON unmarshal even on a hit. Under bursty
traffic (e.g. 10+ concurrent panel requests on every Grafana dashboard
refresh, each calling validateStandardTokens which verifies BOTH the
access token and the ID token), this is two redundant parses per
request multiplied by the panel count.

Move the cache lookup ahead of parseJWT. On a hit the function returns
nil immediately. On a miss the original flow runs unchanged.

Also nil-guard t.tokenCache to keep partial-literal test instances safe
(matches the same pattern we already use for tokenBlacklist).

Tests:
- TestVerifyToken_CacheHitSkipsParse: cache pre-populated with claims
  for a token whose body would fail parseJWT - returns nil iff the
  fast-path bypasses the parse
- TestVerifyToken_CacheMissStillParses: a syntactically valid but
  unsigned token still errors past parseJWT on cache miss

* feat(refresh): cross-replica refresh-grant dedup via shared cache

The in-process RefreshCoordinator added in 9f96d8c already collapses
concurrent refresh-token grants on a single Traefik replica. With the
plugin's existing Redis (Dragonfly) cache infrastructure available, we
can extend that dedup across replicas: if pod A refreshes a token at
T+0 and pod B receives a request for the same session at T+1, pod B
should reuse pod A's result rather than POSTing the now-rotated refresh
token to the IdP.

Implementation:
- Add a refreshResultCache to UniversalCacheManager (memory-only when
  Redis is disabled, Redis-backed in production via the existing
  hybrid/Redis-only mode selection)
- Expose it through CacheManager.GetSharedRefreshResultCache and on the
  TraefikOidc struct as refreshResultCache (CacheInterface)
- Inside the closure passed to RefreshCoordinator.CoordinateRefresh,
  consult the cache first; on hit return immediately, on miss exchange
  with the IdP and populate the cache for peers
- 5s TTL: long enough for siblings to observe, short enough that a
  rotated refresh token cannot be re-supplied after the IdP has moved on
- Errors are intentionally NOT cached - peers must always be able to
  retry on their own

Pragmatic choice: optimistic cache rather than a hard distributed lock.
- A hard lock (SET NX + poll) doubles Redis RTT and risks dead-locks
  if a Traefik pod dies mid-grant.
- The user's BGP+Local externalTrafficPolicy already pins ingress for
  a session to one node in steady state, so cross-pod racing is rare.
- This optimistic path catches the rare failover case without adding
  failure modes.

Tests:
- TestCoordinatedTokenRefresh_CrossReplicaCacheHit: pre-populated cache
  short-circuits the upstream call entirely (0 IdP calls)
- TestCoordinatedTokenRefresh_PopulatesCrossReplicaCache: leader stores
  a successful result for peers to find
- TestCoordinatedTokenRefresh_ErrorIsNotCached: invalid_grant must not
  poison the dedup cache - peers must retry independently
2026-04-30 18:52:39 +01:00
lukaszraczylo 4d28fa01ab perf(jwk,cache): cache parsed public keys + RLock token cache reads
Hot-path JWT verification rebuilt the public key on every call:
  jwk -> ToRSAPublicKey -> x509.MarshalPKIXPublicKey -> pem.Encode
  -> verifySignature -> pem.Decode -> x509.ParsePKIXPublicKey -> verify
Under yaegi this pinned a CPU when many concurrent dashboard panels
poll behind the middleware. The PEM round trip is pure waste.

* jwk.go: cache pre-parsed crypto.PublicKey per kid alongside the
  raw JWKSet (parallel cache entry, same 1h TTL, invalidates together).
* jwt.go: split verifySignatureWithKey from verifySignature; existing
  PEM-input entry point preserved for backchannel-logout callers.
* token_manager.go: VerifyJWTSignatureAndClaims now goes straight from
  jwks cache to verifySignatureWithKey, no PEM round trip and no
  per-request availableKids slice.
* universal_cache.go: token/JWK/session Get() takes RLock when the
  entry is unexpired, so concurrent token verifications no longer
  serialize on a single mutex. LRU semantics for general and metadata
  caches are unchanged (tests cover the strict-LRU contract there).
* mocks: MockJWKCache, EnhancedMockJWKCache, mockJWKCacheForLogout,
  staticJWKCache satisfy the extended interface.
2026-04-30 10:14:10 +01:00
lukaszraczylo 2d1b04c637 review fixes apr 2026 (#130)
* Multiple fixes

- refresh coordinator dedup + memory pressure wire
- middleware sse consolidation + timer leak + claim cache
- universal cache sync backfill + isDebug gate
- lazy background task race
- memory monitor stw cached + refresh() api

* fix(auth): suppress OIDC redirects on non-navigation requests

- [x] Add isNonNavigationRequest using Sec-Fetch-Mode and Accept headers
- [x] Add comprehensive TestIsNonNavigationRequest
- [x] Update ServeHTTP to 401 non-navigation and AJAX requests

Fixes #129

* feat(config): add custom CA and insecure skip verify for OIDC TLS

- [x] Add CACertPath, CACertPEM, InsecureSkipVerify to Config
- [x] Implement loadCACertPool for CA bundle loading
- [x] Update HTTPClientConfig with RootCAs and InsecureSkipVerify
- [x] Apply CA pool and skip verify to pooled HTTP clients
- [x] Enhance configKey to distinguish TLS configs
- [x] Add comprehensive ca_cert_test.go

Fixes #125

* feat(oidc): add custom CA certificate support for private OIDC providers

- [x] Add caCertPath, caCertPEM, insecureSkipVerify config options
- [x] Update traefik.yml with new OIDC client config fields
- [x] Add configuration schema descriptions for new options
- [x] Update README table and add Custom CA Certificates section

* Fix the documentation.

* test(redis): add oversized argument rejection test

- [x] Add TestRedisConn_RejectOversizedArgumentBytes
- [x] Import strings package

* Dependencies cleanup
2026-04-19 10:12:00 +01:00
lukaszraczylo ccbb98b9dd fix-issue-122 (#128) 2026-03-04 00:23:30 +00:00
Serhii Vasyliev 1362cc0dac Improve debug logging around callback URL matching (#126)
* Add debug logging around callback URL matching in ServeHTTP

The callback URL comparison at the core of OIDC flow had zero logging,
making it extremely difficult to diagnose redirect loop issues caused
by misconfigured callbackURL (e.g., full URL vs path-only).

Every other path comparison in ServeHTTP already logs debug info
(logout, backchannel, frontchannel, excluded URLs), but the callback
URL check was completely silent.

Added debug logs that show:
- The values being compared (request path vs configured callback)
- Whether the match succeeded or failed
- Configured redirURLPath during initialization

This would have immediately revealed the root cause of issue #1
where callbackURL was set as a full URL but compared against
req.URL.Path which only contains the path component.

Closes #3

* improve-callback-url-logging: Add init-time logging for callbackURL config
2026-02-23 10:36:37 +00:00
Yuval Bar-On 249dcad1b3 fix: prevent deadlock in SessionData.Clear method (#114)
Move mutex unlock before calling Save() to prevent potential deadlock
when Save() method needs to acquire the same mutex.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-authored-by: Claude <noreply@anthropic.com>
2026-02-16 15:02:33 +00:00
lukaszraczylo de4b4d7258 fix(cache): remove sync.Pool for Yaegi compatibility (#121)
- [x] Remove sync.Pool implementation that causes reflection panics
- [x] Replace pool-based NewRESPWriter with direct instantiation
- [x] Replace pool-based NewRESPReader with direct instantiation
- [x] Convert Release() methods to no-ops for API compatibility
- [x] Add documentation explaining sync.Pool removal for Yaegi
- [x] Remove "sync" import

Resolves #120
2026-01-19 17:52:31 +00:00
lukaszraczylo 9d52f1b018 feat(core): refactor linters config and improve code quality (#119)
- [x] Reorganize golangci-lint configuration with documented disable reasons
- [x] Simplify errcheck and revive linter rules with targeted exclusions
- [x] Pre-compile regex patterns in input_validation.go for performance
- [x] Fix type assertions in memory_shard.go and resp.go with safety checks
- [x] Replace string comparison with EqualFold for case-insensitive matching
- [x] Fix loop variable captures in jwk.go and logout.go
- [x] Change high goroutine log level from Info to Debug in autocleanup.go
- [x] Replace deprecated "cancelled" spelling with "canceled" throughout
- [x] Add nolint annotations for intentional unused parameters
- [x] Improve comment formatting for deprecated functions
- [x] Fix comment spelling: "marshalling" → "marshaling"
- [x] Refactor provider warnings formatting in internal/providers/warnings.go
- [x] Simplify metrics summary building in internal/recovery/metrics.go
- [x] Pre-allocate slice in error_recovery.go GetDegradedServices
- [x] Refactor context cancellation checks in redis.go
2026-01-15 10:40:49 +00:00
lukaszraczylo 57724918fe fix 116 (#118)
* Fix cache serialisation

* fix(cache): add integer overflow protection for serialization

- [x] Add maxCacheEntrySize constant (64 MiB) to prevent memory overflow
- [x] Validate byte slice size before adding marker byte
- [x] Validate JSON-serialized data size before marker addition
- [x] Add comprehensive overflow protection test cases

* docs: add security fix documentation for integer overflow protection

* test: fix goroutine tests to use mock OIDC servers

The TestContextAwareGoroutineManagement tests were making real HTTP
calls to hardcoded URLs like https://example.com, causing failures
in CI when those requests timeout or return HTTP errors.

Changes:
- Added createMockOIDCServer() helper function using httptest
- Updated GoroutineCleanupOnContextCancel to use mock server
- Updated NoGoroutineLeakOnMultipleInstances to use 3 mock servers
- Updated SingletonTasksAcrossInstances to use mock servers array

This prevents network calls and makes tests more reliable and faster.

Fixes test failures in GitHub Actions CI.
2026-01-08 22:50:46 +00:00
lukaszraczylo 775de2ada1 Fix cache serialisation (#117)
* Fix cache serialisation

* fix(cache): add integer overflow protection for serialization

- [x] Add maxCacheEntrySize constant (64 MiB) to prevent memory overflow
- [x] Validate byte slice size before adding marker byte
- [x] Validate JSON-serialized data size before marker addition
- [x] Add comprehensive overflow protection test cases
2026-01-08 22:06:19 +00:00
lukaszraczylo 7816e05c98 fix issue with logout url (#112)
* fix(logout): handle logout requests before OIDC initialization

- [x] Add debug logging to logout handler entry point
- [x] Move logout path check before OIDC initialization to enable logout when provider unavailable
- [x] Move excluded URL and SSE checks before initialization wait
- [x] Add debug logging for initialization wait to diagnose hanging requests
- [x] Add test for logout functionality without OIDC provider availability

* feat(logout): implement OIDC backchannel and front-channel logout

- [x] Add logout token validation and backchannel logout handler
- [x] Add front-channel logout handler with iframe support
- [x] Implement session invalidation cache for distributed deployments
- [x] Add comprehensive logout token claim verification (issuer, audience, events, iat, sid/sub)
- [x] Integrate session invalidation checks into authorization flow
- [x] Add configuration options for enabling backchannel/front-channel logout
- [x] Add extensive test coverage for logout flows and edge cases
- [x] Update documentation with logout configuration examples
- [x] Add middleware routing for logout endpoints
- [x] Extend cache manager with session invalidation cache support

Resolves #110

* fixup! feat(logout): implement OIDC backchannel and front-channel logout

* fixup! Merge branch 'main' into fix-issue-with-logout-url
2026-01-04 01:59:50 +00:00
Dominik Chilla 8bf7998150 Fix for Hashicorp Vault - accept opaque access tokens with dot-characters (#113) 2026-01-02 16:42:22 +00:00
muffn_ 22c4323fcb fix: set X-Forwarded-User header for SSE requests from existing session (#111)
Co-authored-by: muffin <MonsterMuffin@users.noreply.github.com>
2026-01-02 02:50:11 +00:00
lukaszraczylo 06b219d1f8 feat(dcr): Add Redis storage support for multi-replica deployments (#109)
- [x] Add file and Redis storage backends for DCR credentials
- [x] Implement storage abstraction with FileStore and RedisStore
- [x] Add factory function for automatic backend selection (auto/file/redis)
- [x] Integrate DCR credentials cache into UniversalCacheManager
- [x] Add comprehensive tests for storage backends and factory
- [x] Update configuration schema with storage backend options
- [x] Update documentation with multi-replica deployment guidance
- [x] Add Redis key prefix configuration for credential isolation
2025-12-31 12:52:39 +00:00
180 changed files with 17276 additions and 9362 deletions
+15
View File
@@ -0,0 +1,15 @@
# These are supported funding model platforms
github: lukaszraczylo
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
polar: # Replace with a single Polar username
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
thanks_dev: # Replace with a single thanks.dev username
custom: https://monzo.me/lukaszraczylo
+1 -1
View File
@@ -18,6 +18,6 @@ jobs:
pr-checks:
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
with:
go-version: "1.24.11"
go-version: "1.25.x"
coverage-threshold: 70
secrets: inherit
+1 -1
View File
@@ -19,5 +19,5 @@ jobs:
release:
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
with:
go-version: "1.24.11"
go-version: "1.25.x"
secrets: inherit
+1
View File
@@ -1,3 +1,4 @@
docker/
.claude/*.out
*.test
.leann/
+49 -32
View File
@@ -14,21 +14,22 @@ linters:
- gosec
- misspell
- noctx
- nolintlint
- prealloc
- revive
- rowserrcheck
- sqlclosecheck
- unconvert
- unparam
- whitespace
disable:
- exhaustive
- funlen
- gocognit
- gocyclo # Disabled: OAuth/OIDC flows are inherently complex
- goprintffuncname # Disabled: naming convention is project-specific
- lll
- mnd
- testpackage
- whitespace # Disabled: style preference about newlines
- wsl
settings:
dupl:
@@ -47,29 +48,13 @@ linters:
- fmt.Fprintln
goconst:
min-len: 3
min-occurrences: 10 # Increased to reduce noise for standard OAuth2/OIDC strings
min-occurrences: 15 # Increased to reduce noise for standard OAuth2/OIDC strings and common patterns like "true"
ignore-tests: true
gocritic:
# Using default enabled checks in v2
enabled-checks:
- appendCombine
- boolExprSimplify
- builtinShadow
- commentedOutCode
- emptyFallthrough
- equalFold
- hexLiteral
- indexAlloc
- initClause
- methodExprCall
- nestingReduce
- rangeExprCopy
- rangeValCopy
- stringXbytes
- typeAssertChain
- typeUnparen
- unlabelStmt
- yodaStyleExpr
# Disable style-only checks that add noise
disabled-checks:
- ifElseChain # Style preference, switch not always clearer
- elseif # Style preference
gocyclo:
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
gosec:
@@ -106,23 +91,23 @@ linters:
- name: error-return
- name: error-strings
- name: error-naming
- name: exported
- name: if-return
# - name: exported # Disabled: too noisy, not all exported functions need comments
# - name: if-return # Disabled: style preference
- name: increment-decrement
- name: var-naming
- name: var-declaration
- name: package-comments
# - name: var-naming # Disabled: too strict for legacy code (IP vs Ip)
# - name: var-declaration # Disabled: explicit zero values can be clearer
# - name: package-comments # Disabled: handled by other tools
- name: range
- name: receiver-naming
- name: time-naming
- name: unexported-return
- name: indent-error-flow
# - name: indent-error-flow # Disabled: style preference
- name: errorf
- name: empty-block
# - name: empty-block # Disabled: sometimes empty blocks are intentional
- name: superfluous-else
- name: unused-parameter
# - name: unused-parameter # Disabled: test callbacks and interface implementations often have required unused params
- name: unreachable-code
- name: redefines-builtin-id
# - name: redefines-builtin-id # Disabled: min/max helpers are common before Go 1.21
unparam:
check-exported: false
staticcheck:
@@ -132,8 +117,15 @@ linters:
- -QF1003 # Tagged switch - style preference, may affect Yaegi
- -QF1007 # Merge conditional assignment - style preference
- -QF1008 # Remove embedded field - may break Yaegi compatibility
- -QF1011 # Omit type from declaration - style preference
- -QF1012 # Use fmt.Fprintf - style preference
- -SA9003 # Empty branch - sometimes intentional for future work
- -ST1000 # Package comment format - not required for all packages
- -ST1003 # Package name format - allowed for test packages
- -ST1016 # Receiver name consistency - legacy code
- -ST1020 # Comment format for methods - style preference
- -ST1021 # Comment format for types - style preference
- -ST1023 # Omit type from declaration - style preference
exclusions:
generated: lax
rules:
@@ -144,18 +136,43 @@ linters:
- goconst
- gocyclo
- gosec
- govet
- ineffassign
- noctx
- prealloc
- unparam
- revive
- gocritic
path: _test\.go
- linters:
- dupl
- gocyclo
- govet
- noctx
- prealloc
- unparam
- revive
- gocritic
path: test.*\.go
- linters:
- gocritic
- unused
- errcheck
- revive
path: mocks.*\.go
- linters:
- errcheck
- revive
- gocritic
- govet
- unparam
path: internal/testutil/
- linters:
- govet
- unparam
- noctx
- prealloc
path: integration/
- linters:
- gosec
text: 'G404:'
+81 -1609
View File
File diff suppressed because it is too large Load Diff
+61
View File
@@ -0,0 +1,61 @@
# traefikoidc — Makefile
# Run `make help` for available targets.
GO ?= go
GOPATH := $(shell $(GO) env GOPATH)
# Pin to the yaegi version bundled by the deployed Traefik so yaegi-validate
# tests the real interpreter, not a newer one that may support more. Traefik
# v3.7.1 vendors yaegi v0.16.1 (Go ~1.22 stdlib surface). Bump when Traefik is.
YAEGI_VERSION ?= v0.16.1
TEST_TIMEOUT ?= 480s
.DEFAULT_GOAL := help
.PHONY: help
help: ## Show this help
@grep -hE '^[a-zA-Z0-9_-]+:.*## ' $(MAKEFILE_LIST) | awk 'BEGIN{FS=":.*## "}{printf " \033[36m%-16s\033[0m %s\n", $$1, $$2}'
.PHONY: build
build: ## Compile all packages (native toolchain)
$(GO) build ./...
.PHONY: fmt
fmt: ## Format sources with gofmt
gofmt -w $$(git ls-files '*.go' | grep -v '^vendor/')
.PHONY: vet
vet: ## Run go vet
$(GO) vet ./...
.PHONY: lint
lint: ## Run golangci-lint if available
@command -v golangci-lint >/dev/null 2>&1 && golangci-lint run ./... || echo "golangci-lint not installed; skipping"
.PHONY: staticcheck
staticcheck: ## Run staticcheck (matches the CI "Static Analysis" job; catches U1000 unused, etc.)
@command -v staticcheck >/dev/null 2>&1 || { echo ">> installing staticcheck"; $(GO) install honnef.co/go/tools/cmd/staticcheck@latest; }
@GOFLAGS=-buildvcs=false $$(command -v staticcheck || echo "$(GOPATH)/bin/staticcheck") ./...
.PHONY: test
test: ## Run the test suite
$(GO) test ./... -count=1 -timeout $(TEST_TIMEOUT)
.PHONY: vendor
vendor: ## Refresh and vendor dependencies
$(GO) mod tidy && $(GO) mod vendor
# yaegi-validate interprets the plugin under the yaegi interpreter the same way
# Traefik loads it. Native `go build`/`go test` use the standard compiler and do
# NOT catch yaegi-only incompatibilities (unsupported stdlib symbols, reflection
# edge cases). This target does. Importing the package forces yaegi to interpret
# every file in it plus its vendored deps; CreateConfig + New exercise the
# instantiation path. Pin YAEGI_VERSION to match Traefik's bundled yaegi if you
# need exact parity.
.PHONY: yaegi-validate
yaegi-validate: ## Verify the plugin loads under Traefik's yaegi interpreter
@command -v yaegi >/dev/null 2>&1 || { echo ">> installing yaegi@$(YAEGI_VERSION)"; $(GO) install github.com/traefik/yaegi/cmd/yaegi@$(YAEGI_VERSION); }
@echo ">> interpreting plugin under yaegi (as Traefik does)"
@DO_NOT_TRACK=1 GOFLAGS=-mod=vendor $$(command -v yaegi || echo "$(GOPATH)/bin/yaegi") run ./cmd/yaegicheck/main.go
.PHONY: check
check: vet staticcheck test yaegi-validate ## vet + staticcheck + tests + yaegi load validation
+353 -1933
View File
File diff suppressed because it is too large Load Diff
+49
View File
@@ -0,0 +1,49 @@
# Security Fix: Integer Overflow Protection in Cache Serialization
## Summary
Fixed **High severity** integer overflow vulnerability identified by GitHub Advanced Security in PR #117.
## Vulnerability
**Locations**: `universal_cache.go` lines 789 and 811
- `result := make([]byte, len(bytes)+1)` - Raw bytes path
- `result := make([]byte, len(jsonData)+1)` - JSON encoding path
**Risk**: Potential integer overflow when allocating memory for very large cache entries.
## Fix Applied
1. **Added size limit constant**:
```go
maxCacheEntrySize = 64 * 1024 * 1024 // 64 MiB
```
2. **Size validation before allocation**:
- Validates entry size doesn't exceed limit
- Validates adding marker byte won't overflow
- Returns descriptive error messages
3. **Comprehensive test coverage**:
- Oversized byte slices (>64 MiB)
- Exact max size edge case
- Safe sizes (normal operation)
- Large JSON data structures
## Verification
✅ All tests pass with race detection
✅ No security issues (golangci-lint, gosec)
✅ 76.3% test coverage maintained
## Impact
- No breaking changes
- Negligible performance overhead
- Prevents potential buffer overflows
- Predictable memory usage
---
**Date**: January 8, 2026
**Severity**: High → Resolved
+5 -3
View File
@@ -484,7 +484,8 @@ func TestAuth0Scenario3OpaqueAccessToken(t *testing.T) {
session.SetAccessToken(opaqueAccessToken)
session.SetIDToken(idToken)
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokensRS(rs)
if !authenticated || needsRefresh || expired {
t.Errorf("Session with opaque access token and valid ID token should be authenticated. Got: auth=%v, refresh=%v, expired=%v",
authenticated, needsRefresh, expired)
@@ -623,7 +624,8 @@ func TestAuth0Scenario2StrictMode(t *testing.T) {
session.SetRefreshToken("test-refresh-token") // Add refresh token so it can attempt refresh
// In strict mode, this should FAIL (no fallback to ID token)
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokensRS(rs)
if authenticated {
t.Errorf("Strict mode: Session with wrong access token audience should be rejected, but got authenticated=true")
}
@@ -1491,7 +1493,7 @@ func TestAudienceEndToEndScenario(t *testing.T) {
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("Failed to set authenticated: %v", err)
}
session.SetEmail("user@company.com")
session.SetUserIdentifier("user@company.com")
session.SetIDToken(validJWT)
session.SetAccessToken(validJWT)
+69 -34
View File
@@ -4,8 +4,7 @@ import (
"fmt"
"net/http"
"strings"
"github.com/google/uuid"
"time"
)
// validateRedirectCount checks if redirect limit is exceeded and handles the error
@@ -44,7 +43,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
// Clear all existing session data
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
session.SetEmail("")
session.SetUserIdentifier("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetIDToken("")
@@ -77,7 +76,12 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
csrfToken := uuid.NewString()
csrfToken, err := newUUIDv4()
if err != nil {
t.logger.Errorf("Failed to generate CSRF token: %v", err)
http.Error(rw, "Failed to generate CSRF token", http.StatusInternalServerError)
return
}
nonce, err := generateNonce()
if err != nil {
t.logger.Errorf("Failed to generate nonce: %v", err)
@@ -178,6 +182,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
}
codeVerifier := session.GetCodeVerifier()
if t.enablePKCE && codeVerifier == "" {
t.logger.Error("PKCE is enabled but code verifier is missing from session during callback")
t.sendErrorResponse(rw, req, "Authentication failed: PKCE verifier missing", http.StatusBadRequest)
return
}
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
@@ -246,7 +255,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
return
}
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
session.SetUserIdentifier(userIdentifier)
session.SetIDToken(tokenResponse.IDToken)
session.SetAccessToken(tokenResponse.AccessToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
@@ -259,7 +268,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
// Neutralize open-redirect payloads (e.g. //evil.com, /\evil.com) stored
// from the original request target before using it as the post-login
// redirect target. normalizeLogoutPath forces a host-relative path.
redirectPath = normalizeLogoutPath(incomingPath)
}
session.SetIncomingPath("")
@@ -286,7 +298,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
session.SetIDToken("")
session.SetAccessToken("")
session.SetRefreshToken("")
session.SetEmail("")
session.SetUserIdentifier("")
// Clear CSRF tokens to prevent replay attacks
session.SetCSRF("")
session.SetNonce("")
@@ -301,28 +313,6 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// isUserAuthenticated determines the authentication status and refresh requirements.
// It delegates to provider-specific validation methods that handle different token types
// and expiration behaviors.
// Parameters:
// - session: The session data containing authentication tokens.
//
// Returns:
// - authenticated (bool): True if the user has valid tokens.
// - needsRefresh (bool): True if tokens are valid but nearing expiration.
// - expired (bool): True if the session is unauthenticated, the token is missing,
// or the token verification failed for reasons other than nearing/actual expiration.
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if t.isAzureProvider() {
return t.validateAzureTokens(session)
} else if t.isGoogleProvider() {
return t.validateGoogleTokens(session)
}
// Auth0 and other providers can now use standard validation
// which handles opaque tokens generically
return t.validateStandardTokens(session)
}
// isAjaxRequest determines if this is an AJAX request that should receive 401 instead of redirect
func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
xhr := req.Header.Get("X-Requested-With")
@@ -334,9 +324,54 @@ func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
strings.Contains(accept, "application/json")
}
// isRefreshTokenExpired checks if refresh token is likely expired (older than 6 hours)
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
// This is a heuristic check - actual implementation would depend on
// the specific provider and token metadata
return false // Placeholder implementation
// isNonNavigationRequest reports whether the request is a browser
// sub-resource (script, image, stylesheet, fetch, serviceWorker) rather than
// a top-level HTML navigation. Non-navigation requests MUST NOT trigger an
// OIDC redirect flow: several sub-resource loads happening in parallel would
// each call defaultInitiateAuthentication, each overwriting the session's
// CSRF/nonce, breaking the eventual callback (issue #129).
//
// Detection prefers Sec-Fetch-Mode, which all modern browsers send
// (Chrome/Edge/Firefox/Safari). For older or non-browser clients we fall
// back to Accept: if Accept is present and does not list text/html, treat
// it as a sub-resource. An empty/missing Accept is assumed to be navigation
// (safer to redirect than 401 on an ambiguous request).
func (t *TraefikOidc) isNonNavigationRequest(req *http.Request) bool {
if mode := req.Header.Get("Sec-Fetch-Mode"); mode != "" {
return mode != "navigate"
}
accept := req.Header.Get("Accept")
if accept == "" || accept == "*/*" {
return false
}
return !strings.Contains(accept, "text/html")
}
// isRefreshTokenExpired checks whether the stored refresh token is likely
// past its useful lifetime, using the cookie-side issued_at timestamp set by
// SetRefreshToken. IdPs do not expose RT TTL on the wire, so this is a
// conservative heuristic gated by t.maxRefreshTokenAge (default 6h, set via
// MaxRefreshTokenAgeSeconds; 0 disables the check).
//
// The point of this check is to short-circuit the refresh path BEFORE the
// thundering herd hits the IdP for a token the provider has almost certainly
// revoked. Together with the RefreshCoordinator wireup, it keeps Grafana-
// style polling clients from looping on invalid_grant after a long pause.
func (t *TraefikOidc) isRefreshTokenExpired(session *SessionData) bool {
if t == nil || session == nil {
return false
}
if t.maxRefreshTokenAge <= 0 {
return false
}
issuedAt := session.GetRefreshTokenIssuedAt()
if issuedAt.IsZero() {
// No timestamp recorded (legacy session pre-dating the issued_at
// field). Don't force a re-auth - attempt refresh once and let the
// IdP be the source of truth.
return false
}
return time.Since(issuedAt) > t.maxRefreshTokenAge
}
+88 -4
View File
@@ -192,7 +192,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
// Pre-populate session with old data
_ = session.SetAuthenticated(true)
session.SetEmail("old@example.com")
session.SetUserIdentifier("old@example.com")
session.SetAccessToken("old-access-token-with-many-characters")
session.SetRefreshToken("old-refresh-token-with-many-characters")
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
@@ -207,7 +207,7 @@ func (s *AuthFlowBehaviourSuite) TestPrepareSessionForAuthentication() {
// Verify old data is cleared
s.False(session.GetAuthenticated())
s.Empty(session.GetEmail())
s.Empty(session.GetUserIdentifier())
// Verify new data is set
s.Equal(csrfToken, session.GetCSRF())
@@ -305,6 +305,90 @@ func (s *AuthFlowBehaviourSuite) TestIsAjaxRequest() {
}
}
// TestIsNonNavigationRequest verifies browser sub-resource detection used to
// suppress OIDC redirects on parallel static-asset loads (issue #129).
func (s *AuthFlowBehaviourSuite) TestIsNonNavigationRequest() {
testCases := []struct {
headers map[string]string
name string
expectNonNavigation bool
}{
{
name: "Sec-Fetch-Mode navigate",
headers: map[string]string{"Sec-Fetch-Mode": "navigate"},
expectNonNavigation: false,
},
{
name: "Sec-Fetch-Mode no-cors",
headers: map[string]string{"Sec-Fetch-Mode": "no-cors"},
expectNonNavigation: true,
},
{
name: "Sec-Fetch-Mode cors",
headers: map[string]string{"Sec-Fetch-Mode": "cors"},
expectNonNavigation: true,
},
{
name: "Sec-Fetch-Mode same-origin (fetch in page)",
headers: map[string]string{"Sec-Fetch-Mode": "same-origin"},
expectNonNavigation: true,
},
{
name: "Accept text/html (fallback)",
headers: map[string]string{"Accept": "text/html,application/xhtml+xml"},
expectNonNavigation: false,
},
{
name: "Accept image/png (fallback)",
headers: map[string]string{"Accept": "image/png,image/*;q=0.8"},
expectNonNavigation: true,
},
{
name: "Accept application/javascript (fallback)",
headers: map[string]string{"Accept": "application/javascript"},
expectNonNavigation: true,
},
{
name: "Accept */* treated as navigation",
headers: map[string]string{"Accept": "*/*"},
expectNonNavigation: false,
},
{
name: "No Accept header assumed navigation",
headers: map[string]string{},
expectNonNavigation: false,
},
{
name: "Sec-Fetch-Mode beats Accept (navigate wins)",
headers: map[string]string{
"Sec-Fetch-Mode": "navigate",
"Accept": "application/javascript",
},
expectNonNavigation: false,
},
{
name: "Sec-Fetch-Mode beats Accept (no-cors wins)",
headers: map[string]string{
"Sec-Fetch-Mode": "no-cors",
"Accept": "text/html",
},
expectNonNavigation: true,
},
}
for _, tc := range testCases {
s.Run(tc.name, func() {
req := httptest.NewRequest(http.MethodGet, "/_static/asset.js", nil)
for key, value := range tc.headers {
req.Header.Set(key, value)
}
result := s.tOidc.isNonNavigationRequest(req)
s.Equal(tc.expectNonNavigation, result)
})
}
}
// TestHandleCallback_MissingState tests callback with missing state parameter
func (s *AuthFlowBehaviourSuite) TestHandleCallback_MissingState() {
sessionManager, err := NewSessionManager(
@@ -627,7 +711,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
session, err := sessionManager.GetSession(req)
s.Require().NoError(err)
_ = session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
session.SetIDToken("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.signature")
session.mainSession.Values["redirect_count"] = 3
@@ -636,7 +720,7 @@ func (s *AuthFlowBehaviourSuite) TestHandleExpiredToken() {
// Session should be cleared
s.False(session.GetAuthenticated())
s.Empty(session.GetEmail())
s.Empty(session.GetUserIdentifier())
s.Empty(session.GetIDToken())
// Redirect count should be reset to 0 and then incremented by defaultInitiateAuthentication
+4 -3
View File
@@ -599,8 +599,9 @@ func GetGlobalTaskMemoryMonitor(logger *Logger) *TaskMemoryMonitor {
return globalTaskMemoryMonitor
}
// NewTaskMemoryMonitor creates a new memory monitor for task registry
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior
// NewTaskMemoryMonitor creates a new memory monitor for task registry.
//
// Deprecated: Use GetGlobalTaskMemoryMonitor instead for singleton behavior.
func NewTaskMemoryMonitor(logger *Logger, registry *TaskRegistry) *TaskMemoryMonitor {
return GetGlobalTaskMemoryMonitor(logger)
}
@@ -712,7 +713,7 @@ func (mm *TaskMemoryMonitor) checkForMemoryIssues(stats TaskMemoryStats) {
// Check for goroutine leaks (arbitrary threshold)
if stats.Goroutines > 100 {
mm.logger.Infof("High goroutine count detected: %d", stats.Goroutines)
mm.logger.Debugf("High goroutine count detected: %d", stats.Goroutines)
}
// Check for heap growth without corresponding GC activity
+8 -4
View File
@@ -262,7 +262,8 @@ func TestAzureOIDCRegression(t *testing.T) {
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
// Test that CSRF is preserved during Azure validation failures
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := tOidc.validateAzureTokensRS(rs)
// Should not be authenticated due to validation failure
if authenticated {
@@ -453,7 +454,8 @@ func TestValidateGoogleTokens(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
session := tt.setupSession()
auth, refresh, expired := ts.tOidc.validateGoogleTokens(session)
rs := (&requestState{}).captureSession(session)
auth, refresh, expired := ts.tOidc.validateGoogleTokensRS(rs)
if auth != tt.expectedAuth {
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
@@ -637,7 +639,8 @@ func TestIsUserAuthenticated(t *testing.T) {
defer func() { ts.tOidc.issuerURL = originalIssuer }()
session := tt.setupSession()
auth, refresh, expired := ts.tOidc.isUserAuthenticated(session)
rs := (&requestState{}).captureSession(session)
auth, refresh, expired := ts.tOidc.isUserAuthenticatedRS(rs)
if auth != tt.expectedAuth {
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
@@ -762,7 +765,8 @@ func TestValidateAzureTokensEdgeCases(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
session := tt.setupSession()
auth, refresh, expired := ts.tOidc.validateAzureTokens(session)
rs := (&requestState{}).captureSession(session)
auth, refresh, expired := ts.tOidc.validateAzureTokensRS(rs)
if auth != tt.expectedAuth {
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
+23 -14
View File
@@ -29,8 +29,9 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
pressure := monitor.GetMemoryPressure()
assert.Equal(t, MemoryPressureNone, pressure)
// Collect stats to populate lastStats
monitor.GetCurrentStats()
// Explicitly sample to populate lastStats; GetCurrentStats is now a
// cached read and no longer forces a runtime.ReadMemStats.
monitor.Refresh()
// Now should return a valid pressure level
pressure = monitor.GetMemoryPressure()
@@ -46,11 +47,13 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Start monitoring should not panic
// Start monitoring should not panic. Interval is clamped to the
// minimum (30s); we rely on Refresh() when we need a synchronous
// sample instead of waiting for a tick.
assert.NotPanics(t, func() {
ctx := context.Background()
monitor.StartMonitoring(ctx, 100*time.Millisecond)
time.Sleep(GetTestDuration(50 * time.Millisecond))
monitor.StartMonitoring(ctx, 0)
monitor.Refresh()
})
// Clean up
@@ -117,6 +120,9 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
// Refresh forces a synchronous sample; GetCurrentStats is a cached
// read, so we sample first to guarantee fresh data.
monitor.Refresh()
stats := monitor.GetCurrentStats()
assert.NotNil(t, stats)
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
@@ -450,12 +456,12 @@ func TestMemoryMonitorIntegration(t *testing.T) {
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
defer monitor.StopMonitoring()
// Start monitoring
// Start monitoring. The interval is clamped to the minimum (30s) so
// the ticker won't fire during the test; drive the sample manually via
// Refresh() instead.
ctx := context.Background()
monitor.StartMonitoring(ctx, 50*time.Millisecond)
// Wait for at least one check
time.Sleep(GetTestDuration(150 * time.Millisecond))
monitor.StartMonitoring(ctx, 0)
monitor.Refresh()
// Get pressure (should be a valid pressure level)
pressure := monitor.GetMemoryPressure()
@@ -488,6 +494,7 @@ func TestMemoryStatsCollection(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
monitor.Refresh()
stats := monitor.GetCurrentStats()
assert.NotNil(t, stats)
@@ -501,6 +508,7 @@ func TestMemoryStatsCollection(t *testing.T) {
thresholds := DefaultMemoryAlertThresholds()
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
monitor.Refresh()
stats := monitor.GetCurrentStats()
// Should calculate and include pressure level
@@ -521,13 +529,14 @@ func TestMemoryStatsCollection(t *testing.T) {
// Allocate some memory
_ = make([]byte, 1024*1024) // 1MB
// Get stats before GC
beforeStats := monitor.GetCurrentStats()
// Get stats before GC (explicit Refresh so we have a fresh pre-GC
// snapshot to compare against, not the constructor baseline).
beforeStats := monitor.Refresh()
// Trigger GC
// Trigger GC (internally Refresh()es before and after)
monitor.TriggerGC()
// Get stats after GC
// Get stats after GC from cache (TriggerGC already refreshed it)
afterStats := monitor.GetCurrentStats()
// After GC should have different stats
+683
View File
@@ -0,0 +1,683 @@
// Package traefikoidc — bearer-token (M2M) authentication path.
//
// Disabled by default. When enabled via Config.EnableBearerAuth, requests
// presenting "Authorization: Bearer <jwt>" are validated against the
// configured OIDC provider (signature, issuer, audience, exp, replay-Get)
// and the request is forwarded downstream without creating a cookie session.
//
// Design rules (kept here in code as the single source of truth):
// - Access tokens only. ID tokens are rejected via detectTokenType.
// - Audience is mandatory (enforced at startup in main.go).
// - alg + kid pinned BEFORE JWKS fetch to deny amplification probes.
// - iat upper-age cap bounds clock-skew / forever-token abuse.
// - Multi-audience tokens require matching azp.
// - Per-IP 401 throttle returns 429 + Retry-After after a threshold.
// - JTI Set is suppressed (skipReplayMarking) but JTI Get stays — revoked
// tokens (RevokeToken adds to blacklist) are still rejected.
// - Identifier is read from BearerIdentifierClaim (default "sub"), never
// from UserIdentifierClaim, to avoid the unverified-email spoofing path.
// - Identifier is sanitized: length cap, control chars, bidi-override,
// delimiter chars (, ; =) rejected.
// - On excluded URLs the Authorization header is stripped before forwarding.
//
// See docs/superpowers/specs/2026-05-18-bearer-token-auth-design.md and
// docs/BEARER_AUTH.md for the full threat model.
package traefikoidc
import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"unicode"
)
const bearerPrefix = "Bearer "
// bearerAlgAllowlist is the set of JWS algorithms accepted on the bearer
// path. Asymmetric-only — HS* would allow public-key-as-HMAC-secret attacks
// if any operator ever rotates a key into the symmetric branch by mistake;
// "none" is obvious. Matches the allowlist enforced inside jwt.Verify but is
// checked here BEFORE the JWKS fetch so attacker noise can't amplify.
var bearerAlgAllowlist = map[string]struct{}{
"RS256": {}, "RS384": {}, "RS512": {},
"PS256": {}, "PS384": {}, "PS512": {},
"ES256": {}, "ES384": {}, "ES512": {},
}
// bearerKidMaxLen caps the JOSE kid header length to keep memory and cache-key
// usage bounded against attacker-controlled values.
const bearerKidMaxLen = 256
// validKidChar is the allowlist for kid header characters. Letters, digits,
// dot, underscore, hyphen, equals. Intentionally narrow; real-world kid
// values are short URL-safe-base64-ish identifiers.
func validKidChar(r rune) bool {
if r >= 'a' && r <= 'z' {
return true
}
if r >= 'A' && r <= 'Z' {
return true
}
if r >= '0' && r <= '9' {
return true
}
switch r {
case '.', '_', '-', '=':
return true
}
return false
}
// bearerError categorizes failure modes for the response builder. Categories
// map 1:1 to the table in docs/superpowers/specs/2026-05-18-bearer-token-auth-design.md
// §9 so behavior is auditable from spec to code.
type bearerErrorKind int
const (
bearerErrInvalidRequest bearerErrorKind = iota
bearerErrInvalidToken
bearerErrTokenInactive
bearerErrInvalidIdentifier
bearerErrForbidden
bearerErrThrottled
bearerErrIntrospectionUnavailable
)
type bearerError struct {
kind bearerErrorKind
reason string
}
func (e *bearerError) Error() string { return e.reason }
func newBearerError(kind bearerErrorKind, reason string) *bearerError {
return &bearerError{kind: kind, reason: reason}
}
// joseHeader is the minimal subset of the JWS protected header we inspect
// BEFORE running the full verification pipeline. Lifted out so the alg+kid
// pin can run without paying for parseJWT's full claim decode.
type joseHeader struct {
Alg string `json:"alg"`
Kid string `json:"kid"`
Typ string `json:"typ"`
}
// parseBearerJOSEHeader decodes the first JWT segment for early alg/kid pinning.
// Does not touch the payload or signature — those are the verifier's job.
// Returns nil on success; *bearerError on rejection so the handler can map
// directly to a status code. The decoded header itself is not surfaced because
// callers don't need it (verifyTokenWithOpts re-parses internally).
func parseBearerJOSEHeader(token string) *bearerError {
dot := strings.IndexByte(token, '.')
if dot <= 0 {
return newBearerError(bearerErrInvalidToken, "malformed JWT: no header segment")
}
raw, err := base64.RawURLEncoding.DecodeString(token[:dot])
if err != nil {
// Some IdPs pad with '='; tolerate by retrying with StdEncoding.
raw, err = base64.URLEncoding.DecodeString(token[:dot])
if err != nil {
return newBearerError(bearerErrInvalidToken, "malformed JWT: header not base64url")
}
}
var hdr joseHeader
if err := json.Unmarshal(raw, &hdr); err != nil {
return newBearerError(bearerErrInvalidToken, "malformed JWT: header not JSON")
}
if _, ok := bearerAlgAllowlist[hdr.Alg]; !ok {
return newBearerError(bearerErrInvalidToken, fmt.Sprintf("disallowed alg %q on bearer path", hdr.Alg))
}
if hdr.Kid == "" {
return newBearerError(bearerErrInvalidToken, "missing kid header")
}
if len(hdr.Kid) > bearerKidMaxLen {
return newBearerError(bearerErrInvalidToken, "kid header exceeds max length")
}
for _, r := range hdr.Kid {
if !validKidChar(r) {
return newBearerError(bearerErrInvalidToken, "kid header contains disallowed characters")
}
}
return nil
}
// headerClaimRuneReason reports why a rune is unsafe to inject into a request
// header value, or "" if the rune is acceptable. Shared core of the bearer-path
// identifier sanitizer and the cookie-path header claim sanitizer: rejects
// control chars (CRLF/header injection), Unicode bidi-override runes (RTL
// spoofing of admin UI / SIEM), and the delimiters , ; = (a comma in a group
// name would inject extra entries into a comma-joined header).
func headerClaimRuneReason(r rune) string {
if reason := headerInjectionRuneReason(r); reason != "" {
return reason
}
// The , ; = delimiters are only unsafe for values placed into delimited or
// list contexts (a comma-joined header, or an identifier downstreams may
// split). They are valid in arbitrary single header values, so this stricter
// check is used for the cookie-path identifier and the group/role list, NOT
// for free-form templated header output (see headerValueReason).
if r == ',' || r == ';' || r == '=' {
return "delimiter character"
}
return ""
}
// headerInjectionRuneReason reports why a rune is unsafe in ANY HTTP header
// value, or "" if acceptable. Rejects control characters (CR/LF header
// injection) and Unicode bidi-override runes (RTL spoofing of admin UIs/SIEMs).
// Unlike headerClaimRuneReason it does NOT reject , ; = which are legitimate in
// free-form header values (e.g. an opaque "Authorization: Bearer <token>").
func headerInjectionRuneReason(r rune) string {
if unicode.IsControl(r) {
return "control character"
}
if (r >= 0x202A && r <= 0x202E) || (r >= 0x2066 && r <= 0x2069) {
return "bidi-override character"
}
return ""
}
// headerValueReason reports why value is unsafe to forward as a free-form HTTP
// header value, or "" if acceptable. It rejects values over maxLen (maxLen<=0
// disables the check) and values containing control or bidi-override runes, but
// permits , ; = (valid in header values). Empty is allowed. The reason string
// never includes the value, so it is safe to log.
func headerValueReason(value string, maxLen int) string {
if maxLen > 0 && len(value) > maxLen {
return "exceeds max length"
}
for _, r := range value {
if reason := headerInjectionRuneReason(r); reason != "" {
return reason
}
}
return ""
}
// headerClaimValueReason reports why value is unsafe to inject into a
// downstream request header, or "" if it is acceptable. It rejects empty
// values, values exceeding maxLen (maxLen<=0 disables the length check), and
// values containing any rune rejected by headerClaimRuneReason. The reason
// string is safe to log (it never includes the value itself).
func headerClaimValueReason(value string, maxLen int) string {
if value == "" {
return "empty value"
}
if maxLen > 0 && len(value) > maxLen {
return "exceeds max length"
}
for _, r := range value {
if reason := headerClaimRuneReason(r); reason != "" {
return reason
}
}
return ""
}
// sanitizeHeaderClaimValue validates a claim-derived value before it is
// injected into a downstream request header. It trims surrounding whitespace
// and fails closed (ok=false) on empty values, values exceeding maxLen
// (maxLen<=0 disables the length check), or values containing any rune rejected
// by headerClaimRuneReason. Used by the cookie/session path, which — unlike the
// bearer path — does not otherwise sanitize the principal identifier or the
// group/role strings joined into X-User-Groups / X-User-Roles.
func sanitizeHeaderClaimValue(raw string, maxLen int) (string, bool) {
value := strings.TrimSpace(raw)
if headerClaimValueReason(value, maxLen) != "" {
return "", false
}
return value, true
}
// sanitizeBearerIdentifier validates and trims a principal identifier before
// it is injected into request headers. Layered defense: net/http will reject
// CRLF on the wire too, but rejecting early gives clearer error logs and
// prevents bidi-override / delimiter chars that pass net/http's narrower
// checks but confuse downstream parsers and admin UIs.
func sanitizeBearerIdentifier(raw string, maxLen int) (string, *bearerError) {
identifier := strings.TrimSpace(raw)
if identifier == "" {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier claim empty")
}
if maxLen > 0 && len(identifier) > maxLen {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier exceeds max length")
}
for _, r := range identifier {
if reason := headerClaimRuneReason(r); reason != "" {
return "", newBearerError(bearerErrInvalidIdentifier, "identifier contains "+reason)
}
}
return identifier, nil
}
// resolveBearerIdentifier picks the principal identifier from claims using
// the configured BearerIdentifierClaim (default "sub"). Decoupled from
// userIdentifierClaim (cookie path) to avoid the unverified-email spoofing
// vector documented in the spec §13.
func resolveBearerIdentifier(claims map[string]interface{}, claimName string) (string, *bearerError) {
if claimName == "" {
claimName = "sub"
}
raw, ok := claims[claimName]
if !ok {
return "", newBearerError(bearerErrInvalidIdentifier, fmt.Sprintf("missing claim %q", claimName))
}
str, ok := raw.(string)
if !ok {
return "", newBearerError(bearerErrInvalidIdentifier, fmt.Sprintf("claim %q not a string", claimName))
}
return str, nil
}
// enforceMultiAudienceAzp implements the spec hardening: when aud is a
// multi-element array, require an azp claim equal to clientID. Single-string
// aud is unaffected (existing verifyAudience handles it).
func enforceMultiAudienceAzp(claims map[string]interface{}, clientID string) *bearerError {
audRaw, ok := claims["aud"]
if !ok {
return nil // verifyToken already rejects missing aud
}
arr, ok := audRaw.([]interface{})
if !ok {
return nil // single-string aud
}
if len(arr) <= 1 {
return nil
}
azpRaw, ok := claims["azp"]
if !ok {
return newBearerError(bearerErrInvalidToken, "multi-audience token missing azp")
}
azp, ok := azpRaw.(string)
if !ok || azp == "" {
return newBearerError(bearerErrInvalidToken, "multi-audience token has empty/non-string azp")
}
if azp != clientID {
return newBearerError(bearerErrInvalidToken, "multi-audience token azp does not match clientID")
}
return nil
}
// enforceIatAge implements the spec MaxTokenAgeSeconds bound on iat. Bounds
// clock-manipulation / forever-token abuse without rejecting tokens with a
// normal iat just because the issuer's clock skews a few seconds.
func enforceIatAge(claims map[string]interface{}, maxAge time.Duration) *bearerError {
if maxAge <= 0 {
return nil
}
iatRaw, ok := claims["iat"].(float64)
if !ok {
// jwt.Verify already requires iat; this branch shouldn't be reached.
return newBearerError(bearerErrInvalidToken, "missing iat claim")
}
iat := time.Unix(int64(iatRaw), 0)
if time.Since(iat) > maxAge {
return newBearerError(bearerErrInvalidToken, "token iat outside age bound")
}
return nil
}
// hashIdentifierForLog returns a short SHA-256 prefix safe for info-level
// logs. Full identifier is only emitted at debug. Satisfies the audit
// requirement (trace which principal was rejected) without leaking PII.
func hashIdentifierForLog(identifier string) string {
if identifier == "" {
return "(none)"
}
sum := sha256.Sum256([]byte(identifier))
return hex.EncodeToString(sum[:4]) // 8 hex chars
}
// --- Per-IP failure throttle ---
// bearerFailureTracker records consecutive bearer-auth 401s per source IP and
// parks repeat offenders in a 429 penalty box. Limits offline-guessing-style
// attacks and protects the shared rate-limiter / JWKS endpoint from being
// burned by a single source.
type bearerFailureTracker struct {
mu sync.Mutex
entries map[string]*bearerFailureEntry
// Configuration snapshot. Captured at construction so a hot reconfigure
// doesn't race with the per-request paths.
threshold int
window time.Duration
penalty time.Duration
}
type bearerFailureEntry struct {
firstFailureAt time.Time
penaltyUntil time.Time
count int
}
func newBearerFailureTracker(threshold int, window, penalty time.Duration) *bearerFailureTracker {
if threshold <= 0 {
threshold = 20
}
if window <= 0 {
window = 60 * time.Second
}
if penalty <= 0 {
penalty = 60 * time.Second
}
return &bearerFailureTracker{
entries: make(map[string]*bearerFailureEntry),
threshold: threshold,
window: window,
penalty: penalty,
}
}
// blocked reports whether the source IP is currently in the penalty box.
// Returns (true, retryAfter) when blocked; (false, 0) when allowed.
func (b *bearerFailureTracker) blocked(ip string) (bool, time.Duration) {
if b == nil || ip == "" {
return false, 0
}
b.mu.Lock()
defer b.mu.Unlock()
e, ok := b.entries[ip]
if !ok {
return false, 0
}
now := time.Now()
if !e.penaltyUntil.IsZero() && now.Before(e.penaltyUntil) {
return true, time.Until(e.penaltyUntil)
}
return false, 0
}
// recordFailure increments the failure counter for the given IP and trips
// the penalty box once threshold-within-window is exceeded.
func (b *bearerFailureTracker) recordFailure(ip string) {
if b == nil || ip == "" {
return
}
b.mu.Lock()
defer b.mu.Unlock()
now := time.Now()
e, ok := b.entries[ip]
if !ok || now.Sub(e.firstFailureAt) > b.window {
e = &bearerFailureEntry{firstFailureAt: now}
b.entries[ip] = e
}
e.count++
if e.count >= b.threshold {
e.penaltyUntil = now.Add(b.penalty)
}
}
// recordSuccess clears the failure counter for the given IP after a
// successful bearer auth.
func (b *bearerFailureTracker) recordSuccess(ip string) {
if b == nil || ip == "" {
return
}
b.mu.Lock()
defer b.mu.Unlock()
e, ok := b.entries[ip]
if !ok {
return
}
// Preserve an active penalty so a single success cannot wipe an in-effect
// lockout; only reset the counter when no penalty is active or it has expired.
now := time.Now()
if e.penaltyUntil.IsZero() || now.After(e.penaltyUntil) {
e.count = 0
e.firstFailureAt = now
}
}
// clientIPForBearer returns the source IP used to key the failure tracker.
// Trusts only the request's transport-level RemoteAddr; X-Forwarded-For is
// intentionally ignored to avoid attacker-controlled key spoofing. Behind a
// trusted reverse proxy where every request shares one IP, the throttle is
// still useful (caps attacker churn through that proxy) — operators wanting
// per-real-client throttling must terminate at this middleware.
func clientIPForBearer(req *http.Request) string {
if req == nil {
return ""
}
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
return req.RemoteAddr
}
return host
}
// --- Bearer auth entrypoint ---
// detectBearerToken returns (token, true) when the request carries a usable
// Authorization: Bearer header. Case-insensitive on the scheme. Returns
// ("", false) for any other shape.
func detectBearerToken(req *http.Request) (string, bool) {
if req == nil {
return "", false
}
h := req.Header.Get("Authorization")
if len(h) < len(bearerPrefix) {
return "", false
}
if !strings.EqualFold(h[:len(bearerPrefix)], bearerPrefix) {
return "", false
}
token := strings.TrimSpace(h[len(bearerPrefix):])
if token == "" {
return "", false
}
return token, true
}
// hasSessionCookie reports whether the request carries any cookie matching
// the session prefix. Used to implement the cookie-wins-by-default
// precedence rule when both bearer and cookie are present.
func (t *TraefikOidc) hasSessionCookie(req *http.Request) bool {
if t.sessionManager == nil {
return false
}
prefix := t.sessionManager.GetCookiePrefix()
if prefix == "" {
return false
}
for _, c := range req.Cookies() {
if strings.HasPrefix(c.Name, prefix) {
return true
}
}
return false
}
// writeBearerError writes the canonical 401/403/429/503 response per spec §9.
// Body is always generic; reason is logged at debug only. The
// WWW-Authenticate hint is gated by config (default on, RFC 6750 compliant).
func (t *TraefikOidc) writeBearerError(rw http.ResponseWriter, req *http.Request, err *bearerError) {
var (
status int
errCode string
body string
retryAfter time.Duration
)
switch err.kind {
case bearerErrInvalidRequest:
status = http.StatusUnauthorized
errCode = "invalid_request"
body = "Unauthorized"
case bearerErrInvalidToken, bearerErrTokenInactive, bearerErrInvalidIdentifier:
status = http.StatusUnauthorized
errCode = "invalid_token"
body = "Unauthorized"
case bearerErrForbidden:
status = http.StatusForbidden
body = "Access denied"
case bearerErrThrottled:
status = http.StatusTooManyRequests
body = "Too Many Requests"
retryAfter = t.bearerFailurePenalty
case bearerErrIntrospectionUnavailable:
status = http.StatusServiceUnavailable
body = "Service Unavailable"
default:
status = http.StatusUnauthorized
body = "Unauthorized"
}
if t.bearerEmitWWWAuthenticate && errCode != "" {
rw.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer error=%q`, errCode))
}
if retryAfter > 0 {
rw.Header().Set("Retry-After", fmt.Sprintf("%d", int(retryAfter.Seconds())))
}
rw.Header().Set("Content-Type", "text/plain; charset=utf-8")
rw.WriteHeader(status)
_, _ = rw.Write([]byte(body)) // Safe to ignore: best-effort error body write
if t.logger != nil {
t.logger.Debugf("bearer auth rejected: status=%d category=%v reason=%q path=%s",
status, err.kind, err.reason, req.URL.Path)
}
}
// handleBearerRequest is the entry point invoked by ServeHTTP when the
// EnableBearerAuth flag is set, the request carries an Authorization: Bearer
// header, and the (configurable) cookie-precedence rule allows the bearer
// path to run.
func (t *TraefikOidc) handleBearerRequest(rw http.ResponseWriter, req *http.Request) {
ip := clientIPForBearer(req)
if blocked, retryAfter := t.bearerFailureTracker.blocked(ip); blocked {
throttled := newBearerError(bearerErrThrottled, "ip in penalty box")
// Preserve the actual retry-after even if it diverged from the
// configured default (clock-skew, partial-window expiry).
if retryAfter > 0 {
rw.Header().Set("Retry-After", fmt.Sprintf("%d", int(retryAfter.Seconds())))
}
t.writeBearerError(rw, req, throttled)
return
}
token, ok := detectBearerToken(req)
if !ok {
t.bearerFailureTracker.recordFailure(ip)
t.writeBearerError(rw, req, newBearerError(bearerErrInvalidRequest, "missing or empty bearer token"))
return
}
if len(token) > AccessTokenConfig.MaxLength {
t.bearerFailureTracker.recordFailure(ip)
t.writeBearerError(rw, req, newBearerError(bearerErrInvalidToken, "token exceeds max length"))
return
}
if strings.Count(token, ".") != 2 {
t.bearerFailureTracker.recordFailure(ip)
t.writeBearerError(rw, req, newBearerError(bearerErrInvalidToken, "token is not a 3-segment JWT"))
return
}
if bErr := parseBearerJOSEHeader(token); bErr != nil {
t.bearerFailureTracker.recordFailure(ip)
t.writeBearerError(rw, req, bErr)
return
}
p, bErr := t.buildPrincipalFromBearerToken(token)
if bErr != nil {
t.bearerFailureTracker.recordFailure(ip)
t.writeBearerError(rw, req, bErr)
return
}
t.bearerFailureTracker.recordSuccess(ip)
if t.logger != nil {
t.logger.Debugf("bearer auth success: identifier_hash=%s path=%s",
hashIdentifierForLog(p.Identifier), req.URL.Path)
}
t.forwardAuthorized(rw, req, p)
}
// buildPrincipalFromBearerToken runs the full bearer verification pipeline
// described in spec §7.3 and returns a principal ready for forwardAuthorized.
// Returns a typed *bearerError on failure so the caller can map to status.
func (t *TraefikOidc) buildPrincipalFromBearerToken(token string) (*principal, *bearerError) {
if err := t.verifyTokenWithOpts(token, verifyOpts{skipReplayMarking: true}); err != nil {
return nil, newBearerError(bearerErrInvalidToken, "token verification failed: "+err.Error())
}
parsed, err := parseJWT(token)
if err != nil {
return nil, newBearerError(bearerErrInvalidToken, "post-verify parseJWT failed: "+err.Error())
}
claims := parsed.Claims
// Token-type guard. Reuse the well-tested classifier which already
// checks nonce / typ=at+jwt / token_use / scope / aud-vs-clientID.
if t.detectTokenType(parsed, token) {
return nil, newBearerError(bearerErrInvalidToken, "ID tokens are not accepted on the bearer path")
}
// Belt-and-braces explicit rejection (cheap, catches edge cases not
// covered by detectTokenType's heuristic).
if nonce, ok := claims["nonce"].(string); ok && nonce != "" {
return nil, newBearerError(bearerErrInvalidToken, "nonce claim present (ID-token shape)")
}
if tu, ok := claims["token_use"].(string); ok && tu == "id" {
return nil, newBearerError(bearerErrInvalidToken, "token_use=id rejected")
}
if bErr := enforceMultiAudienceAzp(claims, t.clientID); bErr != nil {
return nil, bErr
}
if bErr := enforceIatAge(claims, t.maxTokenAge); bErr != nil {
return nil, bErr
}
if t.requireTokenIntrospection {
if bErr := t.introspectOnBearerPath(token); bErr != nil {
return nil, bErr
}
}
rawIdentifier, bErr := resolveBearerIdentifier(claims, t.bearerIdentifierClaim)
if bErr != nil {
return nil, bErr
}
identifier, bErr := sanitizeBearerIdentifier(rawIdentifier, t.maxIdentifierLength)
if bErr != nil {
return nil, bErr
}
subject, _ := claims["sub"].(string)
clientID, _ := claims["azp"].(string)
if clientID == "" {
clientID, _ = claims["client_id"].(string)
}
return &principal{
Source: sourceBearer,
Identifier: identifier,
Subject: subject,
ClientID: clientID,
Claims: claims,
AccessToken: token,
}, nil
}
// introspectOnBearerPath calls the existing RFC 7662 introspector when the
// operator demands real-time revocation. Distinguishes "token revoked" (401)
// from "endpoint unavailable" (503) so transient infra failures don't look
// like credential failures.
func (t *TraefikOidc) introspectOnBearerPath(token string) *bearerError {
resp, err := t.introspectToken(token)
if err != nil {
return newBearerError(bearerErrIntrospectionUnavailable, "introspection failed: "+err.Error())
}
if !resp.Active {
return newBearerError(bearerErrTokenInactive, "introspection reports token inactive")
}
return nil
}
+830
View File
@@ -0,0 +1,830 @@
package traefikoidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"golang.org/x/time/rate"
)
// =============================================================================
// Helper builders
// =============================================================================
// makeBearerJWT constructs a JWT with explicit header + claims for tests.
// Signature is opaque (b64("signature")) — bearer tests don't exercise the
// real cryptographic verifier; verification is bypassed via tokenCache pre-
// seed so the bearer pipeline under test sees a "verified" token.
func makeBearerJWT(t *testing.T, header, claims map[string]interface{}) string {
t.Helper()
hb, err := json.Marshal(header)
if err != nil {
t.Fatalf("marshal header: %v", err)
}
cb, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
return fmt.Sprintf("%s.%s.%s",
base64.RawURLEncoding.EncodeToString(hb),
base64.RawURLEncoding.EncodeToString(cb),
base64.RawURLEncoding.EncodeToString([]byte("signature")),
)
}
// defaultBearerHeader produces the standard RS256+kid header used in tests.
func defaultBearerHeader() map[string]interface{} {
return map[string]interface{}{"alg": "RS256", "kid": "test-kid"}
}
// defaultBearerClaims produces a baseline access-token claim set. Tests
// shallow-clone and override fields as needed.
func defaultBearerClaims() map[string]interface{} {
return map[string]interface{}{
"iss": "https://issuer.example.com",
"aud": "https://api.example.com",
"sub": "service-account-1",
"scope": "api:read api:write",
"exp": float64(time.Now().Add(time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
}
}
// makeBearerOIDC constructs a TraefikOidc wired for bearer auth tests. The
// real verifyTokenWithOpts pipeline is short-circuited via tokenCache pre-
// seed: any token Set into t.tokenCache returns nil from VerifyToken,
// letting tests exercise the post-verify bearer logic (classifier, identifier,
// throttle, header forwarding) without standing up JWKs.
func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc {
t.Helper()
sm := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("error"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://issuer.example.com",
audience: "https://api.example.com",
clientID: "https://api.example.com",
tokenCache: NewTokenCache(),
excludedURLs: map[string]struct{}{"/favicon.ico": {}},
allowedRolesAndGroups: map[string]struct{}{},
limiter: rate.NewLimiter(rate.Every(time.Second), 1000),
ctx: context.Background(),
enableBearerAuth: true,
stripAuthorizationHeader: true,
bearerEmitWWWAuthenticate: true,
bearerOverridesCookie: false,
bearerIdentifierClaim: "sub",
maxIdentifierLength: 256,
maxTokenAge: 24 * time.Hour,
bearerFailureThreshold: 20,
bearerFailureWindow: 60 * time.Second,
bearerFailurePenalty: 60 * time.Second,
bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second),
}
oidc.extractClaimsFunc = extractClaims
close(oidc.initComplete)
return oidc
}
// seedVerified pre-populates the tokenCache so verifyTokenWithOpts short-
// circuits to nil for the given token. Mirrors the production fast-return
// path at token_manager.go for previously-verified tokens.
func seedVerified(t *testing.T, oidc *TraefikOidc, token string, claims map[string]interface{}) {
t.Helper()
if oidc.tokenCache == nil {
oidc.tokenCache = NewTokenCache()
}
oidc.tokenCache.Set(token, claims, time.Hour)
}
// =============================================================================
// Unit tests — small helpers
// =============================================================================
func TestDetectBearerToken(t *testing.T) {
t.Parallel()
cases := []struct {
name string
header string
want string
ok bool
}{
{"missing header", "", "", false},
{"basic auth", "Basic abc", "", false},
{"bearer with token", "Bearer abc.def.ghi", "abc.def.ghi", true},
{"lowercase bearer", "bearer abc.def.ghi", "abc.def.ghi", true},
{"mixed case", "BeArEr abc.def.ghi", "abc.def.ghi", true},
{"empty token after prefix", "Bearer ", "", false},
{"bearer no space", "Bearerabc", "", false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
if tc.header != "" {
req.Header.Set("Authorization", tc.header)
}
got, ok := detectBearerToken(req)
if ok != tc.ok || got != tc.want {
t.Fatalf("got=(%q, %v), want=(%q, %v)", got, ok, tc.want, tc.ok)
}
})
}
}
func TestParseBearerJOSEHeader(t *testing.T) {
t.Parallel()
mk := func(t *testing.T, h map[string]interface{}) string {
return makeBearerJWT(t, h, map[string]interface{}{"sub": "x"})
}
cases := []struct {
header map[string]interface{}
name string
wantErr bool
}{
{name: "valid RS256", header: map[string]interface{}{"alg": "RS256", "kid": "k1"}, wantErr: false},
{name: "valid ES512", header: map[string]interface{}{"alg": "ES512", "kid": "abc-_.="}, wantErr: false},
{name: "alg=none rejected", header: map[string]interface{}{"alg": "none", "kid": "k1"}, wantErr: true},
{name: "alg=HS256 rejected", header: map[string]interface{}{"alg": "HS256", "kid": "k1"}, wantErr: true},
{name: "missing kid", header: map[string]interface{}{"alg": "RS256"}, wantErr: true},
{name: "kid too long", header: map[string]interface{}{"alg": "RS256", "kid": strings.Repeat("a", bearerKidMaxLen+1)}, wantErr: true},
{name: "kid bad chars", header: map[string]interface{}{"alg": "RS256", "kid": "evil/../etc/passwd"}, wantErr: true},
{name: "kid with space", header: map[string]interface{}{"alg": "RS256", "kid": "key one"}, wantErr: true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
token := mk(t, tc.header)
err := parseBearerJOSEHeader(token)
if (err != nil) != tc.wantErr {
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
}
})
}
}
func TestSanitiseBearerIdentifier(t *testing.T) {
t.Parallel()
cases := []struct {
name string
in string
want string
wantErr bool
}{
{"normal sub", "service-account-1", "service-account-1", false},
{"email-like", "alice@example.com", "alice@example.com", false},
{"trim whitespace", " abc ", "abc", false},
{"empty", "", "", true},
{"only whitespace", " ", "", true},
{"control char (newline)", "alice\nbob", "", true},
{"control char (CR)", "alice\rbob", "", true},
{"control char (NUL)", "alice\x00bob", "", true},
{"bidi override", "alice\u202ebob", "", true},
{"bidi isolate", "alice\u2066bob", "", true},
{"comma delimiter", "alice,bob", "", true},
{"semicolon delimiter", "alice;bob", "", true},
{"equals delimiter", "alice=bob", "", true},
{"over length", strings.Repeat("a", 257), "", true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := sanitizeBearerIdentifier(tc.in, 256)
if (err != nil) != tc.wantErr {
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
}
if !tc.wantErr && got != tc.want {
t.Fatalf("got=%q want=%q", got, tc.want)
}
})
}
}
func TestResolveBearerIdentifier(t *testing.T) {
t.Parallel()
cases := []struct {
claims map[string]interface{}
name string
claim string
want string
wantErr bool
}{
{name: "default sub", claims: map[string]interface{}{"sub": "abc"}, claim: "", want: "abc"},
{name: "explicit sub", claims: map[string]interface{}{"sub": "abc"}, claim: "sub", want: "abc"},
{name: "custom client_id claim", claims: map[string]interface{}{"client_id": "svc"}, claim: "client_id", want: "svc"},
{name: "missing claim", claims: map[string]interface{}{"other": "x"}, claim: "sub", wantErr: true},
{name: "non-string claim", claims: map[string]interface{}{"sub": 123}, claim: "sub", wantErr: true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := resolveBearerIdentifier(tc.claims, tc.claim)
if (err != nil) != tc.wantErr {
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
}
if !tc.wantErr && got != tc.want {
t.Fatalf("got=%q want=%q", got, tc.want)
}
})
}
}
func TestEnforceMultiAudienceAzp(t *testing.T) {
t.Parallel()
const cid = "https://api.example.com"
cases := []struct {
claims map[string]interface{}
name string
wantErr bool
}{
{name: "single string aud", claims: map[string]interface{}{"aud": "x"}, wantErr: false},
{name: "single element array", claims: map[string]interface{}{"aud": []interface{}{"x"}}, wantErr: false},
{name: "multi-aud with matching azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": cid}, wantErr: false},
{name: "multi-aud missing azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}}, wantErr: true},
{name: "multi-aud empty azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": ""}, wantErr: true},
{name: "multi-aud wrong azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": "other"}, wantErr: true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := enforceMultiAudienceAzp(tc.claims, cid)
if (err != nil) != tc.wantErr {
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
}
})
}
}
func TestEnforceIatAge(t *testing.T) {
t.Parallel()
now := time.Now()
cases := []struct {
name string
iat float64
maxAge time.Duration
wantErr bool
}{
{name: "fresh", iat: float64(now.Unix()), maxAge: time.Hour, wantErr: false},
{name: "23h59m old, max 24h", iat: float64(now.Add(-23*time.Hour - 59*time.Minute).Unix()), maxAge: 24 * time.Hour, wantErr: false},
{name: "25h old, max 24h", iat: float64(now.Add(-25 * time.Hour).Unix()), maxAge: 24 * time.Hour, wantErr: true},
{name: "1970 token", iat: float64(0), maxAge: 24 * time.Hour, wantErr: true},
{name: "maxAge disabled (0)", iat: float64(0), maxAge: 0, wantErr: false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := enforceIatAge(map[string]interface{}{"iat": tc.iat}, tc.maxAge)
if (err != nil) != tc.wantErr {
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
}
})
}
}
func TestBearerFailureTracker(t *testing.T) {
t.Parallel()
tr := newBearerFailureTracker(3, 60*time.Second, 60*time.Second)
const ip = "10.0.0.1"
// Below threshold: not blocked.
for i := 0; i < 2; i++ {
tr.recordFailure(ip)
if b, _ := tr.blocked(ip); b {
t.Fatalf("blocked too early after %d failures", i+1)
}
}
// Threshold reached: blocked.
tr.recordFailure(ip)
if b, retry := tr.blocked(ip); !b || retry <= 0 {
t.Fatalf("expected blocked with positive retry, got=%v retry=%v", b, retry)
}
// A success while a penalty is active must NOT wipe the in-effect lockout
// (otherwise a single success could clear an attacker's penalty).
tr.recordSuccess(ip)
if b, _ := tr.blocked(ip); !b {
t.Fatalf("expected still blocked after success while penalty active")
}
// Other IPs are unaffected.
if b, _ := tr.blocked("10.0.0.2"); b {
t.Fatalf("unrelated IP should not be blocked")
}
// With an expired penalty, a success resets the counter so a subsequent
// sub-threshold failure does not immediately re-block.
tr2 := newBearerFailureTracker(3, 60*time.Second, 1*time.Millisecond)
const ip2 = "10.0.0.3"
for i := 0; i < 3; i++ {
tr2.recordFailure(ip2)
}
time.Sleep(5 * time.Millisecond) // let the short penalty expire
if b, _ := tr2.blocked(ip2); b {
t.Fatalf("expected unblocked after penalty expiry")
}
tr2.recordSuccess(ip2) // resets count since penalty has passed
tr2.recordFailure(ip2) // single failure, well below threshold
if b, _ := tr2.blocked(ip2); b {
t.Fatalf("expected unblocked: counter should have reset after success")
}
}
// =============================================================================
// Integration tests — full ServeHTTP via the bearer pipeline
// =============================================================================
func TestServeHTTP_Bearer_HappyPath(t *testing.T) {
t.Parallel()
var nextCalled atomic.Bool
var capturedHeaders http.Header
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled.Store(true)
capturedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled.Load() {
t.Fatalf("expected next handler to run; got status=%d body=%q", rw.Code, rw.Body.String())
}
if rw.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", rw.Code)
}
if got := capturedHeaders.Get("X-Forwarded-User"); got != "service-account-1" {
t.Fatalf("X-Forwarded-User=%q, want service-account-1", got)
}
if got := capturedHeaders.Get("Authorization"); got != "" {
t.Fatalf("Authorization should be stripped, got=%q", got)
}
}
func TestServeHTTP_Bearer_StripAuthDisabled(t *testing.T) {
t.Parallel()
var capturedAuth string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
oidc.stripAuthorizationHeader = false
claims := defaultBearerClaims()
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !strings.HasPrefix(capturedAuth, "Bearer ") {
t.Fatalf("expected Authorization to be forwarded, got=%q", capturedAuth)
}
}
func TestServeHTTP_Bearer_RejectIDToken(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for ID token rejection")
})
oidc := makeBearerOIDC(t, next)
// ID-token shape: nonce claim present and no scope. detectTokenType
// returns true.
claims := map[string]interface{}{
"iss": "https://issuer.example.com",
"aud": "https://api.example.com",
"sub": "user-1",
"nonce": "n-0S6_WzA2Mj",
"exp": float64(time.Now().Add(time.Hour).Unix()),
"iat": float64(time.Now().Unix()),
}
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
if wa := rw.Header().Get("WWW-Authenticate"); !strings.Contains(wa, `error="invalid_token"`) {
t.Fatalf("expected WWW-Authenticate invalid_token, got=%q", wa)
}
}
func TestServeHTTP_Bearer_AlgNoneRejected(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for alg=none")
})
oidc := makeBearerOIDC(t, next)
header := map[string]interface{}{"alg": "none", "kid": "k1"}
claims := defaultBearerClaims()
token := makeBearerJWT(t, header, claims)
// Even if we pre-seeded the cache, the early alg pin runs FIRST.
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_KidTooLongRejected(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for oversized kid")
})
oidc := makeBearerOIDC(t, next)
header := map[string]interface{}{"alg": "RS256", "kid": strings.Repeat("a", bearerKidMaxLen+1)}
claims := defaultBearerClaims()
token := makeBearerJWT(t, header, claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_MultiAudRequiresAzp(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for multi-aud without azp")
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
claims["aud"] = []interface{}{"https://api.example.com", "https://other.example.com"}
delete(claims, "azp")
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_MultiAudWithAzpAccepted(t *testing.T) {
t.Parallel()
var nextCalled atomic.Bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled.Store(true)
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
claims["aud"] = []interface{}{"https://api.example.com", "https://other.example.com"}
claims["azp"] = oidc.clientID
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusOK || !nextCalled.Load() {
t.Fatalf("expected 200 + next called; got status=%d called=%v", rw.Code, nextCalled.Load())
}
}
func TestServeHTTP_Bearer_IatTooOldRejected(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for old iat")
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
claims["iat"] = float64(time.Now().Add(-25 * time.Hour).Unix())
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_IdentifierWithBidiRejected(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for bidi identifier")
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
claims["sub"] = "alice\u202ebob"
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_ReplayRegression(t *testing.T) {
t.Parallel()
var successCount atomic.Int32
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
successCount.Add(1)
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
claims["jti"] = "regression-jti"
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
for i := 0; i < 100; i++ {
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusOK {
t.Fatalf("iteration %d: status=%d, want 200", i, rw.Code)
}
}
if successCount.Load() != 100 {
t.Fatalf("successCount=%d, want 100", successCount.Load())
}
}
func TestServeHTTP_Bearer_ThrottleTrips429(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run during throttle test")
})
oidc := makeBearerOIDC(t, next)
oidc.bearerFailureTracker = newBearerFailureTracker(3, 60*time.Second, 60*time.Second)
// Send malformed bearers from the same RemoteAddr until threshold trips.
send := func() *httptest.ResponseRecorder {
req := httptest.NewRequest("GET", "/api/work", nil)
req.RemoteAddr = "10.0.0.5:1234"
req.Header.Set("Authorization", "Bearer not-a-jwt")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
return rw
}
for i := 0; i < 3; i++ {
rw := send()
if rw.Code != http.StatusUnauthorized {
t.Fatalf("pre-throttle iteration %d: status=%d, want 401", i, rw.Code)
}
}
// 4th request: throttled.
rw := send()
if rw.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429 after threshold, got %d", rw.Code)
}
if ra := rw.Header().Get("Retry-After"); ra == "" {
t.Fatalf("expected Retry-After header on 429")
}
}
func TestServeHTTP_Bearer_ExcludedURLStripsAuth(t *testing.T) {
t.Parallel()
var capturedAuth string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
oidc.excludedURLs = map[string]struct{}{"/favicon.ico": {}}
req := httptest.NewRequest("GET", "/favicon.ico", nil)
req.Header.Set("Authorization", "Bearer abc.def.ghi")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusOK {
t.Fatalf("excluded path should pass; got %d", rw.Code)
}
if capturedAuth != "" {
t.Fatalf("Authorization must be stripped on excluded paths, got=%q", capturedAuth)
}
}
func TestServeHTTP_Bearer_RolesGate(t *testing.T) {
t.Parallel()
cases := []struct {
name string
rolesClaim []interface{}
want int
}{
{name: "matching role", rolesClaim: []interface{}{"admin"}, want: http.StatusOK},
{name: "no matching role", rolesClaim: []interface{}{"viewer"}, want: http.StatusForbidden},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
oidc.allowedRolesAndGroups = map[string]struct{}{"admin": {}}
oidc.roleClaimName = "roles"
claims := defaultBearerClaims()
claims["roles"] = tc.rolesClaim
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != tc.want {
t.Fatalf("status=%d, want %d", rw.Code, tc.want)
}
})
}
}
func TestServeHTTP_Bearer_CookieWinsByDefault(t *testing.T) {
t.Parallel()
// Both cookie and bearer present: cookie path runs (which will redirect
// to /authorize since the cookie is empty/unauthenticated).
var nextCalled atomic.Bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled.Store(true)
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
claims := defaultBearerClaims()
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
prefix := oidc.sessionManager.GetCookiePrefix()
req.AddCookie(&http.Cookie{Name: prefix + "main", Value: "irrelevant"})
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
// Cookie path consumed the request; bearer was ignored. Since the
// cookie is empty, the cookie path will either 302 to /authorize or
// return 401 — in either case, next must NOT be called.
if nextCalled.Load() {
t.Fatalf("next must not be called when bearer is ignored due to cookie precedence")
}
}
func TestServeHTTP_Bearer_BearerOverridesCookie(t *testing.T) {
t.Parallel()
var nextCalled atomic.Bool
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled.Store(true)
w.WriteHeader(http.StatusOK)
})
oidc := makeBearerOIDC(t, next)
oidc.bearerOverridesCookie = true
claims := defaultBearerClaims()
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
prefix := oidc.sessionManager.GetCookiePrefix()
req.AddCookie(&http.Cookie{Name: prefix + "main", Value: "irrelevant"})
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled.Load() || rw.Code != http.StatusOK {
t.Fatalf("expected bearer to win with override; status=%d called=%v", rw.Code, nextCalled.Load())
}
}
func TestServeHTTP_Bearer_OversizedToken(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for oversized token")
})
oidc := makeBearerOIDC(t, next)
huge := strings.Repeat("a", AccessTokenConfig.MaxLength+1)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+huge)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_MalformedJWT(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("next must not run for malformed JWT")
})
oidc := makeBearerOIDC(t, next)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer not.jwt") // 1 dot
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", rw.Code)
}
}
func TestServeHTTP_Bearer_FeatureOffPassesThrough(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Should not be reached: cookie path runs and (with no session)
// will redirect or 401. We assert no panic / next not called.
t.Fatalf("next must not run when bearer is off and no valid session exists")
})
oidc := makeBearerOIDC(t, next)
oidc.enableBearerAuth = false
claims := defaultBearerClaims()
token := makeBearerJWT(t, defaultBearerHeader(), claims)
seedVerified(t, oidc, token, claims)
req := httptest.NewRequest("GET", "/api/work", nil)
req.Header.Set("Authorization", "Bearer "+token)
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
// Expect non-200: either 302 to /authorize or 401. The point is the
// bearer pipeline didn't run.
if rw.Code == http.StatusOK {
t.Fatalf("expected non-200 when bearer is off; got %d", rw.Code)
}
}
// =============================================================================
// Startup validation tests
// =============================================================================
func TestStartupValidation_BearerRequiresAudience(t *testing.T) {
t.Parallel()
cfg := CreateConfig()
cfg.ProviderURL = "https://issuer.example.com"
cfg.ClientID = "id"
cfg.ClientSecret = "secret"
cfg.CallbackURL = "/oauth/callback"
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef"
cfg.EnableBearerAuth = true
cfg.Audience = ""
_, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), cfg, "bearer-test")
if err == nil || !strings.Contains(err.Error(), "requires Audience") {
t.Fatalf("expected audience-required error, got %v", err)
}
}
func TestStartupValidation_BearerRejectsEmailIdentifier(t *testing.T) {
t.Parallel()
cfg := CreateConfig()
cfg.ProviderURL = "https://issuer.example.com"
cfg.ClientID = "id"
cfg.ClientSecret = "secret"
cfg.CallbackURL = "/oauth/callback"
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef"
cfg.EnableBearerAuth = true
cfg.Audience = "https://api.example.com"
cfg.BearerIdentifierClaim = "email"
_, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), cfg, "bearer-test")
if err == nil || !strings.Contains(err.Error(), "bearerIdentifierClaim=\"email\"") {
t.Fatalf("expected email-identifier rejection, got %v", err)
}
}
// =============================================================================
// Principal invariants
// =============================================================================
func TestBuildPrincipalFromSession_NoIdentifier(t *testing.T) {
t.Parallel()
oidc := &TraefikOidc{logger: NewLogger("error")}
if p := oidc.buildPrincipalFromSession(nil); p != nil {
t.Fatalf("nil session must produce nil principal")
}
}
+137
View File
@@ -0,0 +1,137 @@
package traefikoidc
import (
"encoding/pem"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
// testCertPEM returns a valid PEM-encoded certificate harvested from an
// httptest.NewTLSServer. Using httptest keeps the test free of any
// handwritten static cert that could expire.
func testCertPEM(t *testing.T) string {
t.Helper()
srv := httptest.NewTLSServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
t.Cleanup(srv.Close)
cert := srv.Certificate()
if cert == nil {
t.Fatal("httptest.NewTLSServer did not expose a certificate")
}
return string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))
}
func TestLoadCACertPool_Empty(t *testing.T) {
cfg := &Config{}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool != nil {
t.Errorf("expected nil pool when no CA source configured, got %v", pool)
}
}
func TestLoadCACertPool_InlinePEM(t *testing.T) {
cfg := &Config{CACertPEM: testCertPEM(t)}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool == nil {
t.Fatal("expected non-nil pool for valid CACertPEM")
}
}
func TestLoadCACertPool_InlinePEM_Garbage(t *testing.T) {
cfg := &Config{CACertPEM: "not a pem"}
pool, err := cfg.loadCACertPool()
if err == nil {
t.Fatal("expected error for garbage CACertPEM, got nil")
}
if pool != nil {
t.Errorf("expected nil pool on error, got %v", pool)
}
if !strings.Contains(err.Error(), "caCertPEM") {
t.Errorf("error should name the failing field, got: %v", err)
}
}
func TestLoadCACertPool_FilePath(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "ca.pem")
if err := os.WriteFile(path, []byte(testCertPEM(t)), 0o600); err != nil {
t.Fatalf("writing temp PEM: %v", err)
}
cfg := &Config{CACertPath: path}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool == nil {
t.Fatal("expected non-nil pool for valid CACertPath")
}
}
func TestLoadCACertPool_FilePath_Missing(t *testing.T) {
cfg := &Config{CACertPath: "/does/not/exist/ca.pem"}
pool, err := cfg.loadCACertPool()
if err == nil {
t.Fatal("expected error for missing CACertPath, got nil")
}
if pool != nil {
t.Errorf("expected nil pool on error, got %v", pool)
}
}
func TestLoadCACertPool_Combined(t *testing.T) {
// Both inline and file sources populated — certificates from both should
// be accepted into the same pool.
dir := t.TempDir()
path := filepath.Join(dir, "ca.pem")
if err := os.WriteFile(path, []byte(testCertPEM(t)), 0o600); err != nil {
t.Fatalf("writing temp PEM: %v", err)
}
cfg := &Config{CACertPath: path, CACertPEM: testCertPEM(t)}
pool, err := cfg.loadCACertPool()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pool == nil {
t.Fatal("expected non-nil pool when both sources set")
}
}
func TestSharedTransportPool_ConfigKeyDistinguishesCAAndSkipVerify(t *testing.T) {
p := GetGlobalTransportPool()
cfgSystem := DefaultHTTPClientConfig()
cfgSkip := DefaultHTTPClientConfig()
cfgSkip.InsecureSkipVerify = true
cfgCustomCA := DefaultHTTPClientConfig()
pool, err := (&Config{CACertPEM: testCertPEM(t)}).loadCACertPool()
if err != nil {
t.Fatalf("loadCACertPool: %v", err)
}
cfgCustomCA.RootCAs = pool
keys := map[string]string{
"system": p.configKey(cfgSystem),
"skip": p.configKey(cfgSkip),
"customCA": p.configKey(cfgCustomCA),
}
seen := make(map[string]string, len(keys))
for name, key := range keys {
if dup, ok := seen[key]; ok {
t.Errorf("configKey collision: %s and %s share key %q", name, dup, key)
}
seen[key] = name
}
}
+42 -4
View File
@@ -16,19 +16,23 @@ type CacheManager struct {
}
var (
globalCacheManagerInstance *CacheManager
cacheManagerInitOnce sync.Once
globalCacheManagerInstance *CacheManager
cacheManagerInitOnce sync.Once
cacheManagerActiveFingerprint string
)
// GetGlobalCacheManager returns a singleton CacheManager instance
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
// GetGlobalCacheManager returns a singleton CacheManager instance.
//
// Deprecated: Use GetGlobalCacheManagerWithConfig instead.
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
return GetGlobalCacheManagerWithConfig(wg, nil)
}
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
fp := redisFingerprint(config)
cacheManagerInitOnce.Do(func() {
cacheManagerActiveFingerprint = fp
var redisConfig *RedisConfig
var logger *Logger
@@ -54,9 +58,27 @@ func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheM
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
}
})
// Warn loudly if a later instance asks for a DIFFERENT explicit Redis
// backend than the one that won initialization: the cache manager is a
// process-global singleton shared across plugin instances (yaegi), so this
// instance's divergent configuration is silently ignored, which would
// otherwise collapse cache/state isolation between routes (rank 9).
if fp != "" && cacheManagerActiveFingerprint != "" && fp != cacheManagerActiveFingerprint {
NewLogger(config.LogLevel).Errorf("cache manager already initialized with Redis backend %q; this instance's Redis backend %q is IGNORED (process-global singleton). Use a single consistent cache configuration across all routes.", cacheManagerActiveFingerprint, fp)
}
return globalCacheManagerInstance
}
// redisFingerprint returns a stable identifier for an explicitly-enabled Redis
// backend (address + key prefix), or "" when Redis is not explicitly enabled.
// Used to detect divergent cache configurations across plugin instances.
func redisFingerprint(config *Config) string {
if config == nil || config.Redis == nil || !config.Redis.Enabled {
return ""
}
return config.Redis.Address + "|" + config.Redis.KeyPrefix
}
// GetSharedTokenBlacklist returns the shared token blacklist cache
func (cm *CacheManager) GetSharedTokenBlacklist() CacheInterface {
cm.mu.RLock()
@@ -104,6 +126,22 @@ func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache(), managed: true}
}
// GetSharedSessionInvalidationCache returns the shared session invalidation cache
// for backchannel and front-channel logout (IdP-initiated logout)
func (cm *CacheManager) GetSharedSessionInvalidationCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetSessionInvalidationCache(), managed: true}
}
// GetSharedRefreshResultCache returns the short-lived refresh-result cache used
// by the refresh path to coalesce grants across Traefik replicas via Redis.
func (cm *CacheManager) GetSharedRefreshResultCache() CacheInterface {
cm.mu.RLock()
defer cm.mu.RUnlock()
return &CacheInterfaceWrapper{cache: cm.manager.GetRefreshResultCache(), managed: true}
}
// Close gracefully shuts down all cache components
func (cm *CacheManager) Close() error {
cm.mu.Lock()
+295
View File
@@ -0,0 +1,295 @@
package traefikoidc
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"os"
"time"
)
// isSupportedClientAssertionAlg reports whether alg is a recognized JWS
// algorithm for private_key_jwt (RFC 7523 §2.2).
func isSupportedClientAssertionAlg(alg string) bool {
switch alg {
case "RS256", "RS384", "RS512",
"PS256", "PS384", "PS512",
"ES256", "ES384", "ES512":
return true
}
return false
}
// ClientAssertionSigner builds and signs client_assertion JWTs (RFC 7523 §2.2).
type ClientAssertionSigner struct {
key crypto.PrivateKey
alg string
kid string
// rand is the entropy source for jti generation and PSS/ECDSA signing.
// Defaults to crypto/rand.Reader when nil.
rand io.Reader
// now returns the current time. Defaults to time.Now when nil.
now func() time.Time
}
// NewClientAssertionSigner parses pemBytes as a private key, validates that
// alg is consistent with the key type, and returns a ready-to-use signer.
// kid is placed verbatim in the JWS header.
//
// PEM block types understood:
// - "PRIVATE KEY" → PKCS#8 (tried first for all types)
// - "RSA PRIVATE KEY" → PKCS#1
// - "EC PRIVATE KEY" → SEC1
func NewClientAssertionSigner(pemBytes []byte, alg, kid string) (*ClientAssertionSigner, error) {
if !isSupportedClientAssertionAlg(alg) {
return nil, fmt.Errorf("unsupported client assertion alg %q", alg)
}
if kid == "" {
return nil, fmt.Errorf("kid must not be empty")
}
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, fmt.Errorf("no PEM block found in private key material")
}
var key crypto.PrivateKey
var parseErr error
switch block.Type {
case "PRIVATE KEY":
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
case "RSA PRIVATE KEY":
key, parseErr = x509.ParsePKCS1PrivateKey(block.Bytes)
case "EC PRIVATE KEY":
key, parseErr = x509.ParseECPrivateKey(block.Bytes)
default:
// Best-effort fallback for unknown block types.
key, parseErr = x509.ParsePKCS8PrivateKey(block.Bytes)
}
if parseErr != nil {
return nil, fmt.Errorf("failed to parse private key (block type %q): %w", block.Type, parseErr)
}
if err := validateAlgKeyMatch(alg, key); err != nil {
return nil, err
}
return &ClientAssertionSigner{key: key, alg: alg, kid: kid}, nil
}
// validateAlgKeyMatch returns an error when alg implies a key type that does
// not match the actual key.
func validateAlgKeyMatch(alg string, key crypto.PrivateKey) error {
switch alg[0] {
case 'R', 'P': // RS* or PS*
if _, ok := key.(*rsa.PrivateKey); !ok {
return fmt.Errorf("alg %q requires an RSA key, got %T", alg, key)
}
case 'E': // ES*
if _, ok := key.(*ecdsa.PrivateKey); !ok {
return fmt.Errorf("alg %q requires an EC key, got %T", alg, key)
}
}
return nil
}
// Sign constructs and returns a signed client_assertion JWT.
// audience is typically the token endpoint URL (RFC 7523 §3).
// clientID is used as both iss and sub per RFC 7523 §2.2.
func (s *ClientAssertionSigner) Sign(audience, clientID string) (string, error) {
rander := s.rand
if rander == nil {
rander = rand.Reader
}
nowFn := s.now
if nowFn == nil {
nowFn = time.Now
}
now := nowFn()
// 16 random bytes as lowercase hex for jti uniqueness.
jtiBytes := make([]byte, 16)
if _, err := io.ReadFull(rander, jtiBytes); err != nil {
return "", fmt.Errorf("failed to generate jti: %w", err)
}
jti := hex.EncodeToString(jtiBytes)
header := map[string]string{
"alg": s.alg,
"typ": "JWT",
"kid": s.kid,
}
hdrJSON, err := json.Marshal(header)
if err != nil {
return "", fmt.Errorf("failed to marshal JWT header: %w", err)
}
claims := map[string]any{
"iss": clientID,
"sub": clientID,
"aud": audience,
"jti": jti,
"iat": now.Unix(),
"exp": now.Add(60 * time.Second).Unix(),
}
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("failed to marshal JWT claims: %w", err)
}
hdrB64 := base64.RawURLEncoding.EncodeToString(hdrJSON)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
signingInput := hdrB64 + "." + claimsB64
sig, err := s.sign(rander, []byte(signingInput))
if err != nil {
return "", err
}
return signingInput + "." + base64.RawURLEncoding.EncodeToString(sig), nil
}
// sign computes raw signature bytes for signingInput per s.alg.
// validateAlgKeyMatch in NewClientAssertionSigner guarantees the key type
// matches s.alg, but the comma-ok asserts here keep errcheck happy and
// surface internal misuse loudly instead of via panic.
func (s *ClientAssertionSigner) sign(rander io.Reader, input []byte) ([]byte, error) {
switch s.alg {
case "RS256", "RS384", "RS512", "PS256", "PS384", "PS512":
rsaKey, ok := s.key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("internal: alg %q requires *rsa.PrivateKey, got %T", s.alg, s.key)
}
hash := rsaHashForAlg(s.alg)
digest := hashSum(hash, input)
if s.alg[0] == 'R' {
return signRSAPKCS1v15(rander, rsaKey, hash, digest)
}
return signRSAPSS(rander, rsaKey, hash, digest)
case "ES256", "ES384", "ES512":
ecKey, ok := s.key.(*ecdsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("internal: alg %q requires *ecdsa.PrivateKey, got %T", s.alg, s.key)
}
hash := ecHashForAlg(s.alg)
digest := hashSum(hash, input)
return signECDSA(rander, ecKey, digest)
}
return nil, fmt.Errorf("unhandled alg %q", s.alg)
}
func rsaHashForAlg(alg string) crypto.Hash {
switch alg {
case "RS256", "PS256":
return crypto.SHA256
case "RS384", "PS384":
return crypto.SHA384
case "RS512", "PS512":
return crypto.SHA512
}
return 0
}
func ecHashForAlg(alg string) crypto.Hash {
switch alg {
case "ES256":
return crypto.SHA256
case "ES384":
return crypto.SHA384
case "ES512":
return crypto.SHA512
}
return 0
}
func hashSum(h crypto.Hash, input []byte) []byte {
switch h {
case crypto.SHA256:
sum := sha256.Sum256(input)
return sum[:]
case crypto.SHA384:
sum := sha512.Sum384(input)
return sum[:]
case crypto.SHA512:
sum := sha512.Sum512(input)
return sum[:]
}
return nil
}
func signRSAPKCS1v15(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
sig, err := rsa.SignPKCS1v15(rander, key, hash, digest)
if err != nil {
return nil, fmt.Errorf("RSA PKCS1v15 signing failed: %w", err)
}
return sig, nil
}
func signRSAPSS(rander io.Reader, key *rsa.PrivateKey, hash crypto.Hash, digest []byte) ([]byte, error) {
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hash}
sig, err := rsa.SignPSS(rander, key, hash, digest, opts)
if err != nil {
return nil, fmt.Errorf("RSA PSS signing failed: %w", err)
}
return sig, nil
}
// signECDSA produces the JWS raw r||s signature (RFC 7515 App. A.3).
// Each scalar is zero-padded to (curve.BitSize+7)/8 bytes.
func signECDSA(rander io.Reader, key *ecdsa.PrivateKey, digest []byte) ([]byte, error) {
r, ss, err := ecdsa.Sign(rander, key, digest)
if err != nil {
return nil, fmt.Errorf("ECDSA signing failed: %w", err)
}
byteLen := (key.Curve.Params().BitSize + 7) / 8
sig := make([]byte, 2*byteLen)
padBigInt(sig[0:byteLen], r)
padBigInt(sig[byteLen:], ss)
return sig, nil
}
// padBigInt writes n as a fixed-width big-endian integer into buf.
func padBigInt(buf []byte, n *big.Int) {
b := n.Bytes()
copy(buf[len(buf)-len(b):], b)
}
// buildClientAssertionSignerFromConfig loads key material and constructs a
// ClientAssertionSigner. Called from NewWithContext when
// ClientAuthMethod == "private_key_jwt".
func buildClientAssertionSignerFromConfig(config *Config) (*ClientAssertionSigner, error) {
var pemBytes []byte
if config.ClientAssertionPrivateKey != "" {
pemBytes = []byte(config.ClientAssertionPrivateKey)
} else {
data, err := os.ReadFile(config.ClientAssertionKeyPath)
if err != nil {
return nil, fmt.Errorf("read clientAssertionKeyPath %q: %w", config.ClientAssertionKeyPath, err)
}
pemBytes = data
}
alg := config.ClientAssertionAlg
if alg == "" {
alg = "RS256"
}
return NewClientAssertionSigner(pemBytes, alg, config.ClientAssertionKeyID)
}
+46
View File
@@ -0,0 +1,46 @@
//go:build ignore
// Command yaegicheck verifies that the traefikoidc plugin can be imported and
// instantiated by the yaegi interpreter — the same way Traefik loads a plugin.
//
// It is run by `make yaegi-validate`. Importing the plugin package forces yaegi
// to interpret every source file in the package (and its vendored
// dependencies), so any construct yaegi cannot handle (unsupported stdlib
// symbol, reflection edge case, etc.) surfaces here rather than at Traefik load
// time. CreateConfig + New additionally exercise the instantiation path
// (session manager, cookie codec, caches, key derivation) under the interpreter.
package main
import (
"context"
"fmt"
"net/http"
"os"
oidc "github.com/lukaszraczylo/traefikoidc"
)
func main() {
cfg := oidc.CreateConfig()
cfg.ProviderURL = "https://accounts.google.com"
cfg.ClientID = "yaegi-check-client"
cfg.ClientSecret = "yaegi-check-secret"
cfg.CallbackURL = "/oauth2/callback"
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef"
cfg.RateLimit = 100
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h, err := oidc.New(context.Background(), next, cfg, "yaegi-check")
if err != nil {
fmt.Println("FAIL: New returned an error under yaegi:", err)
os.Exit(1)
}
if h == nil {
fmt.Println("FAIL: New returned a nil handler under yaegi")
os.Exit(1)
}
if closer, ok := h.(interface{ Close() error }); ok {
_ = closer.Close()
}
fmt.Println("OK: traefikoidc imported + CreateConfig + New succeeded under yaegi")
}
+2 -2
View File
@@ -7,7 +7,7 @@ import (
// REDACTED is the placeholder value for sensitive information
const REDACTED = "[REDACTED]"
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
// MarshalJSON implements custom JSON marshaling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalJSON() ([]byte, error) {
// Build a map manually to avoid type alias issues with yaegi
@@ -47,7 +47,7 @@ func (c Config) MarshalJSON() ([]byte, error) {
return json.Marshal(result)
}
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
// MarshalYAML implements custom YAML marshaling to redact sensitive fields
// Rewritten without type aliases for yaegi compatibility
func (c Config) MarshalYAML() (interface{}, error) {
// Build a map manually to avoid type alias issues with yaegi
-76
View File
@@ -278,82 +278,6 @@ func TestHTTPClientProfiler_Methods_CoverageBoost(t *testing.T) {
}
}
// =============================================================================
// SECURITY MONITORING TESTS
// =============================================================================
func TestSecurityMonitor_StopCleanupRoutine_CoverageBoost(t *testing.T) {
logger := NewLogger("info")
config := SecurityMonitorConfig{
MaxFailuresPerIP: 5,
FailureWindowMinutes: 15,
BlockDurationMinutes: 30,
RapidFailureThreshold: 3,
CleanupIntervalMinutes: 60,
RetentionHours: 24,
EnablePatternDetection: true,
EnableDetailedLogging: false,
LogSuspiciousOnly: false,
}
sm := NewSecurityMonitor(config, logger)
if sm == nil {
t.Fatal("Expected non-nil SecurityMonitor")
}
// Start cleanup routine first (lowercase method)
sm.startCleanupRoutine()
// Give it a moment to start
time.Sleep(50 * time.Millisecond)
// Stop cleanup routine (public method)
sm.StopCleanupRoutine()
// Stop again should be safe
sm.StopCleanupRoutine()
}
func TestSecurityMonitor_MultipleHandlers_CoverageBoost(t *testing.T) {
logger := NewLogger("info")
config := SecurityMonitorConfig{
MaxFailuresPerIP: 5,
FailureWindowMinutes: 15,
BlockDurationMinutes: 30,
RapidFailureThreshold: 3,
CleanupIntervalMinutes: 60,
RetentionHours: 24,
}
sm := NewSecurityMonitor(config, logger)
// Create handler
handler := &LoggingSecurityEventHandler{logger: logger}
// Register handler using AddEventHandler
sm.AddEventHandler(handler)
// Record a failure to trigger events
sm.RecordAuthenticationFailure("192.168.1.100", "test-agent", "/test", "test_failure", nil)
}
func TestLoggingSecurityEventHandler_HandleSecurityEvent_AllSeverities_CoverageBoost(t *testing.T) {
logger := NewLogger("debug")
handler := &LoggingSecurityEventHandler{logger: logger}
// Severity is a string in this implementation
events := []SecurityEvent{
{Type: "test", Severity: "low", Message: "low severity"},
{Type: "test", Severity: "medium", Message: "medium severity"},
{Type: "test", Severity: "high", Message: "high severity"},
{Type: "test", Severity: "critical", Message: "critical severity"},
}
for _, event := range events {
handler.HandleSecurityEvent(event)
}
}
// =============================================================================
// SESSION MANAGER TESTS
// =============================================================================
+4 -4
View File
@@ -31,7 +31,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
session.SetCSRF(csrfToken)
session.SetNonce("test-nonce")
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken("old-access-token")
session.SetRefreshToken("old-refresh-token")
session.SetIDToken("old-id-token")
@@ -61,7 +61,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
// Now perform selective clearing (as done in the fix)
session2.SetAuthenticated(false)
session2.SetEmail("")
session2.SetUserIdentifier("")
session2.SetAccessToken("")
session2.SetRefreshToken("")
session2.SetIDToken("")
@@ -303,7 +303,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// Set initial session data
session.SetAuthenticated(true)
session.SetEmail("old@example.com")
session.SetUserIdentifier("old@example.com")
session.SetAccessToken("old-token")
session.SetCSRF("existing-csrf")
@@ -325,7 +325,7 @@ func TestRegressionLoginLoop(t *testing.T) {
// OLD BEHAVIOR: session.Clear() would have been called here, losing CSRF
// NEW BEHAVIOR: Selective clearing
session2.SetAuthenticated(false)
session2.SetEmail("")
session2.SetUserIdentifier("")
session2.SetAccessToken("")
session2.SetRefreshToken("")
session2.SetIDToken("")
+290
View File
@@ -0,0 +1,290 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik
package traefikoidc
import (
"context"
"fmt"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/dcrstorage"
)
// DCRStorageBackend represents the type of storage backend for DCR credentials.
// Alias for internal package type for backward compatibility.
type DCRStorageBackend = dcrstorage.StorageBackend
const (
// DCRStorageBackendFile uses file-based storage (default for backward compatibility)
DCRStorageBackendFile DCRStorageBackend = dcrstorage.StorageBackendFile
// DCRStorageBackendRedis uses Redis for distributed storage
DCRStorageBackendRedis DCRStorageBackend = dcrstorage.StorageBackendRedis
// DCRStorageBackendAuto automatically selects Redis if available, otherwise file
DCRStorageBackendAuto DCRStorageBackend = dcrstorage.StorageBackendAuto
)
// DCRCredentialsStore defines the interface for storing DCR credentials.
// This abstraction allows different storage backends (file, Redis) to be used
// for persisting OIDC Dynamic Client Registration credentials across nodes.
type DCRCredentialsStore interface {
// Save stores the client registration response for a provider
// The providerURL is used as a key to support multi-tenant scenarios
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
// Load retrieves stored credentials for a provider
// Returns nil, nil if no credentials exist (not an error)
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
// Delete removes stored credentials for a provider
Delete(ctx context.Context, providerURL string) error
// Exists checks if credentials exist for a provider
Exists(ctx context.Context, providerURL string) (bool, error)
}
// loggerAdapter adapts our Logger to the dcrstorage.Logger interface
type loggerAdapter struct {
logger *Logger
}
func (l *loggerAdapter) Debug(msg string) { l.logger.Debug("%s", msg) }
func (l *loggerAdapter) Debugf(format string, args ...any) { l.logger.Debugf(format, args...) }
func (l *loggerAdapter) Info(msg string) { l.logger.Info("%s", msg) }
func (l *loggerAdapter) Infof(format string, args ...any) { l.logger.Infof(format, args...) }
func (l *loggerAdapter) Error(msg string) { l.logger.Error("%s", msg) }
func (l *loggerAdapter) Errorf(format string, args ...any) { l.logger.Errorf(format, args...) }
// cacheAdapter adapts UniversalCache to dcrstorage.Cache interface
type cacheAdapter struct {
cache *UniversalCache
}
func (c *cacheAdapter) Get(key string) (any, bool) {
return c.cache.Get(key)
}
func (c *cacheAdapter) Set(key string, value any, ttl time.Duration) error {
return c.cache.Set(key, value, ttl)
}
func (c *cacheAdapter) Delete(key string) {
c.cache.Delete(key)
}
// fileStoreWrapper wraps dcrstorage.FileStore to implement DCRCredentialsStore
type fileStoreWrapper struct {
inner *dcrstorage.FileStore
}
func (w *fileStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
innerCreds := convertCredsToInternal(creds)
return w.inner.Save(ctx, providerURL, innerCreds)
}
func (w *fileStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
innerCreds, err := w.inner.Load(ctx, providerURL)
if err != nil || innerCreds == nil {
return nil, err
}
return convertCredsFromInternal(innerCreds), nil
}
func (w *fileStoreWrapper) Delete(ctx context.Context, providerURL string) error {
return w.inner.Delete(ctx, providerURL)
}
func (w *fileStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
return w.inner.Exists(ctx, providerURL)
}
// basePath returns the base path used for storing credentials (for backward compatibility in tests)
func (w *fileStoreWrapper) basePath() string {
return w.inner.BasePath()
}
// getFilePath returns the file path for storing credentials for a specific provider (for backward compatibility in tests)
func (w *fileStoreWrapper) getFilePath(providerURL string) string {
return w.inner.GetFilePath(providerURL)
}
// redisStoreWrapper wraps dcrstorage.RedisStore to implement DCRCredentialsStore
type redisStoreWrapper struct {
inner *dcrstorage.RedisStore
}
func (w *redisStoreWrapper) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
innerCreds := convertCredsToInternal(creds)
return w.inner.Save(ctx, providerURL, innerCreds)
}
func (w *redisStoreWrapper) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
innerCreds, err := w.inner.Load(ctx, providerURL)
if err != nil || innerCreds == nil {
return nil, err
}
return convertCredsFromInternal(innerCreds), nil
}
func (w *redisStoreWrapper) Delete(ctx context.Context, providerURL string) error {
return w.inner.Delete(ctx, providerURL)
}
func (w *redisStoreWrapper) Exists(ctx context.Context, providerURL string) (bool, error) {
return w.inner.Exists(ctx, providerURL)
}
// FileCredentialsStore implements DCRCredentialsStore using file-based storage.
// This is the default storage backend for backward compatibility with existing deployments.
type FileCredentialsStore = fileStoreWrapper
// RedisCredentialsStore implements DCRCredentialsStore using Redis-backed cache.
// This storage backend enables sharing DCR credentials across multiple Traefik instances.
type RedisCredentialsStore = redisStoreWrapper
// NewFileCredentialsStore creates a new file-based credentials store.
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
func NewFileCredentialsStore(basePath string, logger *Logger) *FileCredentialsStore {
var dcrLogger dcrstorage.Logger
if logger != nil {
dcrLogger = &loggerAdapter{logger: logger}
}
inner := dcrstorage.NewFileStore(basePath, dcrLogger)
return &fileStoreWrapper{inner: inner}
}
// NewRedisCredentialsStore creates a new Redis-backed credentials store.
// The cache should be configured with a Redis backend for distributed storage.
// If keyPrefix is empty, defaults to "dcr:creds:"
func NewRedisCredentialsStore(cache *UniversalCache, keyPrefix string, logger *Logger) *RedisCredentialsStore {
var dcrLogger dcrstorage.Logger
if logger != nil {
dcrLogger = &loggerAdapter{logger: logger}
}
cacheAdapt := &cacheAdapter{cache: cache}
inner := dcrstorage.NewRedisStore(cacheAdapt, keyPrefix, dcrLogger)
return &redisStoreWrapper{inner: inner}
}
// Helper functions to convert between main package and internal package types
func convertCredsToInternal(creds *ClientRegistrationResponse) *dcrstorage.ClientRegistrationResponse {
if creds == nil {
return nil
}
return &dcrstorage.ClientRegistrationResponse{
SubjectType: creds.SubjectType,
LogoURI: creds.LogoURI,
RegistrationAccessToken: creds.RegistrationAccessToken,
RegistrationClientURI: creds.RegistrationClientURI,
Scope: creds.Scope,
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
TOSURI: creds.TOSURI,
PolicyURI: creds.PolicyURI,
ClientSecret: creds.ClientSecret,
ApplicationType: creds.ApplicationType,
ClientID: creds.ClientID,
ClientName: creds.ClientName,
JWKSURI: creds.JWKSURI,
ClientURI: creds.ClientURI,
Contacts: creds.Contacts,
GrantTypes: creds.GrantTypes,
ResponseTypes: creds.ResponseTypes,
RedirectURIs: creds.RedirectURIs,
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
ClientIDIssuedAt: creds.ClientIDIssuedAt,
}
}
func convertCredsFromInternal(creds *dcrstorage.ClientRegistrationResponse) *ClientRegistrationResponse {
if creds == nil {
return nil
}
return &ClientRegistrationResponse{
SubjectType: creds.SubjectType,
LogoURI: creds.LogoURI,
RegistrationAccessToken: creds.RegistrationAccessToken,
RegistrationClientURI: creds.RegistrationClientURI,
Scope: creds.Scope,
TokenEndpointAuthMethod: creds.TokenEndpointAuthMethod,
TOSURI: creds.TOSURI,
PolicyURI: creds.PolicyURI,
ClientSecret: creds.ClientSecret,
ApplicationType: creds.ApplicationType,
ClientID: creds.ClientID,
ClientName: creds.ClientName,
JWKSURI: creds.JWKSURI,
ClientURI: creds.ClientURI,
Contacts: creds.Contacts,
GrantTypes: creds.GrantTypes,
ResponseTypes: creds.ResponseTypes,
RedirectURIs: creds.RedirectURIs,
ClientSecretExpiresAt: creds.ClientSecretExpiresAt,
ClientIDIssuedAt: creds.ClientIDIssuedAt,
}
}
// NewDCRCredentialsStore creates a DCRCredentialsStore based on configuration.
// This factory function handles backend selection logic:
// - "file": Use file-based storage (default for backward compatibility)
// - "redis": Use Redis exclusively (fails if Redis unavailable)
// - "auto": Use Redis if available, fallback to file
func NewDCRCredentialsStore(
config *DynamicClientRegistrationConfig,
cacheManager *CacheManager,
logger *Logger,
) (DCRCredentialsStore, error) {
if config == nil {
return nil, fmt.Errorf("DCR config is nil")
}
if logger == nil {
logger = GetSingletonNoOpLogger()
}
backend := config.StorageBackend
if backend == "" {
backend = string(DCRStorageBackendAuto) // Default to auto selection
}
switch DCRStorageBackend(backend) {
case DCRStorageBackendFile:
logger.Info("Using file-based storage for DCR credentials")
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
case DCRStorageBackendRedis:
cache := getDCRCache(cacheManager)
if cache == nil {
return nil, fmt.Errorf("redis storage requested but Redis/cache not configured")
}
logger.Info("Using Redis storage for DCR credentials")
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
case DCRStorageBackendAuto:
// Try Redis first, fallback to file
cache := getDCRCache(cacheManager)
if cache != nil && cache.backend != nil {
logger.Info("Auto-selected Redis storage for DCR credentials")
return NewRedisCredentialsStore(cache, config.RedisKeyPrefix, logger), nil
}
logger.Info("Redis not available, using file storage for DCR credentials")
return NewFileCredentialsStore(config.CredentialsFile, logger), nil
default:
return nil, fmt.Errorf("unknown DCR storage backend: %s", backend)
}
}
// getDCRCache safely retrieves the DCR credentials cache from the cache manager
func getDCRCache(cacheManager *CacheManager) *UniversalCache {
if cacheManager == nil {
return nil
}
cacheManager.mu.RLock()
defer cacheManager.mu.RUnlock()
if cacheManager.manager == nil {
return nil
}
return cacheManager.manager.GetDCRCredentialsCache()
}
+663
View File
@@ -0,0 +1,663 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik
package traefikoidc
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
// TestFileCredentialsStore_SaveLoad tests the file-based credentials store
func TestFileCredentialsStore_SaveLoad(t *testing.T) {
t.Parallel()
// Create a temp directory for test files
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
testCreds := &ClientRegistrationResponse{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "test-access-token",
RegistrationClientURI: "https://example.com/register/test-client-id",
RedirectURIs: []string{"https://app.example.com/callback"},
GrantTypes: []string{"authorization_code", "refresh_token"},
ResponseTypes: []string{"code"},
TokenEndpointAuthMethod: "client_secret_basic",
}
ctx := context.Background()
providerURL := "https://auth.example.com"
t.Run("save and load credentials", func(t *testing.T) {
// Save credentials
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
// Load credentials
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
// Verify fields
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
tempDir2 := t.TempDir()
store2 := NewFileCredentialsStore(filepath.Join(tempDir2, "nonexistent.json"), logger)
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent file: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if exists {
t.Error("Expected credentials to not exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("delete non-existent credentials", func(t *testing.T) {
// Should not error
err := store.Delete(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Delete should not error for non-existent: %v", err)
}
})
}
// TestFileCredentialsStore_MultiProvider tests multi-provider support
func TestFileCredentialsStore_MultiProvider(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
provider1 := "https://auth1.example.com"
provider2 := "https://auth2.example.com"
creds1 := &ClientRegistrationResponse{
ClientID: "client-1",
ClientSecret: "secret-1",
}
creds2 := &ClientRegistrationResponse{
ClientID: "client-2",
ClientSecret: "secret-2",
}
// Save credentials for both providers
if err := store.Save(ctx, provider1, creds1); err != nil {
t.Fatalf("Failed to save creds1: %v", err)
}
if err := store.Save(ctx, provider2, creds2); err != nil {
t.Fatalf("Failed to save creds2: %v", err)
}
// Load and verify each provider's credentials
loaded1, err := store.Load(ctx, provider1)
if err != nil {
t.Fatalf("Failed to load creds1: %v", err)
}
if loaded1.ClientID != "client-1" {
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
}
loaded2, err := store.Load(ctx, provider2)
if err != nil {
t.Fatalf("Failed to load creds2: %v", err)
}
if loaded2.ClientID != "client-2" {
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
}
// Delete one shouldn't affect the other
if err := store.Delete(ctx, provider1); err != nil {
t.Fatalf("Failed to delete creds1: %v", err)
}
exists, _ := store.Exists(ctx, provider2)
if !exists {
t.Error("Provider 2 credentials should still exist")
}
}
// TestFileCredentialsStore_ConcurrentAccess tests thread safety
func TestFileCredentialsStore_ConcurrentAccess(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
creds := &ClientRegistrationResponse{
ClientID: "test-client",
ClientSecret: "test-secret",
}
var wg sync.WaitGroup
concurrency := 10
// Concurrent saves
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = store.Save(ctx, providerURL, creds)
}()
}
wg.Wait()
// Concurrent loads
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = store.Load(ctx, providerURL)
}()
}
wg.Wait()
// Final verification
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load after concurrent access: %v", err)
}
if loaded == nil || loaded.ClientID != "test-client" {
t.Error("Credentials corrupted after concurrent access")
}
}
// TestFileCredentialsStore_InvalidInput tests error handling
func TestFileCredentialsStore_InvalidInput(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
t.Run("empty provider URL uses default path", func(t *testing.T) {
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "", creds)
if err != nil {
t.Fatalf("Save with empty provider URL failed: %v", err)
}
loaded, err := store.Load(ctx, "")
if err != nil {
t.Fatalf("Load with empty provider URL failed: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials with empty provider URL")
}
})
}
// TestFileCredentialsStore_DefaultPath tests default path behavior
func TestFileCredentialsStore_DefaultPath(t *testing.T) {
t.Parallel()
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore("", logger)
// Just verify we can create with empty path and it has a default
if store.basePath() == "" {
t.Error("Expected default base path")
}
}
// TestRedisCredentialsStore_WithMemoryCache tests Redis store with in-memory cache
func TestRedisCredentialsStore_WithMemoryCache(t *testing.T) {
t.Parallel()
// Create an in-memory cache for testing
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
testCreds := &ClientRegistrationResponse{
ClientID: "redis-test-client",
ClientSecret: "redis-test-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "redis-test-token",
RedirectURIs: []string{"https://app.example.com/callback"},
}
t.Run("save and load credentials", func(t *testing.T) {
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
}
// TestRedisCredentialsStore_TTLFromExpiry tests TTL calculation
func TestRedisCredentialsStore_TTLFromExpiry(t *testing.T) {
t.Parallel()
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
t.Run("expired credentials should fail", func(t *testing.T) {
expiredCreds := &ClientRegistrationResponse{
ClientID: "expired-client",
ClientSecret: "expired-secret",
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(), // Already expired
}
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
if err == nil {
t.Error("Expected error for expired credentials")
}
})
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
creds := &ClientRegistrationResponse{
ClientID: "no-expiry-client",
ClientSecret: "no-expiry-secret",
ClientSecretExpiresAt: 0, // No expiry
}
err := store.Save(ctx, "https://noexpiry.example.com", creds)
if err != nil {
t.Fatalf("Failed to save credentials without expiry: %v", err)
}
})
}
// TestRedisCredentialsStore_InvalidInput tests error handling
func TestRedisCredentialsStore_InvalidInput(t *testing.T) {
t.Parallel()
cache := NewUniversalCache(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 100,
DefaultTTL: time.Hour,
Logger: GetSingletonNoOpLogger(),
})
defer cache.Close()
logger := GetSingletonNoOpLogger()
store := NewRedisCredentialsStore(cache, "", logger)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
}
// TestDCRStorageFactory tests the factory function
func TestDCRStorageFactory(t *testing.T) {
t.Parallel()
logger := GetSingletonNoOpLogger()
t.Run("nil config returns error", func(t *testing.T) {
_, err := NewDCRCredentialsStore(nil, nil, logger)
if err == nil {
t.Error("Expected error for nil config")
}
})
t.Run("file backend creates file store", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "file",
CredentialsFile: "/tmp/test-creds.json",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create file store: %v", err)
}
if store == nil {
t.Error("Expected store but got nil")
}
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore")
}
})
t.Run("redis backend without cache manager returns error", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "redis",
}
_, err := NewDCRCredentialsStore(config, nil, logger)
if err == nil {
t.Error("Expected error for redis backend without cache manager")
}
})
t.Run("auto backend without redis falls back to file", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "auto",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create auto store: %v", err)
}
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore for auto without redis")
}
})
t.Run("unknown backend returns error", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "unknown",
}
_, err := NewDCRCredentialsStore(config, nil, logger)
if err == nil {
t.Error("Expected error for unknown backend")
}
})
t.Run("empty backend defaults to auto", func(t *testing.T) {
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
StorageBackend: "",
}
store, err := NewDCRCredentialsStore(config, nil, logger)
if err != nil {
t.Fatalf("Failed to create store with empty backend: %v", err)
}
// Should default to file (auto without redis)
_, ok := store.(*FileCredentialsStore)
if !ok {
t.Error("Expected FileCredentialsStore for empty backend")
}
})
}
// TestDynamicClientRegistrar_WithStore tests registrar with store
func TestDynamicClientRegistrar_WithStore(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
}
registrar := NewDynamicClientRegistrarWithStore(
nil, // httpClient
logger,
config,
"https://auth.example.com",
store,
)
if registrar == nil {
t.Fatal("Expected registrar but got nil")
}
if registrar.store == nil {
t.Error("Expected store to be set")
}
// Test SetStore
newStore := NewFileCredentialsStore(filepath.Join(tempDir, "new.json"), logger)
registrar.SetStore(newStore)
if registrar.store != newStore {
t.Error("SetStore did not update the store")
}
}
// TestDynamicClientRegistrar_CredentialsFromStore tests loading from store
func TestDynamicClientRegistrar_CredentialsFromStore(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
providerURL := "https://auth.example.com"
ctx := context.Background()
// Pre-save credentials
testCreds := &ClientRegistrationResponse{
ClientID: "pre-saved-client",
ClientSecret: "pre-saved-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
}
if err := store.Save(ctx, providerURL, testCreds); err != nil {
t.Fatalf("Failed to pre-save credentials: %v", err)
}
config := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
}
registrar := NewDynamicClientRegistrarWithStore(
nil,
logger,
config,
providerURL,
store,
)
// Test loading via the internal method
loaded, err := registrar.loadCredentialsFromStore(ctx)
if err != nil {
t.Fatalf("Failed to load from store: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != "pre-saved-client" {
t.Errorf("ClientID mismatch: got %s", loaded.ClientID)
}
}
// TestFileCredentialsStore_CorruptedFile tests handling of corrupted files
func TestFileCredentialsStore_CorruptedFile(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(basePath, logger)
ctx := context.Background()
providerURL := "https://auth.example.com"
// Write corrupted JSON
filePath := store.getFilePath(providerURL)
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
t.Fatalf("Failed to write corrupted file: %v", err)
}
// Should return error for corrupted file
_, err := store.Load(ctx, providerURL)
if err == nil {
t.Error("Expected error for corrupted JSON")
}
}
// TestFileCredentialsStore_DirectoryCreation tests auto directory creation
func TestFileCredentialsStore_DirectoryCreation(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
logger := GetSingletonNoOpLogger()
store := NewFileCredentialsStore(deepPath, logger)
ctx := context.Background()
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "https://example.com", creds)
if err != nil {
t.Fatalf("Failed to save with nested directory: %v", err)
}
loaded, err := store.Load(ctx, "https://example.com")
if err != nil {
t.Fatalf("Failed to load after nested directory creation: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials from nested directory")
}
}
+7 -4
View File
@@ -25,7 +25,10 @@ The **audience** (`aud`) claim in a JWT identifies the intended recipient of the
### Why Does This Matter?
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
Audience validation rejects access tokens whose `aud` claim does not match the
expected audience, blocking the trivial form of token confusion where a token
issued for API A is presented to API B. (Defence in depth — pair with
short-lived tokens, rotation, and per-API client credentials.)
---
@@ -137,8 +140,8 @@ http:
**Recommended:** `true` for production
**What it does:**
- When `true`: Rejects sessions if access token audience doesn't match (prevents Scenario 2)
- When `false`: Logs warnings but allows fallback to ID token (backward compatible)
- When `true`: On audience mismatch, the middleware does **not** silently fall back to ID-token validation. It tries to refresh the access token first; if no refresh token is present (or refresh fails), the user is re-authenticated.
- When `false`: Logs warnings and falls back to ID-token validation (backward compatible).
**Example:**
```yaml
@@ -349,7 +352,7 @@ When opaque tokens are detected:
**Cache behavior:**
- Cache key: Token hash
- TTL: 5 minutes or token expiry (whichever is shorter)
- TTL: 5 minutes; if the token's `exp` is sooner, the cache entry expires at `exp` instead. Tokens without `exp` use the flat 5-minute TTL.
- Reduces introspection requests for frequently used tokens
---
+250
View File
@@ -0,0 +1,250 @@
# Bearer Token (M2M) Authentication
Opt-in path that lets API clients present `Authorization: Bearer <jwt>` to
authenticate without going through the cookie-based OIDC redirect flow.
Designed for machine-to-machine (M2M) traffic — services calling other
services with tokens minted by your OIDC provider.
The bearer path lives next to the cookie path: both go through the same
post-auth pipeline (`forwardAuthorized`) that injects identity headers,
checks `allowedRolesAndGroups`, applies security headers, and forwards to
the backend. The only thing that differs is how the principal is established
for that single request.
## Quick start
```yaml
enableBearerAuth: true
audience: https://api.example.com # REQUIRED when bearer is enabled
clientID: my-api-client-id
providerURL: https://issuer.example.com
sessionEncryptionKey: <32+-byte secret>
callbackURL: /oauth2/callback
```
That is the minimum. Everything else has a secure default.
## Obtaining bearer tokens from your OIDC provider
The middleware only **validates** bearer tokens — minting them is the IdP's job. For M2M traffic the canonical mint flow is OAuth 2.0 **`client_credentials`** (RFC 6749 §4.4); some providers require **JWT bearer assertion** (RFC 7523) instead.
```
┌────────────┐ POST /token ┌──────────┐
│ client │ ───────────────────────────────►│ IdP │
│ (service) │ grant_type=client_credentials │ /token │
│ │ client_id=… │ │
│ │ client_secret=… (or JWT) │ │
│ │ audience=https://api.… ←── critical │
│ │ scope=api:read … │
│ │ ◄───────────────────────────────│ │
│ │ access_token (JWT) │ │
└────────────┘ └──────────┘
│ GET /protected
│ Authorization: Bearer <access_token>
Your service (behind Traefik + this plugin)
```
The IdP returns a JWT signed by the same JWKs the middleware already trusts (it discovers them from `providerURL`/.well-known). On the first protected request, the middleware verifies signature + issuer + **audience** + `exp` + identifier claim, then forwards downstream with `X-Forwarded-User` set.
### Minimal worked example (Auth0-shape)
```bash
# 1. Mint a token
curl -s -X POST https://issuer.example.com/oauth/token \
-H 'Content-Type: application/json' \
-d '{
"grant_type": "client_credentials",
"client_id": "your-m2m-client-id",
"client_secret": "your-m2m-client-secret",
"audience": "https://api.example.com",
"scope": "api:read api:write"
}'
# → {"access_token":"eyJhbGciOiJSUzI1NiIs…","token_type":"Bearer","expires_in":86400,…}
# 2. Use it
curl -H 'Authorization: Bearer eyJhbGciOiJSUzI1NiIs…' https://api.example.com/protected
```
The `audience` field in the token request **must match** the `audience` you configured on the middleware. Mismatch → 401 with `Bearer error="invalid_token"`.
### Per-provider quick reference
| Provider | Grant | Token endpoint | Audience parameter | Notes |
|---|---|---|---|---|
| **Auth0** | `client_credentials` | `https://TENANT.auth0.com/oauth/token` | `audience=<your API identifier>` | Register an "API" + "Machine to Machine Application" authorised against that API. Without `audience` you get an opaque /userinfo token, which the bearer path rejects. See `docs/AUTH0_AUDIENCE_GUIDE.md`. |
| **Okta** | `client_credentials` | `https://TENANT.okta.com/oauth2/default/v1/token` | Configured in the authorization server; default `aud` is the auth-server URL | Service app must enable the `client_credentials` flow and be granted the requested scopes. |
| **Keycloak** | `client_credentials` | `https://kc/realms/REALM/protocol/openid-connect/token` | Configure an "Audience" mapper on a client scope, or use `client_id` as the audience | Client must have `serviceAccountsEnabled: true` plus role mappings. |
| **Entra ID / Azure AD** | `client_credentials` (v2.0 endpoint) | `https://login.microsoftonline.com/TENANT/oauth2/v2.0/token` | Pass `scope=<App ID URI>/.default`; `aud` ends up being the API's App ID URI | Requires an App Registration + API permissions + admin consent. **Use the v2.0 endpoint** — v1 issues Microsoft-proprietary access tokens that are opaque to non-Microsoft clients. |
| **AWS Cognito** | `client_credentials` | `https://YOUR_DOMAIN.auth.REGION.amazoncognito.com/oauth2/token` | Scopes from a "Resource Server" attached to your User Pool | App client must have `client_credentials` flow enabled. Use HTTP **Basic** auth header for `client_id:client_secret`. |
| **GitLab** | `client_credentials` | `https://gitlab.com/oauth/token` | Audience matches the GitLab issuer | Rarely used for protecting external APIs; better suited for GitLab's own resources. |
| **Google** | **JWT bearer (RFC 7523)***not* `client_credentials` | `https://oauth2.googleapis.com/token` | Signed assertion JWT carries `aud=https://oauth2.googleapis.com/token`; resulting access token is **opaque** unless you specifically request a Google-issued JWT for your API | Google service-account flow is not the best fit for this middleware (opaque tokens are rejected on the bearer path). Run Auth0 / Okta / Keycloak in front, or use ID-token-based flows on the cookie path. |
### RFC 7523 (JWT bearer assertion) — secretless alternative
When shared secrets are forbidden (FAPI, internal compliance), swap `client_secret` for a signed JWT assertion:
```
POST /token
grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer
assertion=<JWT signed by the client's private key>
```
The assertion JWT carries `iss=<client_id>`, `sub=<client_id>`, `aud=<token endpoint>`, `exp`. The IdP verifies the signature against a public key you've pre-registered and returns an access token.
This middleware already supports JWT assertions on the *middleware → IdP* hop via `clientAuthMethod: private_key_jwt` (see `docs/CONFIGURATION.md`). For the *client → IdP* hop, the same pattern applies — the client signs its own assertion.
### Operational notes
- **Token TTL is typically 124 hours.** Clients should refresh on `401`, not on a polling timer — saves the IdP.
- **Cache and reuse tokens.** The middleware caches verified tokens too, so repeated presentations are cheap. Clients SHOULD reuse a token until ~80 % of `expires_in`.
- **JWKS rotation is transparent.** The middleware auto-refreshes its JWKS cache when the IdP rotates keys. Clients don't need to do anything.
- **Revocation is generally not per-token** with `client_credentials`. If you need real-time revocation, set `requireTokenIntrospection: true` on the middleware and the IdP is consulted on every cache miss.
- **`scope` vs `audience`.** Scope says *what the client may do*; audience says *which service the token is for*. The middleware enforces audience; the backend service should enforce scope.
- **Secret hygiene.** Store `client_secret` in a secrets manager (Vault, AWS Secrets Manager, Kubernetes `Secret`). For higher assurance, switch the client to `private_key_jwt` (no shared secret at all).
### Quickest validation loop
```bash
# 1. Mint
TOKEN=$(curl -s -X POST https://issuer.example.com/oauth/token \
-H 'Content-Type: application/json' \
-d '{"grant_type":"client_credentials","client_id":"…","client_secret":"…","audience":"https://api.example.com"}' \
| jq -r .access_token)
# 2. Inspect claims to confirm aud/iss/exp match the middleware config
echo "$TOKEN" | cut -d. -f2 | base64 -d 2>/dev/null | jq
# 3. Hit the protected route
curl -i -H "Authorization: Bearer $TOKEN" https://api.example.com/protected
```
`HTTP/1.1 200` with `X-Forwarded-User` on the backend confirms the loop works end-to-end. `401` with `WWW-Authenticate: Bearer error="invalid_token"` plus a middleware debug log explaining the rejection (audience mismatch, ID token presented, `iat` outside the 24h window, etc.) confirms the hardening is firing as designed.
## Threat model and design rules
Bearer authentication has materially different security properties from
cookie sessions: no `HttpOnly`/`Secure`/`SameSite` shielding, the token is
visible in headers and logs, and it's easier to exfiltrate. The bearer path
treats every one of these as a first-class concern.
| Property | Behaviour | Why |
|---|---|---|
| Default state | `enableBearerAuth=false` | Bearer is opt-in; existing deployments observe no change. |
| Audience | **Mandatory.** Startup fails if `audience` is empty when bearer is enabled. | Eliminates the "token issued for service B accepted by service A" confusion attack. |
| Token format | JWT only (3 segments, JOSE-encoded). Opaque tokens are not accepted on the bearer path. | Matches the validation pipeline; opaque tokens require introspection only and bypass JWT-specific defences. |
| `alg` allowlist | Hard-pinned asymmetric: `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. Checked **before** any JWKS fetch. | Denies `alg=none` and `alg=HS*` probes; prevents attacker noise from amplifying into JWKS round-trips. |
| `kid` hardening | Max 256 bytes; charset `[A-Za-z0-9._\-=]`. Checked **before** JWKS fetch. | Prevents cache-key explosion / pathological-`kid` JWKS amplification. |
| Token type | ID tokens are explicitly rejected (`nonce` claim, `typ: at+jwt`, `token_use=id`, scope/aud heuristics — reuses the existing `detectTokenType` helper). | ID tokens are not API credentials; treating them as such is classic token confusion. |
| Multi-audience | When `aud` is an array of length > 1, the token must carry `azp == clientID`. | OIDC §2 hardening against tokens minted for one client being replayed by another. |
| `iat` upper-age | Rejects tokens older than `maxTokenAgeSeconds` (default 24h). | Bounds clock-manipulation / forever-token abuse, even if `exp` is far in the future. |
| Identifier claim | `bearerIdentifierClaim` (default `"sub"`). Resolved value drives `X-Forwarded-User`. | Decoupled from the cookie path's `UserIdentifierClaim` (default `email`) so the M2M flow can never accidentally trust an unverified email. |
| Identifier sanitisation | Length cap (`maxIdentifierLength`, default 256). Rejects control chars, Unicode bidi-overrides (U+202AU+202E, U+2066U+2069), and the delimiters `, ; =`. | Defence in depth against downstream header injection / log injection / admin-UI spoofing. |
| JTI replay marking | Bearer path skips the JTI **Set** (so the same token can be reused until `exp`) but the **Get** stays active. | Allows legitimate bearer reuse without false-positive replay detection; revoked tokens (added to the blacklist by `RevokeToken`) still fail immediately. |
| Mixed bearer + cookie | **Cookie wins by default.** Flip to bearer-wins with `bearerOverridesCookie=true`. | Safer against browser/extension/proxy bearer injection scenarios. The cookie is the authoritative authenticator when present. |
| `Authorization` strip | `stripAuthorizationHeader=true` by default. | Keeps the raw token out of downstream services and their logs. |
| Excluded URLs | `Authorization` is stripped on excluded paths when `enableBearerAuth=true`. | Prevents bearer leakage into public health/metrics endpoint logs and prevents recon via excluded paths. |
| Per-IP throttle | After `bearerFailureThreshold` consecutive 401s from one source IP within `bearerFailureWindowSeconds`, further bearer requests from that IP return `429 Too Many Requests` + `Retry-After` for `bearerFailurePenaltySeconds`. | Limits offline-guessing-style attacks and protects the shared rate-limiter / JWKS endpoint. |
| Optional introspection | `requireTokenIntrospection=true` calls RFC 7662 introspection on every cache miss. Introspection result is cached briefly. Endpoint failure returns `503` (distinguishes infra outage from credential rejection). | Real-time revocation for high-assurance environments. Adds per-request IdP latency. |
| Response shape | `401 Unauthorized` with generic body. `WWW-Authenticate: Bearer error="invalid_token"` per RFC 6750 §3 (toggleable via `bearerEmitWWWAuthenticate`). `403` for roles/groups denial. `429` for throttle. `503` for introspection-endpoint outage. | Auditable from spec to code; reason categories never leak into the response body. |
| Logging | Failure reason + identifier hash (SHA-256 truncated to 8 hex chars) logged at debug. Raw tokens are never logged. | Audit trail without secrets-in-logs. |
## Configuration reference
| Field | Default | Description |
|---|---|---|
| `enableBearerAuth` | `false` | Master switch for the bearer path. |
| `audience` | (unset) | **Required** when `enableBearerAuth=true`. Reuses the existing global `audience` field. |
| `bearerIdentifierClaim` | `"sub"` | JWT claim used as the principal identifier. `"email"` is rejected at startup. |
| `stripAuthorizationHeader` | `true` | Remove the `Authorization` header before forwarding to the backend. Disable only when a downstream needs to re-verify the bearer. |
| `bearerEmitWWWAuthenticate` | `true` | Include `WWW-Authenticate: Bearer error="..."` on 401 responses (RFC 6750 §3). Disable to reduce recon signal. |
| `bearerOverridesCookie` | `false` | Cookie wins when both are present (default). Set `true` for the AWS/GCP/Kubernetes bearer-wins convention. |
| `maxTokenAgeSeconds` | `86400` | Upper bound on `iat` claim age (24h). Set `0` to disable the check (not recommended). |
| `maxIdentifierLength` | `256` | Length cap for the post-sanitisation identifier. |
| `bearerFailureThreshold` | `20` | Consecutive 401s from one IP that trip the throttle. |
| `bearerFailureWindowSeconds` | `60` | Rolling window over which 401s are counted. |
| `bearerFailurePenaltySeconds` | `60` | Duration of the 429 penalty box after the threshold trips. |
| `requireTokenIntrospection` | `false` | Call RFC 7662 introspection on every cache miss. Adds per-request IdP latency. |
## What the bearer path does NOT do
- **Human-user / browser flows.** The bearer path is M2M-only in this
iteration. Browser SPAs that want to attach a bearer to fetch calls work
if your backend treats them as machine clients, but the spec defaults are
tuned for service-to-service traffic.
- **Opaque access tokens.** Tokens must be JWTs. Introspection is a
revocation overlay on top of JWT verification, not a substitute for it.
- **`email_verified` enforcement.** The bearer path rejects `email` as the
identifier claim at startup precisely because `email_verified` is not
enforced in this iteration. Adding human-user bearer support is a
follow-up that must include this check.
- **mTLS / API keys.** Out of scope. The `principal` abstraction enables
adding these later as additional auth methods that produce a principal
for the shared `forwardAuthorized` pipeline.
- **SSE / WebSocket bypass with bearer.** Bypass paths keep their existing
cookie-only behaviour; bearer headers are ignored on those endpoints.
Documented limitation; widen by removing the bypass if you need bearer on
streaming endpoints.
## Operational guidance
- **Always set `strictAudienceValidation: true` when bearer is enabled.**
Startup logs a recommendation if you don't.
- **Set a tight `maxTokenAgeSeconds`** for environments where tokens are
expected to be minted frequently — the default 24h is conservative.
- **Enable `requireTokenIntrospection`** if your IdP supports it and
revocation latency matters. Bearer-path introspection caches results for
a short window per token.
- **Monitor 429s.** Sustained 429 traffic indicates either a buggy client
loop or an active credential-stuffing attempt. The throttle is your
primary signal for both.
- **`stripAuthorizationHeader=false` extends the token's blast radius** to
every downstream service that sees the request. Treat those services'
logs as token stores.
- **Bearer reuse is normal.** Don't enable per-token rate limiting; that's
what `bearerFailureThreshold` is for (per-IP, not per-token).
- **Cookie-wins is the safer default.** Only flip `bearerOverridesCookie`
if you control all clients and have audited that none of them present a
cookie alongside a bearer they don't intend to authenticate with.
## Failure response matrix
| Trigger | Status | Body | `WWW-Authenticate` |
|---|---|---|---|
| Empty bearer after prefix | 401 | `Unauthorized` | `Bearer error="invalid_request"` |
| Token over `MaxLength` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Not a 3-segment JWT | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Disallowed `alg` (e.g. none, HS*) | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Missing / oversized / bad-charset `kid` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Signature / issuer / audience / `exp` failure | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| `iat` older than `maxTokenAgeSeconds` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Multi-audience token without matching `azp` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Detected as ID token | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| JTI blacklisted (revoked) | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Introspection reports `active=false` | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Introspection endpoint failure | 503 | `Service Unavailable` | (none) |
| Identifier claim missing / empty | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Identifier fails sanitisation | 401 | `Unauthorized` | `Bearer error="invalid_token"` |
| Per-IP failure threshold tripped | 429 | `Too Many Requests` | (none); `Retry-After: <bearerFailurePenaltySeconds>` |
| Roles / groups not allowed | 403 | `Access denied` | (none) |
## Known follow-ups (deferred)
These are documented as future work, not blockers:
- **Human-user bearer with `email_verified` enforcement.** Requires
decoupling the email-claim guard from the startup rejection and adding a
per-request `email_verified=true` check.
- **Introspection respects `client_assertion`.** The existing introspection
helper uses `client_secret_basic` only; operators on `private_key_jwt`
will see introspection silently use basic auth.
- **Per-route bearer configuration.** Single middleware-wide setting in this
iteration.
## References
- [PR design spec](superpowers/specs/2026-05-18-bearer-token-auth-design.md) — full design rationale, alternatives considered, and per-section sign-off history.
- [RFC 6750](https://www.rfc-editor.org/rfc/rfc6750) — Bearer Token Usage.
- [RFC 7662](https://www.rfc-editor.org/rfc/rfc7662) — OAuth 2.0 Token Introspection.
- [RFC 9068](https://www.rfc-editor.org/rfc/rfc9068) — JWT Profile for OAuth 2.0 Access Tokens.
+219 -9
View File
@@ -5,6 +5,7 @@ Complete reference for all Traefik OIDC middleware configuration options.
## Table of Contents
- [Required Parameters](#required-parameters)
- [Client Authentication](#client-authentication)
- [Optional Parameters](#optional-parameters)
- [Security Options](#security-options)
- [Session Management](#session-management)
@@ -22,7 +23,7 @@ Complete reference for all Traefik OIDC middleware configuration options.
|-----------|------|-------------|---------|
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
| `clientSecret` | string | OAuth 2.0 client secret | `your-client-secret` |
| `clientSecret` | string | OAuth 2.0 client secret. Required when `clientAuthMethod` is unset, `client_secret_post`, or `client_secret_basic`. Optional when `clientAuthMethod: private_key_jwt`. | `your-client-secret` |
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
@@ -45,6 +46,129 @@ spec:
---
## Client Authentication
The middleware supports three client authentication methods at the token and
revocation endpoints. The default is `client_secret_post` (current behavior);
`private_key_jwt` is opt-in and backwards compatible.
| Method | Default | Description |
|--------|---------|-------------|
| `client_secret_post` | yes | `client_id` + `client_secret` in the request body. |
| `client_secret_basic` | no | RFC 6749 §2.3.1 — `client_id` + `client_secret` in the `Authorization: Basic` header (form-urlencoded then base64); not in the body. |
| `private_key_jwt` | no | RFC 7523 §2.2 — plugin signs a short-lived JWT with a private key and sends it as `client_assertion`. |
Select via `clientAuthMethod`:
```yaml
clientAuthMethod: private_key_jwt
```
### client_secret_post
Default. The plugin sends `client_id` and `client_secret` as form parameters
in the token / revocation request body. No additional configuration required.
### private_key_jwt
Asymmetric client authentication per
[RFC 7523 §2.2](https://www.rfc-editor.org/rfc/rfc7523). Use this when your
IdP enforces short secret TTLs, when policy mandates secretless clients, or
when you want to avoid distributing a shared secret to the proxy.
For each token / revocation request the plugin builds a JWS with:
- `iss` = `sub` = `clientID`
- `aud` = token endpoint URL
- `iat` = now, `exp` = now + 60s
- `jti` = random hex per request
- `kid` header = `clientAssertionKeyID`
**Required fields:**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `clientAuthMethod` | string | `client_secret_post` | Set to `private_key_jwt`. |
| `clientAssertionPrivateKey` | string | none | Inline PEM private key. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8, PKCS#1, and SEC1 formats accepted. |
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk. Mutually exclusive with `clientAssertionPrivateKey`. |
| `clientAssertionKeyID` | string | none | `kid` header inserted in the JWS. Must match the public key registered with the IdP. |
| `clientAssertionAlg` | string | `RS256` | One of `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`, `ES256`, `ES384`, `ES512`. |
When `clientAuthMethod: private_key_jwt`, `clientSecret` is optional.
**Example — inline PEM:**
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-auth
spec:
plugin:
traefikoidc:
providerURL: https://idp.example.com
clientID: my-client-id
sessionEncryptionKey: your-32-byte-encryption-key-here
callbackURL: /oauth2/callback
clientAuthMethod: private_key_jwt
clientAssertionKeyID: key-2026-01
clientAssertionAlg: RS256
clientAssertionPrivateKey: |
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7VJTUt9Us8cKj
MZj4ev7QnMa1mYV3Kx1jRkH5YwXQ7N2J2j8K5pP6h0oZmXq1yQv4r8wZb3sH9D2k
... (truncated) ...
-----END PRIVATE KEY-----
```
**Example — key on disk:**
```yaml
clientAuthMethod: private_key_jwt
clientAssertionKeyPath: /etc/traefik/oidc/client-key.pem
clientAssertionKeyID: key-2026-01
clientAssertionAlg: RS256
```
**Generating an RS256 key with OpenSSL:**
```bash
openssl genpkey -algorithm RSA -pkeyopt rsa_keygen_bits:2048 \
-out client-key.pem
openssl rsa -in client-key.pem -pubout -out client-pub.pem
```
Register `client-pub.pem` (or its JWK form) with your IdP under the same
`kid` you set in `clientAssertionKeyID`.
**Notes:**
- The private key is parsed once at plugin startup. Key rotation requires a
Traefik reload.
- Assertion lifetime is fixed at 60 seconds.
- A fresh random `jti` is generated per request.
- The `aud` claim is the token endpoint URL (from discovery).
- Tracking issue:
[#135](https://github.com/lukaszraczylo/traefikoidc/issues/135).
### client_secret_basic
Per [RFC 6749 §2.3.1][rfc6749-2-3-1], the plugin sends the client credentials
in an `Authorization: Basic` header instead of the body. Both halves
(`client_id`, `client_secret`) are form-urlencoded individually, joined with
a colon, then base64-encoded. Use this when your IdP requires Basic auth at
the token endpoint and rejects credentials in the body.
```yaml
clientAuthMethod: client_secret_basic
clientID: your-client-id
clientSecret: your-client-secret
```
[rfc6749-2-3-1]: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
---
## Optional Parameters
| Parameter | Type | Default | Description |
@@ -52,23 +176,55 @@ spec:
| `logoutURL` | string | `callbackURL + "/logout"` | Path for logout requests |
| `postLogoutRedirectURI` | string | `/` | Redirect URL after logout |
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
| `forceHTTPS` | bool | `false` | Force HTTPS for redirect URIs |
| `forceHTTPS` | bool | `true` | Force HTTPS for redirect URIs (set `false` only for plaintext HTTP local dev) |
| `rateLimit` | int | `100` | Maximum requests per second |
| `excludedURLs` | []string | none | Paths that bypass authentication |
| `excludedURLs` | []string | none | Paths that bypass authentication, matched at a path-segment or file-extension boundary |
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
| `clientAuthMethod` | string | `client_secret_post` | Client authentication method at token/revocation endpoints. One of `client_secret_post`, `client_secret_basic`, `private_key_jwt`. See [Client Authentication](#client-authentication). |
| `clientAssertionPrivateKey` | string | none | Inline PEM private key for `private_key_jwt`. Mutually exclusive with `clientAssertionKeyPath`. PKCS#8 / PKCS#1 / SEC1. |
| `clientAssertionKeyPath` | string | none | Path to PEM private key on disk for `private_key_jwt`. Mutually exclusive with `clientAssertionPrivateKey`. |
| `clientAssertionKeyID` | string | none | `kid` header for `private_key_jwt` assertions. Required when `clientAuthMethod: private_key_jwt`. |
| `clientAssertionAlg` | string | `RS256` | Signing algorithm for `private_key_jwt`. One of `RS256/384/512`, `PS256/384/512`, `ES256/384/512`. |
### TLS Termination at Load Balancer
If running Traefik behind a load balancer (AWS ALB, Google Cloud LB, Azure App Gateway) that terminates TLS:
`forceHTTPS` defaults to `true`, so redirect URIs always use `https://`. This is
the correct default behind any TLS-terminating load balancer (AWS ALB, Google
Cloud LB, Azure App Gateway) — `X-Forwarded-Proto` cannot be trusted (ALB may
overwrite it).
```yaml
forceHTTPS: true # Required for correct redirect URIs
```
Set `forceHTTPS: false` only when you serve OIDC over plaintext HTTP (local
dev). Otherwise leave it at default.
Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures.
### Streaming Endpoints (SSE and WebSocket)
The middleware automatically bypasses the OIDC redirect for two request kinds
that browsers cannot follow a 302 on:
| Bypass | Triggered by |
|--------|--------------|
| Server-Sent Events (SSE) | `Accept: text/event-stream` |
| WebSocket upgrade | `Upgrade: websocket` + `Connection: upgrade` (RFC 6455) |
These requests do **not** require any explicit configuration — they are
handled implicitly. However, the bypass is **not** unauthenticated:
- A valid, encrypted session cookie is required. Requests without one are
rejected (the connection cannot proceed to the backend).
- The session cookie is sealed with `sessionEncryptionKey`, so the
`authenticated` flag cannot be forged.
- Validation is cookie-only — no JWK fetch / signature verification — so
streaming endpoints keep working when the OIDC provider is briefly
unavailable.
- The user identifier from the session is forwarded as `X-Forwarded-User`
(and `X-Auth-Request-User` unless `minimalHeaders: true`).
For browser clients, the user must complete the normal OIDC flow on a
regular HTTP page first; the resulting session cookie is then reused on the
SSE / WebSocket connection.
---
@@ -105,6 +261,26 @@ strictAudienceValidation: true
| `disableReplayDetection` | bool | `false` | Disable JTI-based replay attack detection |
| `allowPrivateIPAddresses` | bool | `false` | Allow private IPs in provider URLs |
### Bearer-token (M2M) authentication
Opt-in path that accepts `Authorization: Bearer <jwt>` instead of the cookie
session flow. M2M-only, default off, audience-mandatory. See
[docs/BEARER_AUTH.md](BEARER_AUTH.md) for the threat model and operational
guidance.
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enableBearerAuth` | bool | `false` | Master switch. Startup fails if true with empty `audience` or with `bearerIdentifierClaim=email`. |
| `bearerIdentifierClaim` | string | `"sub"` | JWT claim used as the principal identifier. `"email"` is rejected at startup. |
| `stripAuthorizationHeader` | bool | `true` | Strip `Authorization` from forwarded requests after successful bearer auth. |
| `bearerEmitWWWAuthenticate` | bool | `true` | Emit RFC 6750 `WWW-Authenticate: Bearer error="..."` hints on 401. |
| `bearerOverridesCookie` | bool | `false` | Cookie wins when both bearer and cookie are present (default). Set true for bearer-wins. |
| `maxTokenAgeSeconds` | int64 | `86400` | Upper bound on `iat` claim age (24h). 0 disables the check. |
| `maxIdentifierLength` | int | `256` | Length cap on the sanitised principal identifier. |
| `bearerFailureThreshold` | int | `20` | Consecutive 401s from one source IP that trip the throttle. |
| `bearerFailureWindowSeconds` | int | `60` | Rolling window for counting 401s. |
| `bearerFailurePenaltySeconds` | int | `60` | 429 + `Retry-After` duration after the threshold trips. |
---
## Session Management
@@ -113,6 +289,7 @@ strictAudienceValidation: true
|-----------|------|---------|-------------|
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
| `maxRefreshTokenAgeSeconds` | int | `21600` | Heuristic max age (in seconds) of a stored refresh token. Once exceeded, requests treat the RT as expired up front (returns 401 to AJAX, triggers full re-auth on navigations) instead of grant-spamming the IdP with `invalid_grant` retries. IdPs do not advertise RT TTL on the wire, so this is intentionally a conservative heuristic — tune to match your provider. Set `0` to disable. Default `21600` (6h). |
| `cookieDomain` | string | auto-detected | Domain for session cookies |
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
@@ -384,10 +561,14 @@ scopes:
### Dynamic Client Registration (RFC 7591)
Dynamic Client Registration allows the middleware to automatically register itself with the OIDC provider, eliminating the need to manually create client credentials.
**Basic Configuration (Single Instance):**
```yaml
dynamicClientRegistration:
enabled: true
initialAccessToken: "your-token" # Optional
initialAccessToken: "your-token" # Optional, if provider requires it
persistCredentials: true
credentialsFile: "/tmp/oidc-credentials.json"
clientMetadata:
@@ -400,6 +581,35 @@ dynamicClientRegistration:
- "refresh_token"
```
**Multi-Replica Deployment (Kubernetes):**
For Kubernetes deployments with multiple replicas, use Redis storage to share credentials across all instances and prevent registration race conditions:
```yaml
dynamicClientRegistration:
enabled: true
persistCredentials: true
storageBackend: "redis" # Share credentials via Redis
redisKeyPrefix: "myapp:dcr:" # Optional custom prefix
clientMetadata:
redirect_uris:
- "https://your-app.com/oauth2/callback"
client_name: "My Application"
redis:
enabled: true
address: "redis:6379"
cacheMode: "redis"
```
**Storage Backend Options:**
| Backend | Description | Use Case |
|---------|-------------|----------|
| `file` | Store credentials in local file | Single instance deployments |
| `redis` | Store credentials in Redis | Multi-replica Kubernetes deployments |
| `auto` | Use Redis if available, fallback to file | Flexible deployments (default) |
### Multi-Replica Deployment
Without Redis, disable replay detection:
+95
View File
@@ -0,0 +1,95 @@
# Dynamic Client Registration (RFC 7591)
The middleware can register itself with an OIDC provider at startup instead of
using a pre-provisioned `clientID` / `clientSecret`. Useful for multi-tenant
deployments, self-service integrations, and ephemeral environments.
## How it works
1. Middleware reads `registration_endpoint` from `.well-known/openid-configuration`.
2. If `clientID` is empty, it `POST`s `clientMetadata` to the registration endpoint.
3. Returned `client_id` / `client_secret` are cached, optionally persisted.
4. Subsequent requests use the registered credentials.
For multi-replica deployments, set `storageBackend: redis` so all replicas
share one client and avoid registration races.
## Configuration
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-dcr
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: https://your-oidc-provider.com
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
callbackURL: /oauth2/callback
dynamicClientRegistration:
enabled: true
persistCredentials: true
storageBackend: redis # file | redis | auto
initialAccessToken: "" # optional, for protected endpoints
registrationEndpoint: "" # optional, override discovery
credentialsFile: /tmp/oidc-client-credentials.json
redisKeyPrefix: "dcr:creds:"
clientMetadata:
redirect_uris:
- https://app.example.com/oauth2/callback
client_name: My Application
application_type: web
grant_types: [authorization_code, refresh_token]
response_types: [code]
token_endpoint_auth_method: client_secret_basic
contacts: [admin@example.com]
```
## Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `enabled` | `false` | Enable DCR. |
| `persistCredentials` | `false` | Save returned credentials for reuse across restarts. |
| `storageBackend` | `auto` | `file`, `redis`, or `auto` (Redis if available, else file). |
| `credentialsFile` | `/tmp/oidc-client-credentials.json` | Path for file-backed storage. Mode `0600`. |
| `redisKeyPrefix` | (none — set explicitly) | Key prefix for Redis-backed storage. The code does not inject a default; if unset, keys have no prefix. `dcr:creds:` is a sensible convention. |
| `registrationEndpoint` | discovered | Override the discovered endpoint. |
| `initialAccessToken` | none | Bearer token for protected registration endpoints. |
| `clientMetadata.redirect_uris` | required | Callback URIs for the OAuth flow. |
| `clientMetadata.client_name` | none | Human-readable client name. |
| `clientMetadata.application_type` | `web` | `web` or `native`. |
| `clientMetadata.grant_types` | `[authorization_code, refresh_token]` | OAuth grant types. |
| `clientMetadata.response_types` | `[code]` | OAuth response types. |
| `clientMetadata.token_endpoint_auth_method` | `client_secret_basic` | `client_secret_basic`, `client_secret_post`, or `none`. |
| `clientMetadata.scope` | none | Space-separated scopes. |
| `clientMetadata.contacts` | none | Admin email addresses. |
| `clientMetadata.logo_uri` | none | Logo URL for consent screens. |
| `clientMetadata.client_uri` | none | Client homepage URL. |
| `clientMetadata.policy_uri` | none | Privacy policy URL. |
| `clientMetadata.tos_uri` | none | Terms of service URL. |
## Provider support
The middleware does not gate DCR by provider — if the provider exposes a
`registration_endpoint` in its discovery document (or you set
`registrationEndpoint` explicitly), DCR will attempt registration. The table
below is informational guidance based on each provider's published support.
| Provider | DCR | Notes |
|----------|-----|-------|
| Keycloak | Yes | Enable in realm settings. |
| Auth0 | Yes | Requires Management API token. |
| Okta | Yes | Enable Dynamic Client Registration in admin console. |
| Azure AD | Limited | Use App Registration API instead. |
| Google | No | Manual registration required. |
| AWS Cognito | No | Manual registration required. |
## Security notes
- Registration endpoints must be HTTPS (loopback excepted for local dev).
- Use `initialAccessToken` in production to gate registration.
- File-backed credentials use `0600`; protect the mount path.
- The plugin marks credentials invalid when within ~5 min of `client_secret_expires_at` but does **not** automatically re-register. If your provider sets a non-zero expiry, schedule manual rotation (delete the credentials file or Redis entry, restart) before that time.
+20 -99
View File
@@ -16,9 +16,8 @@ Guide for local development, testing, and contributing to the Traefik OIDC middl
## Prerequisites
- **Go 1.23+** for plugin compilation
- **Docker & Docker Compose** for local testing
- **OIDC Provider** credentials (Google, Azure, etc.)
- **Go 1.24+** (matches `go.mod`; CI runs Go 1.24.11)
- **OIDC Provider** credentials (Google, Azure, etc.) for any end-to-end test against a real provider
### Required Development Tools
@@ -40,110 +39,32 @@ go install golang.org/x/vuln/cmd/govulncheck@latest
## Local Development Setup
### Docker Compose Environment
The repository includes a Docker Compose setup for testing the plugin locally.
#### 1. Host Configuration
Add to `/etc/hosts`:
### Build and unit tests
```bash
127.0.0.1 hello.localhost
127.0.0.1 traefik.localhost
go mod tidy
go build ./...
go test ./... -short # fast loop, < 30 s
go test -race -timeout=15m ./...
```
#### 2. Plugin Configuration
### Sample plugin configurations
The plugin is loaded using Traefik's **local plugins mode**:
Working middleware/Traefik configs live in [`examples/`](../examples/):
- Plugin source: Parent directory (`../`)
- Mount path: `/plugins-local/src/github.com/lukaszraczylo/traefikoidc`
- Configuration: `experimental.localPlugins` in `traefik.yml`
- `complete-traefik-config.yaml` — full middleware example
- `redis-config.yaml` — Redis cache configuration
#### 3. OIDC Provider Setup
To run the plugin against a real Traefik instance, drop the project on disk
and load it via `experimental.localPlugins` in your Traefik static config —
see the [README install section](../README.md#install).
Edit `docker/dynamic.yml` with your provider details:
### Integration tests
**Google:**
```yaml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
providerURL: "https://accounts.google.com"
clientID: "your-client-id.apps.googleusercontent.com"
clientSecret: "your-google-client-secret"
sessionEncryptionKey: "your-32-character-encryption-key"
callbackURL: "/oauth2/callback"
logoutURL: "/oauth2/logout"
scopes:
- "openid"
- "email"
- "profile"
```
**Azure AD:**
```yaml
http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
providerURL: "https://login.microsoftonline.com/your-tenant-id/v2.0"
clientID: "your-azure-client-id"
clientSecret: "your-azure-client-secret"
sessionEncryptionKey: "your-32-character-encryption-key"
callbackURL: "/oauth2/callback"
scopes:
- "openid"
- "email"
- "profile"
```
#### 4. Start Environment
Integration tests live in `integration/`. Run them explicitly:
```bash
cd docker
docker-compose up -d
```
#### 5. Test Plugin
- **Protected App**: http://hello.localhost (redirects to OIDC)
- **Traefik Dashboard**: http://traefik.localhost:8080
### Development Workflow
1. **Edit plugin code** in the project root
2. **Build and test** (optional syntax check):
```bash
go mod tidy
go build .
go test ./...
```
3. **Restart Traefik** to reload plugin:
```bash
docker-compose restart traefik
```
4. **Test changes** at http://hello.localhost
### Debugging
**View plugin logs:**
```bash
docker-compose logs -f traefik | grep traefikoidc
```
**Check plugin loading:**
```bash
docker-compose logs traefik | grep -i plugin
```
**Verify plugin directory:**
```bash
docker-compose exec traefik ls -la /plugins-local/src/github.com/lukaszraczylo/traefikoidc/
go test ./integration/... -run Integration -v
```
---
@@ -299,7 +220,7 @@ The repository uses GitHub Actions for comprehensive validation with 20+ paralle
#### Testing (9 suites)
- Race Detector
- Coverage (75% threshold)
- Coverage (70% threshold, enforced in `pr.yaml`)
- Memory Leaks
- Integration Tests
- Regression Tests
@@ -323,13 +244,13 @@ Tests run in parallel for:
#### Performance & Build (3 checks)
- Benchmarks
- Multi-platform Build (linux/darwin x amd64/arm64)
- Go Version Compatibility (Go 1.23 & 1.24)
- Go Version Compatibility (currently Go 1.24.11 in CI)
### Quality Gates
All PRs must pass:
- All parallel checks
- 75% test coverage minimum
- 70% test coverage minimum
- Zero security vulnerabilities
- No race conditions
- No memory leaks
+3 -3
View File
@@ -23,10 +23,10 @@ Configuration reference for each supported OIDC provider.
| Provider | OIDC Support | Refresh Tokens | Auto-Detection | ID Tokens |
|----------|-------------|----------------|----------------|-----------|
| Google | Full | Yes | `accounts.google.com` | Yes |
| Azure AD | Full | Yes | `login.microsoftonline.com` | Yes |
| Azure AD | Full | Yes | `login.microsoftonline.com`, `sts.windows.net` | Yes |
| Auth0 | Full | Yes | `*.auth0.com` | Yes |
| Okta | Full | Yes | `*.okta.com` | Yes |
| Keycloak | Full | Yes | `/auth/realms/` path | Yes |
| Okta | Full | Yes | `*.okta.com`, `*.oktapreview.com`, `*.okta-emea.com` | Yes |
| Keycloak | Full | Yes | host containing `keycloak`, or `/realms/` in path (matches both `/auth/realms/` legacy and `/realms/` modern) | Yes |
| AWS Cognito | Full | Yes | `cognito-idp.*.amazonaws.com` | Yes |
| GitLab | Full | Yes | `gitlab.com` | Yes |
| GitHub | OAuth 2.0 Only | No | `github.com` | No |
+14 -6
View File
@@ -109,11 +109,11 @@ redis:
| `writeTimeout` | int | `3` | Write timeout (seconds) |
| `enableTLS` | bool | `false` | Enable TLS for connections |
| `tlsSkipVerify` | bool | `false` | Skip TLS certificate verification |
| `enableCircuitBreaker` | bool | `true` | Enable circuit breaker |
| `circuitBreakerThreshold` | int | `5` | Failures before circuit opens |
| `circuitBreakerTimeout` | int | `60` | Circuit reset timeout (seconds) |
| `enableHealthCheck` | bool | `true` | Enable periodic health checks |
| `healthCheckInterval` | int | `30` | Health check interval (seconds) |
| `enableCircuitBreaker` | bool | `false` | Wrap the Redis backend with a circuit breaker. **Recommended `true` in production.** |
| `circuitBreakerThreshold` | int | `5` | Consecutive failures before the circuit opens (only when `enableCircuitBreaker: true`). |
| `circuitBreakerTimeout` | int | `60` | Seconds the circuit stays open before allowing a probe (only when `enableCircuitBreaker: true`). |
| `enableHealthCheck` | bool | `false` | Wrap the Redis backend with periodic health checks. **Recommended `true` in production.** |
| `healthCheckInterval` | int | `30` | Health check interval in seconds (only when `enableHealthCheck: true`). |
| `hybridL1Size` | int | `500` | Max items in L1 cache (hybrid mode) |
| `hybridL1MemoryMB` | int64 | `10` | Max memory for L1 cache in MB |
@@ -134,13 +134,21 @@ REDIS_READ_TIMEOUT=3
REDIS_WRITE_TIMEOUT=3
REDIS_ENABLE_TLS=false
REDIS_TLS_SKIP_VERIFY=false
REDIS_HYBRID_L1_SIZE=500
REDIS_HYBRID_L1_MEMORY_MB=10
```
> Resilience fields (`enableCircuitBreaker`, `enableHealthCheck`,
> `circuitBreakerThreshold`, `circuitBreakerTimeout`, `healthCheckInterval`)
> have no environment variable fallback — set them in plugin configuration.
Invalid `cacheMode` values are rejected at plugin startup.
---
## Cache Modes
### Memory Mode (Default without Redis)
### Memory Mode (used when Redis is disabled)
```yaml
redis:
+2 -2
View File
@@ -6,8 +6,8 @@ Comprehensive testing infrastructure for traefikoidc.
| Metric | Value |
|--------|-------|
| Test files | 99 |
| Lines of test code | ~65,500 |
| Test files | 110 |
| Lines of test code | ~72,000 |
| Code coverage | 71.0% |
| Race conditions | None (all pass with `-race`) |
+156 -4
View File
@@ -90,6 +90,7 @@
<a href="#configuration" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Configuration</a>
<a href="#deployment" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Deployment</a>
<a href="#security" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Security</a>
<a href="#logout" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 font-medium">Logout</a>
</div>
<div class="flex items-center space-x-4">
<button id="theme-toggle" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 p-2 min-w-[44px] min-h-[44px] flex items-center justify-center" aria-label="Toggle theme">
@@ -114,6 +115,7 @@
<a href="#configuration" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Configuration</a>
<a href="#deployment" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Deployment</a>
<a href="#security" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Security</a>
<a href="#logout" class="block px-3 py-3 text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 hover:bg-gray-50 dark:hover:bg-gray-700 rounded font-medium">Logout</a>
</div>
</div>
</nav>
@@ -193,7 +195,7 @@
</div>
<div>
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-1">Dynamic Registration</h3>
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration for automatic client setup without manual configuration</p>
<p class="text-sm text-gray-600 dark:text-gray-400">RFC 7591 Dynamic Client Registration with Redis storage support for multi-replica deployments</p>
</div>
</div>
</div>
@@ -640,7 +642,7 @@ spec:
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientSecret</code></td>
<td class="py-2 px-3">OAuth 2.0 client secret</td>
<td class="py-2 px-3">OAuth 2.0 client secret. Only required when <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code> is unset or <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_post</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_basic</code>.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">sessionEncryptionKey</code></td>
@@ -716,6 +718,11 @@ spec:
<td class="py-2 px-3">86400</td>
<td class="py-2 px-3">Maximum session age in seconds (24 hours default)</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">maxRefreshTokenAgeSeconds</code></td>
<td class="py-2 px-3">21600</td>
<td class="py-2 px-3">Heuristic upper bound on stored refresh-token lifetime (6 hours default). Past this, the plugin treats the RT as expired without contacting the IdP. Set <code>0</code> to disable.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">cookiePrefix</code></td>
<td class="py-2 px-3">_oidc_raczylo_</td>
@@ -746,15 +753,48 @@ spec:
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Require RFC 7662 introspection for opaque tokens</td>
</tr>
<tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">disableReplayDetection</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Disable JTI replay detection (for multi-replica without Redis)</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code></td>
<td class="py-2 px-3">client_secret_post</td>
<td class="py-2 px-3">Selects how the plugin authenticates to the token endpoint. One of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_post</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret_basic</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionPrivateKey</code></td>
<td class="py-2 px-3">none</td>
<td class="py-2 px-3">Inline PEM private key used to sign client assertions for <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionKeyPath</code></td>
<td class="py-2 px-3">none</td>
<td class="py-2 px-3">Path to a PEM private key file. Alternative to <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionPrivateKey</code>.</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionKeyID</code></td>
<td class="py-2 px-3">none</td>
<td class="py-2 px-3">JWS <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">kid</code> header value. Required when <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAuthMethod</code> is <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">private_key_jwt</code>.</td>
</tr>
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">clientAssertionAlg</code></td>
<td class="py-2 px-3">RS256</td>
<td class="py-2 px-3">Signing algorithm for the client assertion. One of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">RS512</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">PS512</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES256</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES384</code>/<code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">ES512</code>.</td>
</tr>
</tbody>
</table>
</div>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Private Key JWT (RFC 7523)</h3>
<p class="text-gray-600 dark:text-gray-400 mb-3 text-sm">Use this when your IdP (Entra ID, Okta, Auth0, Keycloak) pressures short-lived secrets, or when policy mandates secretless service-to-service authentication. The plugin signs a 60-second assertion with the configured private key and sends it as <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_assertion</code> instead of <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">client_secret</code>. Public-key registration on the IdP replaces shared-secret rotation. See <a href="https://www.rfc-editor.org/rfc/rfc7523" target="_blank" rel="noopener" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 underline">RFC 7523</a> and <a href="https://github.com/lukaszraczylo/traefikoidc/issues/135" target="_blank" rel="noopener" class="text-gray-600 dark:text-gray-300 hover:text-gray-900 dark:hover:text-gray-100 underline">issue #135</a>.</p>
<pre class="bg-gray-900 text-gray-100 p-4 rounded-lg overflow-x-auto text-sm"><code>clientAuthMethod: private_key_jwt
clientAssertionKeyPath: /etc/traefik/oidc-client.pem
clientAssertionKeyID: my-client-key-2026
# clientSecret no longer required</code></pre>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-3">Example: Google Workspace with Domain Restriction</h3>
@@ -856,7 +896,54 @@ spec:
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.enableTLS</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Enable TLS for Redis connections</td>
<td class="py-2 px-3">Enable TLS for Redis connections (e.g. AWS ElastiCache in-transit encryption)</td>
</tr>
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis.tlsSkipVerify</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Skip TLS server certificate verification (testing only; not recommended in production)</td>
</tr>
</tbody>
</table>
</div>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Dynamic Client Registration (RFC 7591)</h3>
<p class="text-gray-600 dark:text-gray-400 mb-3 text-sm">Automatically register your application with the OIDC provider. Supports Redis storage for multi-replica deployments:</p>
<div class="overflow-x-auto mb-4">
<table class="w-full text-sm">
<thead>
<tr class="border-b border-gray-200 dark:border-gray-700">
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Parameter</th>
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Default</th>
<th class="text-left py-2 px-3 text-gray-900 dark:text-gray-100">Description</th>
</tr>
</thead>
<tbody class="text-gray-600 dark:text-gray-400">
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.enabled</code></td>
<td class="py-2 px-3">false</td>
<td class="py-2 px-3">Enable dynamic client registration</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.persistCredentials</code></td>
<td class="py-2 px-3">true</td>
<td class="py-2 px-3">Persist registered credentials across restarts</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.storageBackend</code></td>
<td class="py-2 px-3">auto</td>
<td class="py-2 px-3">Storage backend: <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">file</code>, <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">redis</code>, or <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">auto</code> (uses Redis if available)</td>
</tr>
<tr class="border-b border-gray-100 dark:border-gray-800">
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.redisKeyPrefix</code></td>
<td class="py-2 px-3">dcr:creds:</td>
<td class="py-2 px-3">Redis key prefix for DCR credentials</td>
</tr>
<tr>
<td class="py-2 px-3"><code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">dynamicClientRegistration.clientMetadata.redirect_uris</code></td>
<td class="py-2 px-3">-</td>
<td class="py-2 px-3">Redirect URIs for the registered client (required)</td>
</tr>
</tbody>
</table>
@@ -1177,6 +1264,71 @@ spec:
</div>
</section>
<!-- IdP-Initiated Logout Section -->
<section id="logout" class="py-12 sm:py-16 md:py-20 bg-white dark:bg-gray-900 theme-transition">
<div class="max-w-6xl mx-auto px-4 sm:px-6">
<div class="text-center mb-8 sm:mb-12">
<h2 class="text-2xl sm:text-3xl md:text-4xl font-bold text-gray-900 dark:text-gray-100 mb-3 sm:mb-4">IdP-Initiated Logout</h2>
<p class="text-base sm:text-lg text-gray-600 dark:text-gray-300 px-4">Support for OIDC Back-Channel and Front-Channel Logout specifications</p>
</div>
<div class="max-w-4xl mx-auto">
<div class="grid md:grid-cols-2 gap-6 mb-8">
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
<i class="fas fa-server mr-2 text-blue-500"></i>
Back-Channel Logout
</h3>
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
Server-to-server logout notification. The IdP sends a signed JWT (logout_token) directly to your application when a user logs out.
</p>
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
<li>&#8226; Signed JWT logout tokens</li>
<li>&#8226; Session ID (sid) based invalidation</li>
<li>&#8226; Subject (sub) based invalidation</li>
<li>&#8226; Works behind firewalls</li>
</ul>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4 flex items-center">
<i class="fas fa-browser mr-2 text-purple-500"></i>
Front-Channel Logout
</h3>
<p class="text-gray-600 dark:text-gray-400 text-sm mb-4">
Browser-based logout via iframe. The IdP embeds an iframe pointing to your logout endpoint during user logout.
</p>
<ul class="text-gray-600 dark:text-gray-400 space-y-2 text-sm">
<li>&#8226; Iframe-based session termination</li>
<li>&#8226; Immediate cookie invalidation</li>
<li>&#8226; Simple GET request handling</li>
<li>&#8226; Issuer validation</li>
</ul>
</div>
</div>
<div class="glass p-6 rounded-xl">
<h3 class="font-semibold text-gray-900 dark:text-gray-100 mb-4">Configuration Example</h3>
<pre class="bg-gray-900 text-gray-100 p-4 rounded-lg overflow-x-auto text-sm"><code>http:
middlewares:
oidc-auth:
plugin:
traefikoidc:
# ... other OIDC configuration ...
# Back-Channel Logout (server-to-server)
enableBackchannelLogout: true
backchannelLogoutURL: "/backchannel-logout"
# Front-Channel Logout (browser-based)
enableFrontchannelLogout: true
frontchannelLogoutURL: "/frontchannel-logout"</code></pre>
<p class="text-gray-600 dark:text-gray-400 text-sm mt-4">
Configure your IdP with the full URLs (e.g., <code class="bg-gray-200 dark:bg-gray-700 px-1 rounded">https://your-app.example.com/backchannel-logout</code>).
When a user logs out from the IdP, all their sessions across your applications will be invalidated.
</p>
</div>
</div>
</div>
</section>
<!-- Why Choose Section -->
<section class="py-12 sm:py-16 md:py-20 bg-gray-50 dark:bg-gray-800 theme-transition">
<div class="max-w-6xl mx-auto px-4 sm:px-6">
@@ -0,0 +1,459 @@
# Bearer Token Authentication — Design Spec
- **Date**: 2026-05-18
- **Status**: Design — pending implementation plan
- **Supersedes**: PR #93 (broken implementation; recommended to close in favour of this design)
## 1. Summary
Add an opt-in path that lets API clients (machine-to-machine) authenticate by presenting a signed access token in the `Authorization: Bearer <token>` header, bypassing the cookie-based OIDC redirect flow. Identity, roles, and authorization checks remain consistent with the existing cookie path; the only thing that changes is how the principal is established for that single request.
The feature is implemented by extracting a shared `forwardAuthorized` pipeline from the existing `processAuthorizedRequest`, introducing a `principal` value type, and adding a small bearer-specific entrypoint that builds a principal directly from a verified JWT — without synthesising a fake `SessionData`.
## 2. Motivation
PR #93 attempted this feature by building an in-memory `SessionData` from JWT claims and reusing `processAuthorizedRequest`. The approach has three latent defects:
1. The synthetic session omits `mainSession.Values["user_identifier"]`. `processAuthorizedRequest` reads it via `GetUserIdentifier()`; when empty it bails to `defaultInitiateAuthentication` and issues an OIDC redirect. The feature is non-functional in practice despite the unit test passing.
2. `verifyToken` accepts both ID tokens (audience match against `clientID`) and access tokens. ID tokens are not API credentials; treating them as such is a classic token-confusion vector.
3. `verifyToken` adds JTI to the replay blacklist on first verify. Once the verified-token cache evicts, subsequent reuse of the same bearer token triggers a false-positive replay rejection.
Rather than patch a synthetic-session approach that will keep generating bugs as `SessionData` evolves, this spec replaces it with a cleaner abstraction where session lifecycle and post-auth header injection live in separate units.
## 3. Goals
- Accept `Authorization: Bearer <jwt>` from M2M clients, validate the token, and forward the request downstream with identity headers populated.
- Enforce the same `allowedRolesAndGroups` policy as the cookie path.
- Default-off; safe defaults when enabled (audience required, ID tokens rejected, identifier sanitised).
- No behavioural change to the cookie path. Existing tests must continue to pass without modification.
## 4. Non-Goals
- Human-user / browser flows. Bearer is M2M-only in this iteration.
- Pure opaque access tokens on the bearer path. Tokens must be JWTs; introspection (RFC 7662) is supported *on top of* JWT verification for revocation state, not as a substitute for it.
- mTLS, API keys, or any other auth method. The `principal` abstraction enables them later, but they are not delivered here.
- Per-route bearer configuration. Single middleware-wide setting.
## 5. Decided Requirements
| Topic | Decision |
|---|---|
| Consumer type | Machine-to-machine (M2M) only |
| Token format | JWT only (signature, issuer, audience, exp) |
| Audience | Mandatory when feature enabled; startup fails if `Audience == ""` |
| Token type | Access tokens only; ID tokens explicitly rejected |
| Revocation | JWT-only verification by default; introspection (RFC 7662) opt-in via existing `RequireTokenIntrospection` |
| Identity claim | New `BearerIdentifierClaim` config (string, default `"sub"`). Bearer path reads this claim exclusively; does NOT use `UserIdentifierClaim` (which defaults to `"email"` and drives the cookie path). Resolved value must be a non-empty string. `sub` is mandatory per `jwt.go:416` regardless, so even with a different `BearerIdentifierClaim` the token must still carry a valid `sub`. Decoupling avoids the M2M-vs-human-user identity-claim conflict and the email-spoofing footgun. |
| Identifier sanitisation | Reject value containing any `unicode.IsControl` char, any Unicode bidi-override (U+202AU+202E, U+2066U+2069), leading/trailing whitespace, commas, semicolons, equals signs. Max length 256 bytes. |
| Token classifier | **Reuse existing `detectTokenType(jwt, token)` at `token_manager.go:187-303`** which already handles `nonce`, `typ: at+jwt`, `token_use`, `scope`, and aud-vs-clientID priority. Bearer path rejects any token where `detectTokenType == true` (ID token). Do not invent a parallel classifier. |
| Algorithm pinning | Hard-pin `alg ∈ {RS256, RS384, RS512, PS256, PS384, PS512, ES256, ES384, ES512}`, enforced **before** JWKS lookup on the bearer path. Prevents wasted JWKS fetches for `alg=none`/HS attacker probes. |
| `kid` hardening | `kid` ≤ 256 bytes, charset `[A-Za-z0-9._\-=]`. Reject before JWKS lookup. |
| Token age | Bearer path enforces `now - iat <= MaxTokenAgeSeconds` (default 86400 / 24h, configurable). Cookie path unchanged. |
| Multi-audience policy | If `aud` is an array (length > 1), require `azp` claim to be present and equal to `clientID`. Single-string `aud` unaffected. |
| Mixed bearer + cookie precedence | **Cookie wins by default** when both are presented (safer for browser scenarios). Operator opt-in: `BearerOverridesCookie=true` to flip. Either way, a warning is logged on the request. |
| Bearer + excluded URL | `Authorization` header is **stripped** before forwarding when the request hits an excluded URL. Prevents bearer leaking into public endpoints' downstream logs and prevents recon via excluded paths. |
| Per-source bearer 401 throttle | New sharded cache `failedBearerAttempts` keyed by client IP. After N (default 20) consecutive 401s from one IP within 1 minute, reject further bearer requests from that IP with 429 for 60s. Applied BEFORE `verifyToken` to deny JWKS amplification. |
| `Authorization` header passthrough | New `StripAuthorizationHeader` config, default `true` |
| Roles/groups gating | Same `allowedRolesAndGroups` rules as cookie path |
| Default state | `EnableBearerAuth` = `false` |
| JTI replay marking | Suppressed on bearer path; cookie path unchanged |
| Failure response shape | 401 with generic body; `WWW-Authenticate: Bearer error="invalid_token"` per RFC 6750 |
| Introspection endpoint outage | 503 (distinguishes infra outage from token rejection) |
| Mixed bearer + cookie | Bearer wins; cookie ignored on that request |
| SSE/WS bypass + bearer | Bypass paths keep cookie-only check; bearer header ignored on SSE/WS |
## 6. Architecture
```
┌──────────────────┐
HTTP req ──► │ ServeHTTP │ (existing entry; adds bearer detection)
└─────────┬────────┘
┌───────────┴────────────┐
▼ ▼
cookie / session bearer (Authorization: Bearer …)
│ │
▼ ▼
┌────────────────┐ ┌────────────────────┐
│ buildPrincipal │ │ buildPrincipal │
│ FromSession() │ │ FromBearerToken() │
└────────┬───────┘ └─────────┬──────────┘
│ produces *principal │
└──────────────┬───────────┘
┌────────────────────────────┐
│ forwardAuthorized(rw,req,p)│ (shared pipeline)
│ • roles/groups gate │
│ • header injection │
│ • header templates │
│ • security headers │
│ • cookie stripping │
│ • next.ServeHTTP │
└────────────────────────────┘
```
**Invariant**: `forwardAuthorized` never touches session storage. Session-specific concerns (Save, IsDirty, backchannel-logout invalidation) stay inside `processAuthorizedRequest` around the call to `forwardAuthorized`.
**Feature gate**: when `EnableBearerAuth == false`, the bearer-detection check in `ServeHTTP` is a no-op. Existing deployments observe byte-identical behaviour.
## 7. Components
### 7.1 `principal` type (new file `principal.go`)
```go
type principalSource int
const (
sourceSession principalSource = iota
sourceBearer
)
type principal struct {
Identifier string // drives X-Forwarded-User
Email string // optional, "" for M2M
Subject string // sub claim
ClientID string // azp / client_id, M2M caller
Claims map[string]interface{} // raw claims for templates / groups
AccessToken string // for X-Auth-Request-Token (gated by minimalHeaders)
IDToken string // "" on bearer path
RefreshToken string // "" on bearer path
Source principalSource
}
```
Pure data. No methods that mutate it. No I/O. No manager pointer.
### 7.2 `buildPrincipalFromSession(*SessionData) *principal` (new in `principal.go`)
Read-only adapter over existing `SessionData` getters: `GetUserIdentifier`, `GetEmail`, `GetAccessToken`, `GetIDToken`, `GetRefreshToken`, cached claims via `GetIDTokenClaims`. Does not write back to the session. This is the only function that still knows about `SessionData`.
### 7.3 `buildPrincipalFromBearerToken(token string) (*principal, error)` (new in `bearer_auth.go`)
1. **Length / format guards**: `len(token) <= AccessTokenConfig.MaxLength`, exactly two dots, non-empty after trim.
2. **Parse header for early alg/kid pinning** (without trusting payload): decode JOSE header; reject if `alg` ∉ asymmetric allowlist; reject if `kid` missing, > 256 bytes, or contains chars outside `[A-Za-z0-9._\-=]`. This happens **before** JWKS lookup so attacker noise doesn't amplify into JWKS fetches.
3. **Per-IP 401 throttle check**: if this IP is in the `failedBearerAttempts` penalty box, return 429 immediately.
4. `t.verifyToken(token, verifyOpts{skipReplayMarking: true})` — reuses signature, issuer, audience, expiration, JTI Get (replay detection). The `skipReplayMarking` flag gates ONLY the JTI Set at `token_manager.go:108-143`; the JTI Get at `token_manager.go:44-47, 80-89` remains active so revoked tokens (via `RevokeToken` adding to blacklist) are still rejected.
5. **Re-parse claims** (`parseJWT(token)` is cheap and already done internally; reuse via a single decode if practical).
6. **Token-type guard**: call existing `detectTokenType(jwt, token)` (`token_manager.go:187-303`). Reject when it returns `true` (ID token). Belt-and-braces: also reject if `claims["nonce"]` is a non-empty string or `claims["token_use"] == "id"`.
7. **Multi-audience hardening**: if `claims["aud"]` is a `[]interface{}` with length > 1, require `claims["azp"]` to be a non-empty string equal to `t.clientID`; reject otherwise.
8. **`iat` upper-age bound**: reject when `time.Now().Unix() - int64(claims["iat"].(float64)) > MaxTokenAgeSeconds` (default 86400).
9. **Optional introspection**: if `requireTokenIntrospection` is set, call `introspectToken`; reject if `active == false` (401); surface 503 on transport failure. Bearer-path introspection cache TTL is capped at 60s (not 5min) to keep the "real-time revocation" promise close to true.
10. **Identifier resolution**: read `t.bearerIdentifierClaim` (defaults to `"sub"`); do NOT use `t.userIdentifierClaim` (cookie path's setting, default `email`). The bearer path does NOT fall back to other claims because `jwt.Verify` already enforces non-empty `sub` (`jwt.go:416-419`). Empty/missing identifier → 401.
11. **Identifier sanitisation**: trim, then reject if length > 256 OR contains any of: `unicode.IsControl`, bidi-override (U+202AU+202E, U+2066U+2069), `,`, `;`, `=`.
12. Return `&principal{ Source: sourceBearer, … }`.
On any failure path: increment the per-IP `failedBearerAttempts` counter; return the appropriate HTTP status (401 / 403 / 429 / 503) without revealing the failure reason in the response body. Reason is logged at debug only, with the identifier (if resolved) hashed via SHA-256 truncated to 8 hex chars.
### 7.4 `forwardAuthorized(rw, req, *principal)` (new in `middleware.go`, extracted)
The shared post-auth pipeline. Lifted verbatim from the existing `processAuthorizedRequest`:
1. Roles/groups extraction via existing `extractGroupsAndRolesFromClaims`.
2. `allowedRolesAndGroups` gate (existing logic).
3. Inject `X-Forwarded-User`, `X-User-Groups`, `X-User-Roles`.
4. Inject `X-Auth-Request-*` (gated by `minimalHeaders`).
5. Header templates.
6. Security headers.
7. Cookie strip when `stripAuthCookies`.
8. **New**: `Authorization` header strip when `stripAuthorizationHeader` AND `principal.Source == sourceBearer`.
9. `t.next.ServeHTTP(rw, req)`.
Does not call `Save`, does not check `IsDirty`. Session persistence stays with the cookie-path caller.
### 7.5 `handleBearerRequest(rw, req)` (new in `bearer_auth.go`)
```
1. Detect "Authorization: Bearer <token>" (case-insensitive prefix).
2. token = TrimSpace(authHeader[7:]); reject empty.
3. p, err := buildPrincipalFromBearerToken(token).
On err → 401 with WWW-Authenticate, log reason at debug.
4. forwardAuthorized(rw, req, p).
```
Target: ~40 lines.
### 7.6 Refactor of `processAuthorizedRequest` (modify `middleware.go`)
Splits along the principal boundary:
- Session-specific part (backchannel-logout invalidation, `IsDirty` / `Save`) stays in `processAuthorizedRequest`.
- Everything else moves to `forwardAuthorized`.
- `processAuthorizedRequest` ends with `forwardAuthorized(rw, req, buildPrincipalFromSession(session))`.
### 7.7 `verifyOpts` extension to `verifyToken` (modify `token_manager.go`)
Add a parameter struct:
```go
type verifyOpts struct {
skipReplayMarking bool // suppress JTI Set (token_manager.go:108-143); blacklist Get stays active
}
```
Both the type and field are unexported (internal-only knob). Signature change: `verifyToken(token string)` becomes `verifyToken(token string, opts verifyOpts)`. Existing callers pass `verifyOpts{}` (zero value = current behaviour). Bearer path passes `verifyOpts{skipReplayMarking: true}`.
**Critical semantics — must be reflected in implementation and tests:**
- `skipReplayMarking` only gates the **Set** at `token_manager.go:108-143` (the call adding the JTI to the blacklist and replay cache).
- The blacklist **Get** at `token_manager.go:44-47, 80-89` stays unconditionally active on the bearer path. Tokens revoked via `RevokeToken` (which adds the JTI to the blacklist) MUST still be rejected on the bearer path.
- Must NOT be implemented by mutating `t.disableReplayDetection` (struct field) — that would create a cross-request race that disables replay protection globally.
A targeted regression test exercises: bearer token verified once → admin calls `RevokeToken` adding the JTI to the blacklist → same token replayed → 401.
### 7.8 Config additions (modify `settings.go`)
```go
EnableBearerAuth bool `json:"enableBearerAuth,omitempty"`
BearerIdentifierClaim string `json:"bearerIdentifierClaim,omitempty"`
StripAuthorizationHeader bool `json:"stripAuthorizationHeader,omitempty"`
BearerEmitWWWAuthenticate bool `json:"bearerEmitWWWAuthenticate,omitempty"`
BearerOverridesCookie bool `json:"bearerOverridesCookie,omitempty"`
MaxTokenAgeSeconds int64 `json:"maxTokenAgeSeconds,omitempty"`
MaxIdentifierLength int `json:"maxIdentifierLength,omitempty"`
BearerFailureThreshold int `json:"bearerFailureThreshold,omitempty"`
BearerFailureWindowSeconds int `json:"bearerFailureWindowSeconds,omitempty"`
BearerFailurePenaltySeconds int `json:"bearerFailurePenaltySeconds,omitempty"`
```
Defaults (applied in `CreateConfig` for the bearer-related fields; values >0 only honoured when `EnableBearerAuth=true`):
- `EnableBearerAuth`: `false`.
- `BearerIdentifierClaim`: `"sub"`.
- `StripAuthorizationHeader`: `true`.
- `BearerEmitWWWAuthenticate`: `true` (RFC 6750 hint enabled by default; flip to false if recon-exposure is a concern).
- `BearerOverridesCookie`: `false` (cookie wins when both present; flip to `true` for the legacy/industry-default behaviour).
- `MaxTokenAgeSeconds`: `86400` (24h upper bound on `iat`).
- `MaxIdentifierLength`: `256`.
- `BearerFailureThreshold`: `20` (consecutive 401s per IP before throttle).
- `BearerFailureWindowSeconds`: `60`.
- `BearerFailurePenaltySeconds`: `60` (429 reply for this long after threshold tripped).
### 7.9 Startup validation (modify `main.go` `New()`)
- `EnableBearerAuth && Audience == ""` → fatal error.
- `EnableBearerAuth && !StrictAudienceValidation` → warning log (recommended hardening).
- `EnableBearerAuth && BearerIdentifierClaim == "email"` → fatal error (the bearer path is M2M and an `email` identifier without `email_verified` enforcement is a spoofing vector; default `BearerIdentifierClaim=sub` avoids this; explicit override to `email` is rejected).
- `EnableBearerAuth && MaxTokenAgeSeconds <= 0` → reset to default 86400 with info log.
- `EnableBearerAuth && BearerFailureThreshold <= 0` → reset to default 20 with info log.
## 8. Data Flow
### 8.1 Bearer path
```
ServeHTTP entry (pre-init paths unchanged: logout, backchannel, frontchannel, excluded URLs, SSE/WS bypass)
├─ enableBearerAuth == false? → fall through to cookie path
└─ enableBearerAuth == true AND Authorization starts with "Bearer "
handleBearerRequest
├─ format guards (empty, length, segment count)
verifyToken(token, verifyOpts{SkipReplayMarking: true})
│ signature, issuer, audience (strict), exp
classifyToken(claims) → reject ID tokens
if requireTokenIntrospection: introspectToken → active check
resolveIdentifier(claims) → sanitiseIdentifier
principal{Source: sourceBearer, …}
forwardAuthorized(rw, req, principal)
├─ roles/groups gate (403 on deny)
├─ header injection
├─ header templates
├─ security headers
├─ strip OIDC cookies (existing)
├─ strip Authorization header (new, when configured)
└─ next.ServeHTTP(rw, req)
```
### 8.2 Cookie path (refactored, semantically unchanged)
```
processAuthorizedRequest
1. Session validity / backchannel-logout invalidation (unchanged).
2. principal := buildPrincipalFromSession(session).
3. forwardAuthorized(rw, req, principal).
4. if session.IsDirty(): session.Save().
```
## 9. Error Handling
| Trigger | Status | Body | WWW-Authenticate | Debug log reason |
|---|---|---|---|---|
| Empty bearer after prefix | 401 | `Unauthorized` | `Bearer error="invalid_request"` | empty bearer token |
| Token over MaxLength | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token exceeds max length |
| Not a 3-segment JWT | 401 | `Unauthorized` | `Bearer error="invalid_token"` | malformed JWT |
| Disallowed `alg` (e.g. none, HS*) | 401 | `Unauthorized` | `Bearer error="invalid_token"` | unsupported alg |
| Missing/oversized/bad-charset `kid` | 401 | `Unauthorized` | `Bearer error="invalid_token"` | invalid kid |
| Signature / issuer / aud / exp fail | 401 | `Unauthorized` | `Bearer error="invalid_token"` | reason from verifyToken (category only) |
| `iat` older than MaxTokenAgeSeconds | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token too old (iat outside age bound) |
| Multi-aud without matching `azp` | 401 | `Unauthorized` | `Bearer error="invalid_token"` | multi-aud token without azp match |
| Detected as ID token | 401 | `Unauthorized` | `Bearer error="invalid_token"` | ID tokens not accepted on bearer path |
| JTI blacklisted (revoked) | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token JTI in blacklist |
| Introspection `active=false` | 401 | `Unauthorized` | `Bearer error="invalid_token"` | token inactive at IdP |
| Introspection endpoint failure | 503 | `Service Unavailable` | (none) | introspection unavailable |
| Identifier claim missing/empty | 401 | `Unauthorized` | `Bearer error="invalid_token"` | no identifier claim |
| Identifier fails sanitisation | 401 | `Unauthorized` | `Bearer error="invalid_token"` | invalid identifier characters |
| Per-IP failure threshold tripped | 429 | `Too Many Requests` | (none); `Retry-After: <BearerFailurePenaltySeconds>` | source IP in penalty box |
| Roles/groups not allowed | 403 | `Access denied` | (none) | user not in allowedRolesAndGroups |
Responses never include token contents, never include the raw failure reason, and never set `Location` headers (API clients cannot follow redirects).
## 10. Edge Cases
1. **Both bearer header and cookie session present.** Cookie wins by default (safer against browser/extension/proxy bearer injection). `BearerOverridesCookie=true` flips to bearer-wins. Either way: WARN log includes both source markers so operators can audit.
2. **`Authorization: Basic …`.** Not bearer; cookie path runs as today.
3. **`Authorization: Bearer ` (trailing space, no value).** Empty after trim → 401.
4. **Mixed-case prefix (`bearer`, `BEARER`, `BeArEr`).** Case-insensitive prefix check; token value preserved verbatim.
5. **Multiple `Authorization` headers.** Use only the first (Go `http.Header.Get` default). Documented.
6. **Bearer during OIDC init wait.** Bearer requests also block on init: we need `issuerURL`, `audience`, JWKs ready. If init fails, bearer requests return 503 just like cookie requests.
7. **SSE / WebSocket bypass with bearer.** Bypass paths keep cookie-only behaviour. Operators who want bearer on streaming endpoints must remove SSE/WS bypass. Documented.
8. **Logout endpoint with bearer.** Logout runs before bearer detection. Treated as cookie-session logout; bearer token revocation requires IdP-side action.
9. **Excluded URLs with bearer.** Bypass excluded URLs as today; bearer not validated on excluded paths. ADDITIONALLY: `Authorization: Bearer` is stripped from the request before forwarding so the token can't leak into the excluded endpoint's downstream logs / metrics scrapers / health checks.
10. **Concurrent identical bearer requests.** Existing `tokenCache` is concurrency-safe; no new locking.
11. **Client rotates token between requests.** Independent verification per token; independent cache entries.
12. **Clock skew.** Use existing `jwt.Verify` leeway. (If absent, add ±30s as a separate change; out of scope here.)
## 11. Testing Strategy
### 11.1 Integration tests (new `bearer_auth_test.go`)
Table-driven test against a real `httptest.Server` and the full `ServeHTTP` flow. Coverage matrix:
- Valid access token + allowed roles → 200, `next` ran, `X-Forwarded-User` set.
- Valid token without configured roles → 200.
- Wrong audience, expired, tampered signature → 401, `next` did not run.
- ID token presented → 401 (`ID tokens not accepted`).
- Malformed JWT (2 segments) → 401.
- Oversized token (> MaxLength) → 401.
- Empty bearer → 401.
- Missing identifier claim → 401.
- Identifier containing `\r\n` → 401.
- `allowedRolesAndGroups` mismatch → 403.
- `allowedRolesAndGroups` match → 200.
- `EnableBearerAuth=false` + bearer header → cookie path runs (302 to `/authorize`).
- Bearer + valid cookie session → bearer wins, 200.
- `StripAuthorizationHeader=true` → downstream sees no `Authorization`.
- `StripAuthorizationHeader=false` → downstream sees `Authorization`.
- Case variants (`bearer`, `BEARER`) → 200.
- SSE bypass + bearer → cookie-only check applies (bearer ignored).
- **Replay regression**: same token 1000 times in a row → all 200.
- **Cache-evict regression**: same token, force-evict `tokenCache` between iterations (call `tokenCache.Delete` directly), replay → still 200 (verifies `skipReplayMarking` doesn't poison the blacklist).
- **Revocation-while-bearer regression**: bearer token verified once → admin calls `RevokeToken` adding JTI to blacklist → same token presented → 401 (verifies blacklist Get stays active on bearer path even with `skipReplayMarking` set).
- **Alg-pin: token signed with `alg=none`** → 401, no JWKS fetch happens (verify with a counting mock).
- **`kid` injection: 50KB random kid** → 401 immediately, no JWKS fetch.
- **Per-IP throttle**: 21 bad bearer requests from same IP within 1 minute → 22nd returns 429 + Retry-After.
- **`iat` upper-age**: token with `iat = now - 25h` → 401 (older than 24h default).
- **Multi-aud without azp**: aud = `["a", "b"]`, no azp → 401.
- **Multi-aud with matching azp**: aud = `["api-aud", "other"]`, azp = clientID → 200.
- **Identifier with bidi-override**: sub contains U+202E → 401.
- **Identifier with comma**: sub = `"alice,bob"` → 401.
- **Identifier over 256 bytes** → 401.
- **`UserIdentifierClaim=email` at startup with EnableBearerAuth=true** → startup fails.
- **Excluded URL + bearer**: bearer header presented on excluded URL → request forwarded, downstream sees no `Authorization` header (stripped).
### 11.2 Unit tests (in `bearer_auth_test.go`)
- `classifyToken`: ID-token detection, access-token detection by `scope`/`scp`/`token_use`, ambiguous → reject.
- `resolveIdentifier`: precedence (`userIdentifierClaim``sub``client_id`/`azp`); missing → error; empty string → error.
- `sanitizeIdentifier`: rejects all `unicode.IsControl`; accepts email/sub-style values.
### 11.3 Introspection tests (`bearer_auth_introspection_test.go`)
- Token valid + introspection `active=true` → 200.
- Token valid + introspection `active=false` → 401.
- Introspection endpoint 500 → 503.
- Second request hits introspection cache (no second HTTP call).
### 11.4 Startup validation tests (extend `settings_test.go` / `main_test.go`)
- `EnableBearerAuth=true, Audience=""``New()` errors.
- `EnableBearerAuth=true, StrictAudienceValidation=false` → succeeds with warning.
- `EnableBearerAuth=false` → no validation; existing tests untouched.
### 11.5 Cookie-path regression suite
- All existing `TestServeHTTP_*` tests in `main_servehttp_test.go` pass unmodified.
- Add: cookie session, `EnableBearerAuth=true`, no bearer header → identical behaviour to baseline.
- Add: dirty session still triggers `Save()` after refactor.
### 11.6 Principal invariants
- `buildPrincipalFromSession`: `Source == sourceSession`; `IDToken` / `RefreshToken` populated when present in session.
- `buildPrincipalFromBearerToken`: `Source == sourceBearer`; `IDToken == ""`, `RefreshToken == ""`.
- `forwardAuthorized` produces identical headers for equivalent principals regardless of source.
### 11.7 Coverage gate
- New code in `bearer_auth.go` and `principal.go`: ≥ 90% line coverage.
- `forwardAuthorized` coverage ≥ existing `processAuthorizedRequest` coverage baseline.
### 11.8 Out of scope (follow-ups)
- Load test of bearer vs cookie hot path.
- Fuzzing the JWT parser.
- Additional auth methods (mTLS, API keys) — design enables them, but they are separate work.
## 12. Migration / Rollout
Default-off. Existing deployments observe no behavioural change. Operators opt in by setting:
```yaml
enableBearerAuth: true
audience: https://api.example.com # required when bearer enabled
# optional:
stripAuthorizationHeader: true # default
requireTokenIntrospection: false # default; set true for real-time revocation
userIdentifierClaim: client_id # optional override; defaults to sub fallback chain
```
Documentation: update `docs/CONFIGURATION.md` with a bearer-auth section, and add a new `docs/BEARER_AUTH.md` covering the security model, threat assumptions (token issuer is trusted; audience must be set; bearer means trust the issuer's revocation policy unless introspection enabled), and recommended configurations for common IdPs.
## 13. Security Considerations
| Concern | Mitigation |
|---|---|
| Token confusion (ID token used as bearer) | Reuse `detectTokenType` (`token_manager.go:187-303`) which checks `nonce`, `typ: at+jwt`, `token_use`, `scope`, aud-vs-clientID. Belt-and-braces: explicit `nonce` + `token_use == "id"` rejection on top. |
| Audience confusion (token for service B accepted by A) | `Audience` mandatory at startup; verified via existing `VerifyJWTSignatureAndClaims`; multi-aud tokens require matching `azp == clientID`. |
| Replay-via-blacklist false positive | `verifyOpts{skipReplayMarking: true}` on bearer path. Gates ONLY the Set; the Get stays so revoked tokens still fail. |
| Revocation lag | Optional RFC 7662 introspection. Bearer-path introspection cache TTL capped at 60s. Set `RequireTokenIntrospection=true` for real-time revocation. |
| `alg`-confusion / `alg=none` attacks | Hard-pin asymmetric allowlist at bearer entry, **before** JWKS fetch. Prevents wasted upstream calls and locks out HS/none probes. |
| `kid` injection / JWKS amplification | `kid` length cap (256 bytes) + charset allowlist enforced at bearer entry. |
| Bearer 401 brute-force / oracle | Per-IP `failedBearerAttempts` cache; configurable threshold + penalty box returning 429 + `Retry-After`. |
| `iat` clock-manipulation / forever-tokens | `MaxTokenAgeSeconds` upper bound (default 24h); cookie path unchanged. |
| Identifier-driven header injection | `sanitizeIdentifier`: length cap, control-char + bidi-override + `,;=` rejection. `net/http` rejects CRLF on the wire too (defence in depth). |
| Token leakage downstream | `StripAuthorizationHeader=true` by default. Also: `Authorization` stripped on excluded-URL requests so bearer can't leak into health/metrics downstream logs. |
| Token-in-logs | All log paths log reason categories, not raw tokens. Identifier hashed via SHA-256 truncated to 8 hex chars before any info/warn-level emission (full identifier only at debug). New `safeLogAuthEvent(category, hashedIdentifier, reasonCode)` helper makes this hard to misuse. |
| `email` claim spoofing | Startup fails if `EnableBearerAuth && UserIdentifierClaim == "email"`. Future human-user bearer iteration must add `email_verified` enforcement. |
| Bypass on SSE / WS endpoints | SSE/WS bypass keeps cookie-only behaviour; bearer ignored. Operators choose to widen if needed. |
| Mixed bearer + cookie precedence | Cookie wins by default (safer for browser scenarios); `BearerOverridesCookie=true` flips. WARN log on both-present requests. |
| Configuration drift (operator forgets audience) | Startup fails when `EnableBearerAuth=true && Audience==""`. |
| Downstream blast radius when `StripAuthorizationHeader=false` | Documented: forwarded bearer extends token's blast radius to all downstream services. Logs at those services become token stores. Operators must treat downstream log policy accordingly. |
| Introspection auth method (pre-existing gap, called out) | `token_introspection.go:80` uses `client_secret_basic` only; does not honour `private_key_jwt`. Out of scope for this PR but documented as a follow-up; operators using `ClientAuthMethod=private_key_jwt` + `RequireTokenIntrospection=true` should be aware introspection will use basic auth. |
## 14. Open Questions
None — all design decisions resolved during brainstorming + security review. Implementation may surface incidental questions (e.g. exact clock-skew leeway in `jwt.Verify`); those are out of scope for this spec and handled in the implementation plan.
## 14a. Security Review Reference
This design was reviewed by the `security-reviewer` subagent on 2026-05-18. Findings incorporated:
- **Critical**: C1 (classifier reuses `detectTokenType`), C2 (sub fallback dropped — unreachable due to `jwt.go:416`), C3 (replay-marking gates only Set, not Get; revocation regression test added).
- **High**: H1 (alg pinned at bearer entry), H2 (kid length + charset), H3 (cookie wins by default, configurable), H4 (per-IP 401 throttle), H5 (multi-aud requires azp).
- **Medium**: M1 (identifier max-length + bidi reject + delimiter chars), M2 (introspection cache TTL capped at 60s on bearer path), M4 (log-hashing via SHA-256[:8]), M5 (StripAuth blast-radius documented), M6 (iat upper-age bound), M7 (Authorization stripped on excluded URLs).
- **Low/Nit**: L2 (renamed to `BearerEmitWWWAuthenticate`), N3 (startup rejects `UserIdentifierClaim=email`).
- **Documented as pre-existing gaps (follow-up PRs)**: M3 (introspection auth method doesn't honour `private_key_jwt`).
## 15. Implementation Plan Reference
To be produced by the `writing-plans` skill in a follow-up document at `docs/superpowers/plans/2026-05-18-bearer-token-auth-plan.md`. The plan decomposes this design into ordered, independently-testable PRs.
+61 -191
View File
@@ -50,6 +50,7 @@ type DynamicClientRegistrar struct {
logger *Logger
config *DynamicClientRegistrationConfig
registrationResponse *ClientRegistrationResponse
store DCRCredentialsStore // Storage backend for credentials
providerURL string
mu sync.RWMutex
}
@@ -73,8 +74,37 @@ func NewDynamicClientRegistrar(
}
}
// NewDynamicClientRegistrarWithStore creates a new dynamic client registrar with a specific storage backend
func NewDynamicClientRegistrarWithStore(
httpClient *http.Client,
logger *Logger,
dcrConfig *DynamicClientRegistrationConfig,
providerURL string,
store DCRCredentialsStore,
) *DynamicClientRegistrar {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
return &DynamicClientRegistrar{
httpClient: httpClient,
logger: logger,
config: dcrConfig,
providerURL: providerURL,
store: store,
}
}
// SetStore sets the credentials store for the registrar
// This allows setting the store after creation when the cache manager is available
func (r *DynamicClientRegistrar) SetStore(store DCRCredentialsStore) {
r.mu.Lock()
defer r.mu.Unlock()
r.store = store
}
// RegisterClient performs dynamic client registration with the OIDC provider
// It first attempts to load existing credentials from a file if persistence is enabled,
// It first attempts to load existing credentials from storage if persistence is enabled,
// then registers a new client if no valid credentials exist.
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
if r.config == nil || !r.config.Enabled {
@@ -83,10 +113,13 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
// Try to load existing credentials if persistence is enabled
if r.config.PersistCredentials {
if resp, err := r.loadCredentials(); err == nil && resp != nil {
resp, err := r.loadCredentialsFromStore(ctx)
if err != nil {
r.logger.Debugf("Failed to load credentials from store: %v", err)
} else if resp != nil {
// Check if credentials are still valid (not expired)
if r.areCredentialsValid(resp) {
r.logger.Info("Loaded existing client credentials from file")
r.logger.Info("Loaded existing client credentials from storage")
r.mu.Lock()
r.registrationResponse = resp
r.mu.Unlock()
@@ -179,7 +212,7 @@ func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registratio
// Persist credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentials(&regResp); err != nil {
if err := r.saveCredentialsToStore(ctx, &regResp); err != nil {
r.logger.Errorf("Failed to persist client credentials: %v", err)
// Don't fail registration if persistence fails
}
@@ -315,7 +348,29 @@ func (r *DynamicClientRegistrar) credentialsFilePath() string {
return "/tmp/oidc-client-credentials.json"
}
// saveCredentials persists client credentials to a file
// loadCredentialsFromStore loads client credentials from the configured storage backend
// Falls back to legacy file-based loading if no store is configured
func (r *DynamicClientRegistrar) loadCredentialsFromStore(ctx context.Context) (*ClientRegistrationResponse, error) {
// Use store if available
if r.store != nil {
return r.store.Load(ctx, r.providerURL)
}
// Fallback to legacy file-based loading
return r.loadCredentials()
}
// saveCredentialsToStore persists client credentials to the configured storage backend
// Falls back to legacy file-based saving if no store is configured
func (r *DynamicClientRegistrar) saveCredentialsToStore(ctx context.Context, resp *ClientRegistrationResponse) error {
// Use store if available
if r.store != nil {
return r.store.Save(ctx, r.providerURL, resp)
}
// Fallback to legacy file-based saving
return r.saveCredentials(resp)
}
// saveCredentials persists client credentials to a file (legacy method)
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
filePath := r.credentialsFilePath()
@@ -333,7 +388,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
return nil
}
// loadCredentials loads client credentials from a file
// loadCredentials loads client credentials from a file (legacy method)
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
filePath := r.credentialsFilePath()
@@ -353,188 +408,3 @@ func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse,
return &resp, nil
}
// UpdateClientRegistration updates an existing client registration using RFC 7592
// This requires the registration_client_uri and registration_access_token from the original registration
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to update")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Build update request
reqBody, err := r.buildRegistrationRequest()
if err != nil {
return nil, fmt.Errorf("failed to build update request: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create update request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("update request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read update response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse update response: %w", err)
}
// Update cache
r.mu.Lock()
r.registrationResponse = &regResp
r.mu.Unlock()
// Persist updated credentials if enabled
if r.config.PersistCredentials {
if err := r.saveCredentials(&regResp); err != nil {
r.logger.Errorf("Failed to persist updated credentials: %v", err)
}
}
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
return &regResp, nil
}
// ReadClientRegistration reads the current client registration using RFC 7592
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return nil, fmt.Errorf("no existing registration to read")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
if err != nil {
return nil, fmt.Errorf("failed to create read request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("read request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
// Handle error responses
if resp.StatusCode != http.StatusOK {
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse successful response
var regResp ClientRegistrationResponse
if err := json.Unmarshal(body, &regResp); err != nil {
return nil, fmt.Errorf("failed to parse read response: %w", err)
}
return &regResp, nil
}
// DeleteClientRegistration deletes the client registration using RFC 7592
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
r.mu.RLock()
cachedResp := r.registrationResponse
r.mu.RUnlock()
if cachedResp == nil {
return fmt.Errorf("no existing registration to delete")
}
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
if err != nil {
return fmt.Errorf("failed to create delete request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
// Execute request
resp, err := r.httpClient.Do(req)
if err != nil {
return fmt.Errorf("delete request failed: %w", err)
}
defer resp.Body.Close()
// Handle error responses (204 No Content is success)
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
var regError ClientRegistrationError
if jsonErr := json.Unmarshal(body, &regError); jsonErr == nil && regError.Error != "" {
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
}
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
}
// Clear cache
r.mu.Lock()
r.registrationResponse = nil
r.mu.Unlock()
// Remove credentials file if persistence is enabled
if r.config.PersistCredentials {
filePath := r.credentialsFilePath()
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
r.logger.Errorf("Failed to remove credentials file: %v", err)
}
}
r.logger.Info("Successfully deleted client registration")
return nil
}
-252
View File
@@ -735,258 +735,6 @@ func TestDCRConfigDefaults(t *testing.T) {
}
}
// TestUpdateClientRegistration tests the RFC 7592 client update functionality
func TestUpdateClientRegistration(t *testing.T) {
updateCalled := false
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPut {
updateCalled = true
// Verify authorization header
if r.Header.Get("Authorization") == "" {
t.Error("Missing Authorization header for update")
}
resp := ClientRegistrationResponse{
ClientID: "updated-client-id",
ClientSecret: "updated-client-secret",
RegistrationAccessToken: "new-access-token",
RegistrationClientURI: r.URL.String(),
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
}))
defer server.Close()
dcrConfig := &DynamicClientRegistrationConfig{
Enabled: true,
ClientMetadata: &ClientRegistrationMetadata{
RedirectURIs: []string{"https://example.com/callback"},
},
}
registrar := NewDynamicClientRegistrar(
server.Client(),
NewLogger("DEBUG"),
dcrConfig,
server.URL,
)
// Set up cached response with management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "original-client-id",
ClientSecret: "original-client-secret",
RegistrationAccessToken: "access-token",
RegistrationClientURI: server.URL + "/register/client123",
}
registrar.mu.Unlock()
// Perform update
ctx := context.Background()
resp, err := registrar.UpdateClientRegistration(ctx)
if err != nil {
t.Fatalf("Update failed: %v", err)
}
if !updateCalled {
t.Error("Update endpoint was not called")
}
if resp.ClientID != "updated-client-id" {
t.Errorf("Updated ClientID mismatch: got %s", resp.ClientID)
}
}
// TestDeleteClientRegistration tests the RFC 7592 client deletion functionality
func TestDeleteClientRegistration(t *testing.T) {
deleteCalled := false
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodDelete {
deleteCalled = true
w.WriteHeader(http.StatusNoContent)
}
}))
defer server.Close()
tempDir := t.TempDir()
credentialsFile := filepath.Join(tempDir, "credentials.json")
// Create a credentials file to test deletion
os.WriteFile(credentialsFile, []byte(`{"client_id":"test"}`), 0600)
dcrConfig := &DynamicClientRegistrationConfig{
Enabled: true,
PersistCredentials: true,
CredentialsFile: credentialsFile,
}
registrar := NewDynamicClientRegistrar(
server.Client(),
NewLogger("DEBUG"),
dcrConfig,
server.URL,
)
// Set up cached response with management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "test-client-id",
RegistrationAccessToken: "access-token",
RegistrationClientURI: server.URL + "/register/client123",
}
registrar.mu.Unlock()
// Perform delete
ctx := context.Background()
err := registrar.DeleteClientRegistration(ctx)
if err != nil {
t.Fatalf("Delete failed: %v", err)
}
if !deleteCalled {
t.Error("Delete endpoint was not called")
}
// Verify cache is cleared
if registrar.GetCachedResponse() != nil {
t.Error("Cached response should be cleared after deletion")
}
// Verify credentials file is deleted
if _, err := os.Stat(credentialsFile); !os.IsNotExist(err) {
t.Error("Credentials file should be deleted")
}
}
// TestReadClientRegistration tests the RFC 7592 client read functionality
func TestReadClientRegistration(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
resp := ClientRegistrationResponse{
ClientID: "read-client-id",
ClientSecret: "read-client-secret",
RedirectURIs: []string{"https://example.com/callback"},
ResponseTypes: []string{"code"},
GrantTypes: []string{"authorization_code"},
ApplicationType: "web",
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
}))
defer server.Close()
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
registrar := NewDynamicClientRegistrar(
server.Client(),
NewLogger("DEBUG"),
dcrConfig,
server.URL,
)
// Set up cached response with management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "original-client-id",
RegistrationAccessToken: "access-token",
RegistrationClientURI: server.URL + "/register/client123",
}
registrar.mu.Unlock()
// Read registration
ctx := context.Background()
resp, err := registrar.ReadClientRegistration(ctx)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
if resp.ClientID != "read-client-id" {
t.Errorf("Read ClientID mismatch: got %s", resp.ClientID)
}
}
// TestOperationsWithoutCachedResponse tests error handling when no cached response exists
func TestOperationsWithoutCachedResponse(t *testing.T) {
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
registrar := NewDynamicClientRegistrar(
&http.Client{},
NewLogger("DEBUG"),
dcrConfig,
"https://example.com",
)
ctx := context.Background()
// Test Update without cached response
_, err := registrar.UpdateClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "no existing registration") {
t.Errorf("Update should fail without cached response: %v", err)
}
// Test Read without cached response
_, err = registrar.ReadClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "no existing registration") {
t.Errorf("Read should fail without cached response: %v", err)
}
// Test Delete without cached response
err = registrar.DeleteClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "no existing registration") {
t.Errorf("Delete should fail without cached response: %v", err)
}
}
// TestOperationsWithoutManagementCredentials tests error handling without management URIs
func TestOperationsWithoutManagementCredentials(t *testing.T) {
dcrConfig := &DynamicClientRegistrationConfig{Enabled: true}
registrar := NewDynamicClientRegistrar(
&http.Client{},
NewLogger("DEBUG"),
dcrConfig,
"https://example.com",
)
// Set up cached response WITHOUT management credentials
registrar.mu.Lock()
registrar.registrationResponse = &ClientRegistrationResponse{
ClientID: "test-client-id",
// Missing RegistrationAccessToken and RegistrationClientURI
}
registrar.mu.Unlock()
ctx := context.Background()
// Test Update without management credentials
_, err := registrar.UpdateClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "registration management not supported") {
t.Errorf("Update should fail without management credentials: %v", err)
}
// Test Read without management credentials
_, err = registrar.ReadClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "registration management not supported") {
t.Errorf("Read should fail without management credentials: %v", err)
}
// Test Delete without management credentials
err = registrar.DeleteClientRegistration(ctx)
if err == nil || !stringContains(err.Error(), "registration management not supported") {
t.Errorf("Delete should fail without management credentials: %v", err)
}
}
// stringContains is a helper function to check if a string contains a substring
func stringContains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && stringContainsHelper(s, substr))
+27
View File
@@ -2,6 +2,8 @@ package traefikoidc
import (
"context"
"crypto"
"fmt"
"net/http"
"sync"
"sync/atomic"
@@ -40,6 +42,31 @@ func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, http
return m.JWKS, m.Err
}
func (m *EnhancedMockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
jwks, err := m.GetJWKS(ctx, jwksURL, httpClient)
if err != nil {
return nil, err
}
if jwks == nil {
return nil, fmt.Errorf("JWKS is nil")
}
for i := range jwks.Keys {
k := &jwks.Keys[i]
if k.Kid != kid {
continue
}
switch k.Kty {
case "RSA":
return k.ToRSAPublicKey()
case "EC":
return k.ToECDSAPublicKey()
default:
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
}
}
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
}
func (m *EnhancedMockJWKCache) Cleanup() {
atomic.AddInt32(&m.CleanupCalls, 1)
m.mu.Lock()
+7 -6
View File
@@ -539,10 +539,10 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
return true
}
errStr := err.Error()
errStr := strings.ToLower(err.Error())
for _, retryableErr := range re.config.RetryableErrors {
if contains(errStr, retryableErr) {
if contains(errStr, strings.ToLower(retryableErr)) {
return true
}
}
@@ -551,7 +551,7 @@ func (re *RetryExecutor) isRetryableError(err error) bool {
if netErr.Timeout() {
return true
}
errStr := netErr.Error()
errStr := strings.ToLower(netErr.Error())
temporaryPatterns := []string{
"connection refused",
"connection reset",
@@ -859,8 +859,9 @@ func (gd *GracefulDegradation) ExecuteWithFallback(serviceName string, primary f
// isServiceDegraded checks if a service is currently degraded
func (gd *GracefulDegradation) isServiceDegraded(serviceName string) bool {
gd.mutex.RLock()
defer gd.mutex.RUnlock()
// Uses a write lock because the recovery-timeout branch deletes from the map.
gd.mutex.Lock()
defer gd.mutex.Unlock()
degradedTime, exists := gd.degradedServices[serviceName]
if !exists {
@@ -954,7 +955,7 @@ func (gd *GracefulDegradation) GetDegradedServices() []string {
gd.mutex.RLock()
defer gd.mutex.RUnlock()
var degraded []string
degraded := make([]string, 0, len(gd.degradedServices))
for serviceName := range gd.degradedServices {
degraded = append(degraded, serviceName)
}
+10
View File
@@ -101,6 +101,16 @@ http:
providerURL: "https://auth.example.com"
callbackURL: "/oauth2/callback"
# ----------------------------------------------------------------
# Optional: switch to RFC 7523 private_key_jwt client auth
# (Entra ID, Okta, Auth0, Keycloak). Replaces clientSecret with a
# signed JWT assertion. See README for details and PEM formats.
# ----------------------------------------------------------------
# clientAuthMethod: "private_key_jwt"
# clientAssertionKeyPath: "/etc/traefik/oidc/client-key.pem"
# clientAssertionKeyID: "prod-key-2026"
# clientAssertionAlg: "RS256" # or PS256/384/512, ES256/384/512
# Session Configuration
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
sessionMaxAge: 28800 # 8 hours
+1 -1
View File
@@ -4,8 +4,8 @@ go 1.24.0
require (
github.com/alicebob/miniredis/v2 v2.35.0
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
github.com/lukaszraczylo/oss-telemetry v0.2.3
github.com/redis/go-redis/v9 v9.17.2
github.com/stretchr/testify v1.10.0
golang.org/x/time v0.14.0
+2 -2
View File
@@ -12,12 +12,12 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/lukaszraczylo/oss-telemetry v0.2.3 h1:xoDtBqeZGmXj7IteiE1M5WMuzeoqag58qEleI0Cf2Ms=
github.com/lukaszraczylo/oss-telemetry v0.2.3/go.mod h1:+Cn78qZo8rc3T9eZt0v3oICYRdd75wORtSidc8lNjDQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
+63 -5
View File
@@ -17,6 +17,21 @@ import (
"github.com/lukaszraczylo/traefikoidc/internal/utils"
)
// newUUIDv4 returns an RFC 4122 v4 UUID string (e.g.
// "f47ac10b-58cc-4372-a567-0e02b2c3d479") backed by crypto/rand. Used for CSRF
// tokens and other opaque random identifiers — replaces github.com/google/uuid
// to keep the plugin stdlib-only on the production path.
func newUUIDv4() (string, error) {
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
return "", fmt.Errorf("could not generate UUID: %w", err)
}
b[6] = (b[6] & 0x0f) | 0x40 // version 4
b[8] = (b[8] & 0x3f) | 0x80 // RFC 4122 variant
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]), nil
}
// generateNonce creates a cryptographically secure random nonce for OIDC flows.
// The nonce is used to prevent replay attacks and associate client sessions with ID tokens.
// Returns:
@@ -92,9 +107,12 @@ type TokenResponse struct {
// - An error if the token exchange fails (e.g., network error, provider error, invalid grant)
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
"client_id": {t.clientID},
"client_secret": {t.clientSecret},
"grant_type": {grantType},
}
// client_id is sent in the body for every method except client_secret_basic,
// where it is carried in the Authorization header per RFC 6749 §2.3.1.
if t.clientAuthMethod != "client_secret_basic" || t.clientAssertion != nil {
data.Set("client_id", t.clientID)
}
if grantType == "authorization_code" {
@@ -126,16 +144,33 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
}
}
// Read tokenURL with RLock
// Read tokenURL with RLock — needed as audience for private_key_jwt (RFC 7523 §3).
t.metadataMu.RLock()
tokenURL := t.tokenURL
t.metadataMu.RUnlock()
useBasicAuth := false
if t.clientAssertion != nil {
assertion, err := t.clientAssertion.Sign(tokenURL, t.clientID)
if err != nil {
return nil, fmt.Errorf("failed to sign client assertion: %w", err)
}
data.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
data.Set("client_assertion", assertion)
} else if t.clientAuthMethod == "client_secret_basic" {
useBasicAuth = true
} else {
data.Set("client_secret", t.clientSecret)
}
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if useBasicAuth {
setOAuthBasicAuth(req, t.clientID, t.clientSecret)
}
resp, err := client.Do(req)
if err != nil {
@@ -336,6 +371,7 @@ func createStringMap(keys []string) map[string]struct{} {
// and redirects to the provider's logout endpoint or configured post-logout URI.
// It handles potential errors during session retrieval or clearing.
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing logout request")
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
@@ -356,10 +392,19 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
baseURL := fmt.Sprintf("%s://%s", scheme, host)
postLogoutRedirectURI := t.postLogoutRedirectURI
// localRedirect is used when there is no provider end-session endpoint and
// the plugin redirects the browser itself. It must never be an absolute URL
// derived from the request host (X-Forwarded-Host is client-controllable and
// would be an open redirect); use a host-relative path, or the operator's
// own configured absolute URL, instead.
localRedirect := "/"
if postLogoutRedirectURI == "" {
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
localRedirect = normalizeLogoutPath(postLogoutRedirectURI)
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
} else {
localRedirect = postLogoutRedirectURI
}
// Read endSessionURL with RLock
@@ -378,7 +423,7 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
return
}
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
http.Redirect(rw, req, localRedirect, http.StatusFound)
}
// BuildLogoutURL constructs a logout URL for the OIDC provider's end session endpoint.
@@ -407,6 +452,19 @@ func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (strin
return u.String(), nil
}
// setOAuthBasicAuth sets the Authorization header per RFC 6749 §2.3.1: the
// client_id and client_secret are form-urlencoded individually, joined with a
// colon, then base64-encoded. This differs from http.Request.SetBasicAuth,
// which skips the form-urlencode step — that matters for credentials with
// reserved characters (`:`, `@`, `+`, `%`, etc.) where the wire format would
// otherwise diverge from what the spec mandates.
func setOAuthBasicAuth(req *http.Request, clientID, clientSecret string) {
user := url.QueryEscape(clientID)
pass := url.QueryEscape(clientSecret)
auth := base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
req.Header.Set("Authorization", "Basic "+auth)
}
// deduplicateScopes removes duplicate scopes from a slice while preserving order.
// This ensures that OAuth scope parameters don't contain duplicates which could
// cause issues with some authorization servers.
+29
View File
@@ -0,0 +1,29 @@
package traefikoidc
import (
"regexp"
"testing"
)
// TestNewUUIDv4 verifies the in-house UUID v4 generator produces RFC 4122
// compliant identifiers. Locks in the replacement for github.com/google/uuid
// — a regression here would weaken the CSRF token used in the OIDC flow.
func TestNewUUIDv4(t *testing.T) {
rfc4122v4 := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`)
const samples = 1000
seen := make(map[string]struct{}, samples)
for i := 0; i < samples; i++ {
got, err := newUUIDv4()
if err != nil {
t.Fatalf("newUUIDv4 failed: %v", err)
}
if !rfc4122v4.MatchString(got) {
t.Fatalf("UUID %q does not match RFC 4122 v4 format", got)
}
if _, dup := seen[got]; dup {
t.Fatalf("duplicate UUID emitted within %d samples: %q", samples, got)
}
seen[got] = struct{}{}
}
}
+13 -5
View File
@@ -3,6 +3,7 @@ package traefikoidc
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
@@ -25,10 +26,16 @@ type HTTPClientConfig struct {
Timeout time.Duration
MaxConnsPerHost int
WriteBufferSize int
UseCookieJar bool
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
// RootCAs is an optional certificate pool used for TLS verification.
// A nil pool means "use the system trust store" (default behavior).
RootCAs *x509.CertPool
// InsecureSkipVerify disables TLS certificate verification.
// ONLY set this for local development against self-signed certificates.
InsecureSkipVerify bool
UseCookieJar bool
ForceHTTP2 bool
DisableKeepAlives bool
DisableCompression bool
}
// DefaultHTTPClientConfig returns the default configuration for general use
@@ -203,7 +210,8 @@ func (f *HTTPClientFactory) CreateHTTPClient(config HTTPClientConfig) *http.Clie
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
},
PreferServerCipherSuites: true,
InsecureSkipVerify: false, // Always verify certificates
RootCAs: config.RootCAs,
InsecureSkipVerify: config.InsecureSkipVerify, //nolint:gosec // opt-in, loud warning emitted at plugin startup
},
ForceAttemptHTTP2: config.ForceHTTP2,
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
+48 -9
View File
@@ -3,6 +3,7 @@ package traefikoidc
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"sync"
@@ -25,6 +26,10 @@ type sharedTransport struct {
lastUsed time.Time
transport *http.Transport
refCount int
// tlsKey identifies the TLS trust settings (CA pool + InsecureSkipVerify)
// this transport was built with, so the at-limit fallback only reuses a
// transport whose TLS configuration matches the caller's.
tlsKey string
}
var (
@@ -52,19 +57,26 @@ func GetGlobalTransportPool() *SharedTransportPool {
// GetOrCreateTransport gets or creates a shared transport with the given config
func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *http.Transport {
// SECURITY FIX: Check client limit before creating new transport
// SECURITY FIX: Check client limit before creating new transport.
if atomic.LoadInt32(&p.clientCount) >= p.maxClients {
// Return existing transport if limit reached
p.mu.RLock()
defer p.mu.RUnlock()
// At the client limit: only reuse a transport that was built for the
// SAME config (same TLS trust store). refCount is mutated under the
// write lock to avoid a data race, and a transport created for a
// different configuration is never handed back — doing so could apply
// the wrong (possibly verification-disabled) TLS settings to a request.
want := tlsConfigKey(config)
p.mu.Lock()
defer p.mu.Unlock()
for _, shared := range p.transports {
if shared != nil && shared.transport != nil {
if shared != nil && shared.transport != nil && shared.tlsKey == want {
shared.refCount++
shared.lastUsed = time.Now()
return shared.transport
}
}
// If no transport available, return nil (caller should handle)
// No TLS-compatible transport available; return nil so the caller falls
// back to a default, certificate-verifying transport rather than one
// with a different (possibly verification-disabled) trust store.
return nil
}
@@ -103,7 +115,8 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
},
PreferServerCipherSuites: true,
InsecureSkipVerify: false,
RootCAs: config.RootCAs,
InsecureSkipVerify: config.InsecureSkipVerify, //nolint:gosec // opt-in, loud warning emitted at plugin startup
},
ForceAttemptHTTP2: config.ForceHTTP2,
TLSHandshakeTimeout: config.TLSHandshakeTimeout,
@@ -123,6 +136,7 @@ func (p *SharedTransportPool) GetOrCreateTransport(config HTTPClientConfig) *htt
transport: transport,
refCount: 1,
lastUsed: time.Now(),
tlsKey: tlsConfigKey(config),
}
return transport
@@ -205,8 +219,33 @@ func (p *SharedTransportPool) performCleanup() {
// configKey generates a unique key for a config
func (p *SharedTransportPool) configKey(config HTTPClientConfig) string {
// Simple key based on main parameters
return string(rune(config.MaxConnsPerHost)) + string(rune(config.MaxIdleConnsPerHost))
// Pool transports by the parameters that change TLS or connection
// behavior. RootCAs and InsecureSkipVerify MUST be part of the key:
// otherwise a middleware configured with a custom CA would share a
// transport with one using the system store, silently bypassing its
// CA configuration.
skip := "0"
if config.InsecureSkipVerify {
skip = "1"
}
return fmt.Sprintf("%d|%d|%p|%s",
config.MaxConnsPerHost,
config.MaxIdleConnsPerHost,
config.RootCAs,
skip,
)
}
// tlsConfigKey identifies only the TLS trust settings of a config — the CA pool
// and the InsecureSkipVerify flag. Two configs with the same tlsConfigKey are
// safe to serve from the same transport even if other (non-TLS) parameters such
// as connection limits differ; configs with different TLS settings are not.
func tlsConfigKey(config HTTPClientConfig) string {
skip := "0"
if config.InsecureSkipVerify {
skip = "1"
}
return fmt.Sprintf("%p|%s", config.RootCAs, skip)
}
// Cleanup closes all transports and stops the cleanup goroutine
+14 -27
View File
@@ -10,6 +10,14 @@ import (
"unicode/utf8"
)
// Pre-compiled regex patterns for validation (const patterns should use MustCompile)
var (
emailRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
urlRegexPattern = regexp.MustCompile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
tokenRegexPattern = regexp.MustCompile(`^[A-Za-z0-9._-]+$`)
usernameRegexPattern = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`)
)
// InputValidator provides comprehensive input validation and sanitization
// to protect against common security vulnerabilities including SQL injection,
// XSS, path traversal, and other injection attacks. It validates and sanitizes
@@ -73,7 +81,7 @@ func DefaultInputValidationConfig() InputValidationConfig {
}
// NewInputValidator creates a new input validator with the specified configuration.
// It compiles all necessary regex patterns and initializes security pattern lists.
// It uses pre-compiled regex patterns and initializes security pattern lists.
//
// Parameters:
// - config: Validation configuration with size limits and mode settings.
@@ -81,29 +89,8 @@ func DefaultInputValidationConfig() InputValidationConfig {
//
// Returns:
// - A configured InputValidator instance.
// - An error if regex compilation fails.
// - An error (always nil, kept for API compatibility).
func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputValidator, error) {
// Compile regex patterns
emailRegex, err := regexp.Compile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
if err != nil {
return nil, fmt.Errorf("failed to compile email regex: %w", err)
}
urlRegex, err := regexp.Compile(`^https?://[a-zA-Z0-9.-]+(?:\.[a-zA-Z]{2,})?(?::[0-9]+)?(?:/[^\s]*)?$`)
if err != nil {
return nil, fmt.Errorf("failed to compile URL regex: %w", err)
}
tokenRegex, err := regexp.Compile(`^[A-Za-z0-9._-]+$`)
if err != nil {
return nil, fmt.Errorf("failed to compile token regex: %w", err)
}
usernameRegex, err := regexp.Compile(`^[a-zA-Z0-9._-]+$`)
if err != nil {
return nil, fmt.Errorf("failed to compile username regex: %w", err)
}
return &InputValidator{
maxTokenLength: config.MaxTokenLength,
maxURLLength: config.MaxURLLength,
@@ -112,10 +99,10 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
maxEmailLength: config.MaxEmailLength,
maxUsernameLength: config.MaxUsernameLength,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
emailRegex: emailRegex,
urlRegex: urlRegex,
tokenRegex: tokenRegex,
usernameRegex: usernameRegex,
emailRegex: emailRegexPattern,
urlRegex: urlRegexPattern,
tokenRegex: tokenRegexPattern,
usernameRegex: usernameRegexPattern,
sqlInjectionPatterns: []string{
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
"union", "select", "insert", "update", "delete", "drop",
+3
View File
@@ -24,6 +24,7 @@ type Config struct {
Type BackendType
RedisAddr string
RedisPassword string
TLSServerName string
PoolSize int
RedisDB int
CleanupInterval time.Duration
@@ -34,6 +35,8 @@ type Config struct {
EnableCircuitBreaker bool
EnableHealthCheck bool
EnableMetrics bool
EnableTLS bool
TLSSkipVerify bool
}
// DefaultConfig returns a default configuration for in-memory caching
+82 -35
View File
@@ -20,6 +20,7 @@ type HybridBackend struct {
ctx context.Context
syncWriteCacheTypes map[string]bool
asyncWriteBuffer chan *asyncWriteItem
l1BackfillBuffer chan *l1BackfillItem
cancel context.CancelFunc
wg sync.WaitGroup
l1Hits atomic.Int64
@@ -28,6 +29,7 @@ type HybridBackend struct {
l1Writes atomic.Int64
misses atomic.Int64
l2Hits atomic.Int64
l1BackfillDrops atomic.Int64
fallbackMode atomic.Bool
}
@@ -39,6 +41,15 @@ type asyncWriteItem struct {
ttl time.Duration
}
// l1BackfillItem represents a deferred write of an L2-resolved value back into
// L1. Backfills run on a single bounded worker so a burst of L2 hits cannot
// detonate the goroutine count (issue: ~1000% CPU under sustained polling).
type l1BackfillItem struct {
key string
value []byte
ttl time.Duration
}
// Logger interface for structured logging
type Logger interface {
Debugf(format string, args ...interface{})
@@ -114,6 +125,7 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
secondary: config.Secondary,
syncWriteCacheTypes: config.SyncWriteCacheTypes,
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
l1BackfillBuffer: make(chan *l1BackfillItem, config.AsyncBufferSize),
ctx: ctx,
cancel: cancel,
logger: config.Logger,
@@ -123,6 +135,11 @@ func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
h.wg.Add(1)
go h.asyncWriteWorker()
// Start L1 backfill worker (single goroutine) to bound goroutine growth on
// L2 hits regardless of request rate.
h.wg.Add(1)
go h.l1BackfillWorker()
// Start health monitoring
h.wg.Add(1)
go h.healthMonitor()
@@ -147,7 +164,7 @@ func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl t
// Check if we're in fallback mode
if h.fallbackMode.Load() {
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", key)
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", redactKey(key))
return nil // Don't fail the operation if L2 is down
}
@@ -159,13 +176,13 @@ func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl t
// Synchronous write for critical cache types
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
h.errors.Add(1)
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", key, err)
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", redactKey(key), err)
h.recordL2Error()
// Don't fail the operation - L1 write succeeded
return nil
}
h.l2Writes.Add(1)
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", key)
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", redactKey(key))
} else {
// Asynchronous write for non-critical cache types
select {
@@ -175,10 +192,10 @@ func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl t
ttl: ttl,
ctx: ctx,
}:
h.logger.Debugf("Queued async write to L2 for key: %s", key)
h.logger.Debugf("Queued async write to L2 for key: %s", redactKey(key))
default:
// Buffer is full, log and continue
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", key)
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", redactKey(key))
h.errors.Add(1)
}
}
@@ -192,7 +209,7 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
value, ttl, exists, err := h.primary.Get(ctx, key)
if err != nil {
h.errors.Add(1)
h.logger.Debugf("L1 get error for key %s: %v", key, err)
h.logger.Debugf("L1 get error for key %s: %v", redactKey(key), err)
}
if exists {
@@ -210,7 +227,7 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
value, ttl, exists, err = h.secondary.Get(ctx, key)
if err != nil {
h.errors.Add(1)
h.logger.Debugf("L2 get error for key %s: %v", key, err)
h.logger.Debugf("L2 get error for key %s: %v", redactKey(key), err)
h.recordL2Error()
h.misses.Add(1)
return nil, 0, false, nil // Don't propagate L2 errors
@@ -223,18 +240,10 @@ func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Durat
h.l2Hits.Add(1)
// Populate L1 cache with value from L2 (write-through on read)
// Use goroutine to avoid blocking the read path
go func() {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
} else {
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
}
}()
// Populate L1 cache with value from L2 (write-through on read).
// Hand off to the bounded backfill worker instead of spawning a goroutine
// per read - under burst that would mint thousands of goroutines.
h.queueL1Backfill(key, value, ttl)
return value, ttl, true, nil
}
@@ -371,6 +380,7 @@ func (h *HybridBackend) Close() error {
// Close async write channel
close(h.asyncWriteBuffer)
close(h.l1BackfillBuffer)
// Wait for workers to finish with timeout
done := make(chan struct{})
@@ -440,13 +450,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
for key, value := range l2Results {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
}(key, value)
h.queueL1Backfill(key, value, 0) // 0 = primary backend default TTL
}
}
} else {
@@ -455,13 +459,7 @@ func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string]
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
results[key] = value
h.l2Hits.Add(1)
// Asynchronously populate L1
go func(k string, v []byte, t time.Duration) {
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = h.primary.Set(writeCtx, k, v, t)
}(key, value, ttl)
h.queueL1Backfill(key, value, ttl)
}
}
}
@@ -538,6 +536,55 @@ func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, tt
return nil
}
// queueL1Backfill enqueues an L2-resolved value for write-through into L1.
// Drops on full buffer to keep the read path constant-time; the next L2 hit
// for the same key simply re-queues it.
func (h *HybridBackend) queueL1Backfill(key string, value []byte, ttl time.Duration) {
select {
case h.l1BackfillBuffer <- &l1BackfillItem{key: key, value: value, ttl: ttl}:
default:
h.l1BackfillDrops.Add(1)
h.logger.Debugf("L1 backfill buffer full, dropping for key: %s", redactKey(key))
}
}
// l1BackfillWorker drains the backfill queue serially. Single worker is
// intentional - L1 writes are local and cheap, and serializing them keeps
// goroutine count bounded under any read rate.
func (h *HybridBackend) l1BackfillWorker() {
defer h.wg.Done()
for {
select {
case <-h.ctx.Done():
// Drain remaining items best-effort then exit.
for len(h.l1BackfillBuffer) > 0 {
select {
case item := <-h.l1BackfillBuffer:
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
_ = h.primary.Set(writeCtx, item.key, item.value, item.ttl)
cancel()
default:
return
}
}
return
case item, ok := <-h.l1BackfillBuffer:
if !ok {
return
}
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
if err := h.primary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", redactKey(item.key), err)
} else {
h.logger.Debugf("Populated L1 cache from L2 for key: %s", redactKey(item.key))
}
cancel()
}
}
}
// asyncWriteWorker processes asynchronous writes to L2
func (h *HybridBackend) asyncWriteWorker() {
defer h.wg.Done()
@@ -572,11 +619,11 @@ func (h *HybridBackend) asyncWriteWorker() {
writeCtx, cancel := context.WithTimeout(item.ctx, 500*time.Millisecond)
if err := h.secondary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
h.errors.Add(1)
h.logger.Debugf("Async write to L2 failed for key %s: %v", item.key, err)
h.logger.Debugf("Async write to L2 failed for key %s: %v", redactKey(item.key), err)
h.recordL2Error()
} else {
h.l2Writes.Add(1)
h.logger.Debugf("Async write to L2 completed for key: %s", item.key)
h.logger.Debugf("Async write to L2 completed for key: %s", redactKey(item.key))
}
cancel()
}
+112
View File
@@ -0,0 +1,112 @@
//go:build !yaegi
package backends
import (
"context"
"fmt"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestHybridBackend_L1BackfillBounded verifies that a burst of L2 hits does
// not detonate the goroutine count. Pre-fix the code spawned one goroutine
// per Get() L2 hit; post-fix all backfills funnel through a single worker.
func TestHybridBackend_L1BackfillBounded(t *testing.T) {
primary := newMockBackend()
secondary := newMockBackend()
hybrid, err := NewHybridBackend(&HybridConfig{
Primary: primary,
Secondary: secondary,
AsyncBufferSize: 256,
})
require.NoError(t, err)
defer hybrid.Close()
ctx := context.Background()
const burst = 1000
// Pre-populate L2 with `burst` distinct keys so each Get triggers a
// fresh L1 backfill enqueue.
for i := 0; i < burst; i++ {
require.NoError(t, secondary.Set(ctx, fmt.Sprintf("k:%d", i), []byte("v"), time.Minute))
}
baseline := runtime.NumGoroutine()
// Issue the burst as fast as possible; the backfill worker MUST be the
// only goroutine doing L1 writes. Allow brief slack for the test runtime
// scheduling but anything north of +20 means goroutine leakage.
peak := baseline
for i := 0; i < burst; i++ {
_, _, exists, err := hybrid.Get(ctx, fmt.Sprintf("k:%d", i))
require.NoError(t, err)
require.True(t, exists)
if g := runtime.NumGoroutine(); g > peak {
peak = g
}
}
delta := peak - baseline
if delta > 20 {
t.Fatalf("goroutine count grew by %d during burst (baseline=%d peak=%d); backfill worker not bounding goroutines",
delta, baseline, peak)
}
// L1 must eventually catch up via the worker. Worker drains serially so
// give it a generous window proportional to the burst size.
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
var populated int
for i := 0; i < burst; i++ {
if _, _, ok, _ := primary.Get(ctx, fmt.Sprintf("k:%d", i)); ok {
populated++
}
}
// Be lenient: drops are acceptable under buffer pressure, just want
// most of the keys to make it.
if populated >= burst-int(hybrid.l1BackfillDrops.Load()) {
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("L1 not backfilled within deadline: l2Hits=%d l1Writes=%d drops=%d",
hybrid.l2Hits.Load(), hybrid.l1Writes.Load(), hybrid.l1BackfillDrops.Load())
}
// TestHybridBackend_L1BackfillFullDrops verifies the drop semantics when the
// buffer is saturated. Drops must be counted, never block, never spawn a
// goroutine.
func TestHybridBackend_L1BackfillFullDrops(t *testing.T) {
primary := newMockBackend()
secondary := newMockBackend()
// Tiny buffer + slow primary writes via failSet so the worker stays
// blocked enough to overflow the buffer.
hybrid, err := NewHybridBackend(&HybridConfig{
Primary: primary,
Secondary: secondary,
AsyncBufferSize: 4,
})
require.NoError(t, err)
defer hybrid.Close()
// Stop the worker from draining: cancel the underlying context so the
// worker bails out, leaving us with a cold buffer and the queue method
// itself responsible for drop accounting.
hybrid.cancel()
// Wait for worker to exit so it can't drain.
time.Sleep(50 * time.Millisecond)
for i := 0; i < 50; i++ {
hybrid.queueL1Backfill(fmt.Sprintf("k:%d", i), []byte("v"), time.Minute)
}
assert.Greater(t, hybrid.l1BackfillDrops.Load(), int64(0),
"expected some drops when buffer is saturated and worker is stopped")
}
+26
View File
@@ -0,0 +1,26 @@
// Package backends provides cache backend implementations for the Traefik OIDC plugin.
package backends
import (
"crypto/sha256"
"encoding/hex"
)
// redactKey returns a short, deterministic hash prefix of a cache key for use
// in debug/info log lines. Cache keys in this plugin can include raw access /
// refresh / id tokens (any caller may pass an arbitrary string), and CodeQL
// flags `key=%s` formatters as a clear-text-logging sink for HTTP-header-
// sourced taint. The hash preserves cache-key uniqueness in logs (same key →
// same hash, useful for correlating a problematic key across log lines) while
// keeping the raw value out of disk-resident log streams.
//
// 8 hex chars (32 bits) is enough to disambiguate at human-debugging scale
// without making the hash itself a useful lookup primitive for an attacker
// who only has the log stream.
func redactKey(key string) string {
if key == "" {
return "(empty)"
}
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:4])
}
+9 -5
View File
@@ -241,9 +241,11 @@ func (s *cacheShard) evictLRULocked() bool {
element := s.lruList.Back()
if element != nil {
item := element.Value.(*memoryCacheItem)
s.deleteItemLocked(item)
return true
item, ok := element.Value.(*memoryCacheItem)
if ok {
s.deleteItemLocked(item)
return true
}
}
return false
}
@@ -267,8 +269,10 @@ func (s *cacheShard) getOldestAccessTime() time.Time {
element := s.lruList.Back()
if element != nil {
item := element.Value.(*memoryCacheItem)
return item.accessedAt
item, ok := element.Value.(*memoryCacheItem)
if ok {
return item.accessedAt
}
}
return time.Time{}
}
+5 -2
View File
@@ -49,6 +49,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
poolConfig := &PoolConfig{
Address: config.RedisAddr,
Password: config.RedisPassword,
TLSServerName: config.TLSServerName,
DB: config.RedisDB,
MaxConnections: config.PoolSize,
ConnectTimeout: 2 * time.Second,
@@ -57,6 +58,8 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
EnableHealthCheck: true,
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
EnableTLS: config.EnableTLS,
TLSSkipVerify: config.TLSSkipVerify,
}
pool, err := NewConnectionPool(poolConfig)
@@ -345,7 +348,7 @@ func (r *RedisBackend) prefixKey(key string) string {
// executeWithRetry executes a Redis operation with exponential backoff retry logic.
// It checks context cancellation at multiple points to ensure fast abort when the
// caller's context is cancelled (e.g., due to request timeout).
// caller's context is canceled (e.g., due to request timeout).
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
maxRetries := 3
baseDelay := 50 * time.Millisecond // Reduced from 100ms to fail faster
@@ -377,7 +380,7 @@ func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*Red
err = operation(conn)
r.pool.Put(conn)
// Check context after operation - if cancelled, don't bother retrying
// Check context after operation - if canceled, don't bother retrying
if ctx.Err() != nil {
return ctx.Err()
}
+25 -3
View File
@@ -2,6 +2,7 @@ package backends
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
@@ -31,6 +32,7 @@ type ConnectionPool struct {
type PoolConfig struct {
Address string
Password string
TLSServerName string // SNI server name; defaults to host(Address) when empty
DB int
MaxConnections int
ConnectTimeout time.Duration
@@ -39,6 +41,8 @@ type PoolConfig struct {
EnableHealthCheck bool // Enable connection health validation
MaxRetries int // Max retries for failed operations
RetryDelay time.Duration // Initial delay between retries
EnableTLS bool // Wrap connection with TLS (e.g. AWS ElastiCache in-transit encryption)
TLSSkipVerify bool // Skip server certificate verification (escape hatch; not recommended)
}
// NewConnectionPool creates a new connection pool
@@ -96,7 +100,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
// No available connection, create new one if under limit
// #nosec G115 -- MaxConnections is a small config value that fits in int32
if p.totalConns.Load() < int32(p.config.MaxConnections) {
conn, err = p.createConnection()
conn, err = p.createConnection(ctx)
if err != nil {
// If this is the last attempt, return error
if attempt == maxAttempts-1 {
@@ -193,13 +197,31 @@ func (p *ConnectionPool) Stats() map[string]interface{} {
}
// createConnection creates a new Redis connection
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
func (p *ConnectionPool) createConnection(ctx context.Context) (*RedisConn, error) {
// Connect with timeout
dialer := &net.Dialer{
Timeout: p.config.ConnectTimeout,
}
conn, err := dialer.Dial("tcp", p.config.Address)
var conn net.Conn
var err error
if p.config.EnableTLS {
serverName := p.config.TLSServerName
if serverName == "" {
if host, _, splitErr := net.SplitHostPort(p.config.Address); splitErr == nil {
serverName = host
}
}
tlsCfg := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: p.config.TLSSkipVerify, // #nosec G402 -- opt-in escape hatch via TLSSkipVerify config
MinVersion: tls.VersionTLS12,
}
tlsDialer := &tls.Dialer{NetDialer: dialer, Config: tlsCfg}
conn, err = tlsDialer.DialContext(ctx, "tcp", p.config.Address)
} else {
conn, err = dialer.DialContext(ctx, "tcp", p.config.Address)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
+31 -1
View File
@@ -3,6 +3,7 @@ package backends
import (
"context"
"errors"
"strings"
"sync"
"testing"
"time"
@@ -201,7 +202,7 @@ func TestConnectionPool_ContextCancellation(t *testing.T) {
conn, err := pool.Get(context.Background())
require.NoError(t, err)
// Try to get another with cancelled context
// Try to get another with canceled context
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
@@ -617,4 +618,33 @@ func TestRedisConn_TooManyArguments(t *testing.T) {
assert.NotContains(t, err.Error(), "too many arguments")
}
})
}
// TestRedisConn_RejectOversizedArgumentBytes is a regression test for CodeQL
// alert #10 (go/allocation-size-overflow). A single argument larger than
// maxTotalArgBytes (64 MiB) must be rejected by the per-argument overflow
// guard in Do() before any allocation is attempted.
func TestRedisConn_RejectOversizedArgumentBytes(t *testing.T) {
mr := NewMiniredisServer(t)
pool, err := NewConnectionPool(&PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
})
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
defer pool.Put(conn)
largeArg := strings.Repeat("x", (64<<20)+1)
_, err = conn.Do("SET", "k", largeArg)
require.Error(t, err)
assert.Contains(t, err.Error(), "arguments too large")
}
+230
View File
@@ -0,0 +1,230 @@
package backends
import (
"bufio"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// drainRESPRequest consumes a single RESP request (array or inline) from r and
// returns true on success. Any read error returns false.
func drainRESPRequest(r *bufio.Reader) bool {
header, err := r.ReadString('\n')
if err != nil {
return false
}
if !strings.HasPrefix(header, "*") {
return true // inline command (single line) — already consumed
}
n, err := strconv.Atoi(strings.TrimRight(strings.TrimPrefix(header, "*"), "\r\n"))
if err != nil || n <= 0 {
return false
}
for i := 0; i < n; i++ {
// Each bulk: "$len\r\n<bytes>\r\n"
if _, err := r.ReadString('\n'); err != nil {
return false
}
if _, err := r.ReadString('\n'); err != nil {
return false
}
}
return true
}
// startTLSPingServer spins up a TLS listener that speaks just enough RESP to
// answer PING with +PONG. Returns the listener address and a self-signed cert.
func startTLSPingServer(t *testing.T) (addr string, certPEM []byte, stop func()) {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "localhost"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
require.NoError(t, err)
tlsCert := tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: priv,
}
listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{tlsCert},
MinVersion: tls.VersionTLS12,
})
require.NoError(t, err)
var wg sync.WaitGroup
stopCh := make(chan struct{})
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopCh:
return
default:
}
c, acceptErr := listener.Accept()
if acceptErr != nil {
return
}
wg.Add(1)
go func(conn net.Conn) {
defer wg.Done()
defer conn.Close()
reader := bufio.NewReader(conn)
for {
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if !drainRESPRequest(reader) {
return
}
_, _ = conn.Write([]byte("+PONG\r\n"))
}
}(c)
}
}()
stop = func() {
close(stopCh)
_ = listener.Close()
wg.Wait()
}
return listener.Addr().String(), der, stop
}
// TestConnectionPool_TLSDial_SkipVerify verifies that EnableTLS=true with
// TLSSkipVerify=true successfully negotiates TLS and exchanges a Redis command.
// Regression test for issue #133 (enableTLS not propagated to client).
func TestConnectionPool_TLSDial_SkipVerify(t *testing.T) {
addr, _, stop := startTLSPingServer(t)
defer stop()
pool, err := NewConnectionPool(&PoolConfig{
Address: addr,
MaxConnections: 2,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: true,
})
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, conn)
defer pool.Put(conn)
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
}
// TestConnectionPool_TLSDial_VerifyFails verifies that EnableTLS=true with
// TLSSkipVerify=false rejects a self-signed server cert.
func TestConnectionPool_TLSDial_VerifyFails(t *testing.T) {
addr, _, stop := startTLSPingServer(t)
defer stop()
pool, err := NewConnectionPool(&PoolConfig{
Address: addr,
MaxConnections: 2,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: false,
})
require.NoError(t, err)
defer pool.Close()
_, err = pool.Get(context.Background())
require.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), "tls")
}
// TestConnectionPool_TLSDial_PlainServerRejected verifies that EnableTLS=true
// fails to handshake against a plain (non-TLS) listener.
func TestConnectionPool_TLSDial_PlainServerRejected(t *testing.T) {
plain, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer plain.Close()
go func() {
for {
c, acceptErr := plain.Accept()
if acceptErr != nil {
return
}
_ = c.Close()
}
}()
pool, err := NewConnectionPool(&PoolConfig{
Address: plain.Addr().String(),
MaxConnections: 1,
ConnectTimeout: 1 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: true,
TLSSkipVerify: true,
})
require.NoError(t, err)
defer pool.Close()
_, err = pool.Get(context.Background())
require.Error(t, err)
}
// TestConnectionPool_PlainDial_StillWorks ensures non-TLS path is unaffected
// when EnableTLS=false (default).
func TestConnectionPool_PlainDial_StillWorks(t *testing.T) {
mr := NewMiniredisServer(t)
pool, err := NewConnectionPool(&PoolConfig{
Address: mr.GetAddr(),
MaxConnections: 1,
ConnectTimeout: 2 * time.Second,
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
EnableTLS: false,
})
require.NoError(t, err)
defer pool.Close()
conn, err := pool.Get(context.Background())
require.NoError(t, err)
defer pool.Put(conn)
resp, err := conn.Do("PING")
require.NoError(t, err)
assert.Equal(t, "PONG", resp)
}
+15 -34
View File
@@ -7,52 +7,34 @@ import (
"io"
"strconv"
"strings"
"sync"
)
// RESP (REdis Serialization Protocol) implementation
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
//
// NOTE: sync.Pool was intentionally removed for Yaegi compatibility.
// Yaegi (Traefik's Go interpreter) has issues with sync.Pool and reflection
// that cause "reflect: call of reflect.Value.Field on zero Value" panics.
// See: https://github.com/lukaszraczylo/traefikoidc/issues/120
var (
ErrInvalidRESP = errors.New("invalid RESP response")
ErrNilResponse = errors.New("nil response")
)
// Object pools for memory optimization - reduces allocations by 50-70%
var (
readerPool = sync.Pool{
New: func() interface{} {
return &RESPReader{
r: bufio.NewReaderSize(nil, 4096),
}
},
}
writerPool = sync.Pool{
New: func() interface{} {
return &RESPWriter{
w: nil,
}
},
}
)
// RESPWriter writes RESP protocol messages
type RESPWriter struct {
w io.Writer
}
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
// NewRESPWriter creates a new RESP writer
func NewRESPWriter(w io.Writer) *RESPWriter {
writer := writerPool.Get().(*RESPWriter)
writer.w = w
return writer
return &RESPWriter{w: w}
}
// Release returns the writer to the pool for reuse
// Release is a no-op for API compatibility (pooling removed for Yaegi compatibility)
func (w *RESPWriter) Release() {
w.w = nil
writerPool.Put(w)
// No-op: pooling removed for Yaegi compatibility
}
// WriteCommand writes a Redis command in RESP array format
@@ -78,17 +60,16 @@ type RESPReader struct {
r *bufio.Reader
}
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
// NewRESPReader creates a new RESP reader
func NewRESPReader(r io.Reader) *RESPReader {
reader := readerPool.Get().(*RESPReader)
reader.r.Reset(r)
return reader
return &RESPReader{
r: bufio.NewReaderSize(r, 4096),
}
}
// Release returns the reader to the pool for reuse
// Release is a no-op for API compatibility (pooling removed for Yaegi compatibility)
func (r *RESPReader) Release() {
r.r.Reset(nil)
readerPool.Put(r)
// No-op: pooling removed for Yaegi compatibility
}
// ReadResponse reads a RESP response and returns the parsed value
+1 -1
View File
@@ -87,7 +87,7 @@ func (s *SingleflightCache) GetOrFetch(ctx context.Context, key string, fetcher
// If successful, store in cache
if call.err == nil && call.val != nil {
// Use a background context for cache storage to ensure it completes
// even if the original context is cancelled
// even if the original context is canceled
storeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = s.backend.Set(storeCtx, key, call.val, call.ttl)
cancel()
+2 -2
View File
@@ -190,7 +190,7 @@ func (c *Cache) Set(key string, value interface{}, ttl time.Duration) error {
c.currentSize++
atomic.AddInt64(&c.sets, 1)
c.logger.Debugf("Cache: Set key=%s, size=%d, ttl=%v", key, size, ttl)
c.logger.Debugf("Cache: Set key=%s, size=%d, ttl=%v", redactKey(key), size, ttl)
return nil
}
@@ -346,7 +346,7 @@ func (c *Cache) evictLRU() {
item, _ := elem.Value.(*Item) // Safe to ignore: type assertion from known type
c.removeItem(item.Key, item)
atomic.AddInt64(&c.evictions, 1)
c.logger.Debugf("Cache: Evicted LRU item key=%s", item.Key)
c.logger.Debugf("Cache: Evicted LRU item key=%s", redactKey(item.Key))
}
}
+22
View File
@@ -0,0 +1,22 @@
// Package cache provides the in-memory cache implementation for the Traefik
// OIDC plugin.
package cache
import (
"crypto/sha256"
"encoding/hex"
)
// redactKey returns a short, deterministic hash prefix of a cache key for use
// in debug/info log lines. Cache keys may include raw access / refresh / id
// tokens (callers pass arbitrary strings) and CodeQL flags `key=%s`
// formatters as a clear-text-logging sink for HTTP-header-sourced taint.
// The hash preserves uniqueness in logs (same key → same hash) while keeping
// the raw value out of disk-resident log streams.
func redactKey(key string) string {
if key == "" {
return "(empty)"
}
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:4])
}
+1 -1
View File
@@ -232,7 +232,7 @@ func (m *Manager) Close() error {
var firstErr error
if err := m.tokenCache.Close(); err != nil && firstErr == nil {
if err := m.tokenCache.Close(); err != nil {
firstErr = err
}
if err := m.metadataCache.Close(); err != nil && firstErr == nil {
+12 -4
View File
@@ -842,10 +842,18 @@ func TestWorkerPool_TaskPanic(t *testing.T) {
t.Error("Timeout waiting for tasks")
}
// Pool should still be functional
metrics := pool.GetMetrics()
if metrics["tasksFailed"].(int64) < 1 {
t.Error("Expected at least one failed task")
// tasksFailed is incremented in the worker's deferred recover(), which runs
// AFTER the panicking task's own `defer wg.Done()`. wg.Wait() above can
// therefore return before the failure is recorded — reading the counter
// immediately is a race that flakes on slow/contended CI runners. Poll until
// the failure lands (or time out).
deadline := time.Now().Add(2 * time.Second)
for pool.GetMetrics()["tasksFailed"].(int64) < 1 {
if time.Now().After(deadline) {
t.Error("Expected at least one failed task")
break
}
time.Sleep(5 * time.Millisecond)
}
}
+1 -1
View File
@@ -397,7 +397,7 @@ func (wp *WorkerPool) Submit(task func()) error {
}
// worker is the main worker routine
func (wp *WorkerPool) worker(id int) {
func (wp *WorkerPool) worker(_ int) {
defer wp.workerWg.Done()
for {
+155
View File
@@ -0,0 +1,155 @@
package dcrstorage
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
)
// FileStore implements Store using file-based storage.
// This is the default storage backend for backward compatibility with existing deployments.
// For distributed environments, consider using RedisStore instead.
type FileStore struct {
basePath string
logger Logger
mu sync.RWMutex
}
// NewFileStore creates a new file-based credentials store.
// If basePath is empty, defaults to /tmp/oidc-client-credentials.json
func NewFileStore(basePath string, logger Logger) *FileStore {
if basePath == "" {
basePath = "/tmp/oidc-client-credentials.json"
}
if logger == nil {
logger = NoOpLogger()
}
return &FileStore{
basePath: basePath,
logger: logger,
}
}
// BasePath returns the base path used for storing credentials
func (s *FileStore) BasePath() string {
return s.basePath
}
// GetFilePath returns the file path for storing credentials for a specific provider.
// For multi-tenant scenarios, each provider gets a separate file based on URL hash.
func (s *FileStore) GetFilePath(providerURL string) string {
if providerURL == "" {
return s.basePath
}
// Hash provider URL for filename safety and uniqueness
hash := sha256.Sum256([]byte(providerURL))
hashStr := hex.EncodeToString(hash[:8]) // Use first 8 bytes for shorter filename
ext := filepath.Ext(s.basePath)
base := strings.TrimSuffix(s.basePath, ext)
if ext == "" {
ext = ".json"
}
return fmt.Sprintf("%s-%s%s", base, hashStr, ext)
}
// Save stores the client registration response to a file
func (s *FileStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
if creds == nil {
return fmt.Errorf("credentials cannot be nil")
}
s.mu.Lock()
defer s.mu.Unlock()
filePath := s.GetFilePath(providerURL)
// Ensure parent directory exists
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create credentials directory: %w", err)
}
data, err := json.MarshalIndent(creds, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal credentials: %w", err)
}
// Write with restrictive permissions (owner read/write only)
if err := os.WriteFile(filePath, data, 0600); err != nil {
return fmt.Errorf("failed to write credentials file: %w", err)
}
s.logger.Debugf("Saved client credentials to %s", filePath)
return nil
}
// Load retrieves stored credentials from a file.
// Returns nil, nil if no credentials file exists (not an error).
func (s *FileStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
filePath := s.GetFilePath(providerURL)
// #nosec G304 -- path is constructed from trusted config values via GetFilePath()
data, err := os.ReadFile(filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil // No credentials file exists - not an error
}
return nil, fmt.Errorf("failed to read credentials file: %w", err)
}
var creds ClientRegistrationResponse
if err := json.Unmarshal(data, &creds); err != nil {
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
}
s.logger.Debugf("Loaded client credentials from %s", filePath)
return &creds, nil
}
// Delete removes the credentials file for a provider
func (s *FileStore) Delete(ctx context.Context, providerURL string) error {
s.mu.Lock()
defer s.mu.Unlock()
filePath := s.GetFilePath(providerURL)
if err := os.Remove(filePath); err != nil {
if os.IsNotExist(err) {
return nil // File doesn't exist, nothing to delete
}
return fmt.Errorf("failed to remove credentials file: %w", err)
}
s.logger.Debugf("Deleted client credentials from %s", filePath)
return nil
}
// Exists checks if credentials exist for a provider
func (s *FileStore) Exists(ctx context.Context, providerURL string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
filePath := s.GetFilePath(providerURL)
_, err := os.Stat(filePath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, fmt.Errorf("failed to check credentials file: %w", err)
}
return true, nil
}
+161
View File
@@ -0,0 +1,161 @@
package dcrstorage
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"sync"
"time"
)
// Cache defines the interface for cache operations needed by RedisStore.
// This allows the main package to provide a cache implementation without
// creating circular dependencies.
type Cache interface {
// Get retrieves a value from the cache
Get(key string) (any, bool)
// Set stores a value in the cache with a TTL
Set(key string, value any, ttl time.Duration) error
// Delete removes a value from the cache
Delete(key string)
}
// RedisStore implements Store using a Cache-backed storage.
// This storage backend enables sharing DCR credentials across multiple Traefik instances
// in distributed environments (e.g., Kubernetes with multiple ingress pods).
type RedisStore struct {
cache Cache
keyPrefix string
logger Logger
mu sync.RWMutex
}
// NewRedisStore creates a new cache-backed credentials store.
// The cache should be configured with a Redis backend for distributed storage.
// If keyPrefix is empty, defaults to "dcr:creds:"
func NewRedisStore(cache Cache, keyPrefix string, logger Logger) *RedisStore {
if keyPrefix == "" {
keyPrefix = "dcr:creds:"
}
if logger == nil {
logger = NoOpLogger()
}
return &RedisStore{
cache: cache,
keyPrefix: keyPrefix,
logger: logger,
}
}
// makeKey creates a unique cache key for a provider URL.
// Uses SHA256 hash of the provider URL for consistent key generation across nodes.
func (s *RedisStore) makeKey(providerURL string) string {
if providerURL == "" {
return s.keyPrefix + "default"
}
hash := sha256.Sum256([]byte(providerURL))
return s.keyPrefix + hex.EncodeToString(hash[:])
}
// Save stores the client registration response in the cache.
// TTL is calculated based on client_secret_expires_at if available.
func (s *RedisStore) Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error {
if creds == nil {
return fmt.Errorf("credentials cannot be nil")
}
s.mu.Lock()
defer s.mu.Unlock()
key := s.makeKey(providerURL)
// Calculate TTL based on client_secret_expires_at if available
ttl := 30 * 24 * time.Hour // Default: 30 days
if creds.ClientSecretExpiresAt > 0 {
expiresAt := time.Unix(creds.ClientSecretExpiresAt, 0)
ttl = time.Until(expiresAt)
if ttl < 0 {
return fmt.Errorf("credentials already expired")
}
// Add a small buffer to ensure we don't serve expired credentials
if ttl > time.Minute {
ttl -= time.Minute
}
}
// Serialize credentials to JSON for storage
data, err := json.Marshal(creds)
if err != nil {
return fmt.Errorf("failed to marshal credentials: %w", err)
}
// Store as string in cache (will be serialized by the cache backend)
if err := s.cache.Set(key, string(data), ttl); err != nil {
return fmt.Errorf("failed to store credentials in cache: %w", err)
}
s.logger.Debugf("Saved client credentials to cache with key %s (TTL: %v)", key, ttl)
return nil
}
// Load retrieves stored credentials from the cache.
// Returns nil, nil if no credentials exist (not an error).
func (s *RedisStore) Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error) {
s.mu.RLock()
defer s.mu.RUnlock()
key := s.makeKey(providerURL)
value, exists := s.cache.Get(key)
if !exists {
return nil, nil // No credentials stored - not an error
}
// Handle different value types from cache
var jsonData string
switch v := value.(type) {
case string:
jsonData = v
case []byte:
jsonData = string(v)
default:
// Try to see if it's already the struct (from local cache)
if creds, ok := value.(*ClientRegistrationResponse); ok {
return creds, nil
}
return nil, fmt.Errorf("unexpected credentials type in cache: %T", value)
}
var creds ClientRegistrationResponse
if err := json.Unmarshal([]byte(jsonData), &creds); err != nil {
return nil, fmt.Errorf("failed to parse credentials from cache: %w", err)
}
s.logger.Debugf("Loaded client credentials from cache with key %s", key)
return &creds, nil
}
// Delete removes stored credentials from the cache
func (s *RedisStore) Delete(ctx context.Context, providerURL string) error {
s.mu.Lock()
defer s.mu.Unlock()
key := s.makeKey(providerURL)
s.cache.Delete(key)
s.logger.Debugf("Deleted client credentials from cache with key %s", key)
return nil
}
// Exists checks if credentials exist in the cache for a provider
func (s *RedisStore) Exists(ctx context.Context, providerURL string) (bool, error) {
s.mu.RLock()
defer s.mu.RUnlock()
key := s.makeKey(providerURL)
_, exists := s.cache.Get(key)
return exists, nil
}
+90
View File
@@ -0,0 +1,90 @@
// Package dcrstorage provides storage backends for OIDC Dynamic Client Registration credentials.
// It supports both file-based and Redis-based storage for persisting client credentials
// across application restarts and distributed deployments.
package dcrstorage
import (
"context"
)
// StorageBackend represents the type of storage backend for DCR credentials
type StorageBackend string
const (
// StorageBackendFile uses file-based storage (default for backward compatibility)
StorageBackendFile StorageBackend = "file"
// StorageBackendRedis uses Redis for distributed storage
StorageBackendRedis StorageBackend = "redis"
// StorageBackendAuto automatically selects Redis if available, otherwise file
StorageBackendAuto StorageBackend = "auto"
)
// Logger interface for DCR storage operations
type Logger interface {
Debug(msg string)
Debugf(format string, args ...any)
Info(msg string)
Infof(format string, args ...any)
Error(msg string)
Errorf(format string, args ...any)
}
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
type ClientRegistrationResponse struct {
SubjectType string `json:"subject_type,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
Scope string `json:"scope,omitempty"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
TOSURI string `json:"tos_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ApplicationType string `json:"application_type,omitempty"`
ClientID string `json:"client_id"`
ClientName string `json:"client_name,omitempty"`
JWKSURI string `json:"jwks_uri,omitempty"`
ClientURI string `json:"client_uri,omitempty"`
Contacts []string `json:"contacts,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
RedirectURIs []string `json:"redirect_uris,omitempty"`
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
}
// Store defines the interface for storing DCR credentials.
// This abstraction allows different storage backends (file, Redis) to be used
// for persisting OIDC Dynamic Client Registration credentials across nodes.
type Store interface {
// Save stores the client registration response for a provider
// The providerURL is used as a key to support multi-tenant scenarios
Save(ctx context.Context, providerURL string, creds *ClientRegistrationResponse) error
// Load retrieves stored credentials for a provider
// Returns nil, nil if no credentials exist (not an error)
Load(ctx context.Context, providerURL string) (*ClientRegistrationResponse, error)
// Delete removes stored credentials for a provider
Delete(ctx context.Context, providerURL string) error
// Exists checks if credentials exist for a provider
Exists(ctx context.Context, providerURL string) (bool, error)
}
// noOpLogger is a no-op implementation of Logger for default use
type noOpLogger struct{}
func (n noOpLogger) Debug(msg string) {}
func (n noOpLogger) Debugf(format string, args ...any) {}
func (n noOpLogger) Info(msg string) {}
func (n noOpLogger) Infof(format string, args ...any) {}
func (n noOpLogger) Error(msg string) {}
func (n noOpLogger) Errorf(format string, args ...any) {}
// NoOpLogger returns a no-op logger instance
func NoOpLogger() Logger {
return noOpLogger{}
}
+464
View File
@@ -0,0 +1,464 @@
package dcrstorage
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
)
// mockCache implements Cache for testing
type mockCache struct {
data map[string]cacheEntry
mu sync.RWMutex
}
type cacheEntry struct {
value any
expiresAt time.Time
}
func newMockCache() *mockCache {
return &mockCache{data: make(map[string]cacheEntry)}
}
func (m *mockCache) Get(key string) (any, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
entry, ok := m.data[key]
if !ok {
return nil, false
}
if time.Now().After(entry.expiresAt) {
return nil, false
}
return entry.value, true
}
func (m *mockCache) Set(key string, value any, ttl time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
m.data[key] = cacheEntry{
value: value,
expiresAt: time.Now().Add(ttl),
}
return nil
}
func (m *mockCache) Delete(key string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.data, key)
}
func TestFileStore_SaveLoad(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
testCreds := &ClientRegistrationResponse{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "test-access-token",
RegistrationClientURI: "https://example.com/register/test-client-id",
RedirectURIs: []string{"https://app.example.com/callback"},
GrantTypes: []string{"authorization_code", "refresh_token"},
ResponseTypes: []string{"code"},
TokenEndpointAuthMethod: "client_secret_basic",
}
ctx := context.Background()
providerURL := "https://auth.example.com"
t.Run("save and load credentials", func(t *testing.T) {
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
if loaded.RegistrationAccessToken != testCreds.RegistrationAccessToken {
t.Errorf("RegistrationAccessToken mismatch: got %s, want %s", loaded.RegistrationAccessToken, testCreds.RegistrationAccessToken)
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
tempDir2 := t.TempDir()
store2 := NewFileStore(filepath.Join(tempDir2, "nonexistent.json"), nil)
loaded, err := store2.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent file: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
exists, err = store.Exists(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if exists {
t.Error("Expected credentials to not exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("delete non-existent credentials", func(t *testing.T) {
err := store.Delete(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Delete should not error for non-existent: %v", err)
}
})
}
func TestFileStore_MultiProvider(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
provider1 := "https://auth1.example.com"
provider2 := "https://auth2.example.com"
creds1 := &ClientRegistrationResponse{
ClientID: "client-1",
ClientSecret: "secret-1",
}
creds2 := &ClientRegistrationResponse{
ClientID: "client-2",
ClientSecret: "secret-2",
}
if err := store.Save(ctx, provider1, creds1); err != nil {
t.Fatalf("Failed to save creds1: %v", err)
}
if err := store.Save(ctx, provider2, creds2); err != nil {
t.Fatalf("Failed to save creds2: %v", err)
}
loaded1, err := store.Load(ctx, provider1)
if err != nil {
t.Fatalf("Failed to load creds1: %v", err)
}
if loaded1.ClientID != "client-1" {
t.Errorf("Provider 1 ClientID mismatch: got %s", loaded1.ClientID)
}
loaded2, err := store.Load(ctx, provider2)
if err != nil {
t.Fatalf("Failed to load creds2: %v", err)
}
if loaded2.ClientID != "client-2" {
t.Errorf("Provider 2 ClientID mismatch: got %s", loaded2.ClientID)
}
if err := store.Delete(ctx, provider1); err != nil {
t.Fatalf("Failed to delete creds1: %v", err)
}
exists, _ := store.Exists(ctx, provider2)
if !exists {
t.Error("Provider 2 credentials should still exist")
}
}
func TestFileStore_ConcurrentAccess(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
providerURL := "https://auth.example.com"
creds := &ClientRegistrationResponse{
ClientID: "test-client",
ClientSecret: "test-secret",
}
var wg sync.WaitGroup
concurrency := 10
for range concurrency {
wg.Add(1)
go func() {
defer wg.Done()
_ = store.Save(ctx, providerURL, creds)
}()
}
wg.Wait()
for range concurrency {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = store.Load(ctx, providerURL)
}()
}
wg.Wait()
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load after concurrent access: %v", err)
}
if loaded == nil || loaded.ClientID != "test-client" {
t.Error("Credentials corrupted after concurrent access")
}
}
func TestFileStore_InvalidInput(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
t.Run("empty provider URL uses default path", func(t *testing.T) {
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "", creds)
if err != nil {
t.Fatalf("Save with empty provider URL failed: %v", err)
}
loaded, err := store.Load(ctx, "")
if err != nil {
t.Fatalf("Load with empty provider URL failed: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials with empty provider URL")
}
})
}
func TestFileStore_DefaultPath(t *testing.T) {
t.Parallel()
store := NewFileStore("", nil)
if store.BasePath() == "" {
t.Error("Expected default base path")
}
}
func TestRedisStore_WithMockCache(t *testing.T) {
t.Parallel()
cache := newMockCache()
store := NewRedisStore(cache, "", nil)
ctx := context.Background()
providerURL := "https://auth.example.com"
testCreds := &ClientRegistrationResponse{
ClientID: "redis-test-client",
ClientSecret: "redis-test-secret",
ClientSecretExpiresAt: time.Now().Add(24 * time.Hour).Unix(),
RegistrationAccessToken: "redis-test-token",
RedirectURIs: []string{"https://app.example.com/callback"},
}
t.Run("save and load credentials", func(t *testing.T) {
err := store.Save(ctx, providerURL, testCreds)
if err != nil {
t.Fatalf("Failed to save credentials: %v", err)
}
loaded, err := store.Load(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to load credentials: %v", err)
}
if loaded == nil {
t.Fatal("Expected credentials but got nil")
}
if loaded.ClientID != testCreds.ClientID {
t.Errorf("ClientID mismatch: got %s, want %s", loaded.ClientID, testCreds.ClientID)
}
if loaded.ClientSecret != testCreds.ClientSecret {
t.Errorf("ClientSecret mismatch: got %s, want %s", loaded.ClientSecret, testCreds.ClientSecret)
}
})
t.Run("exists check", func(t *testing.T) {
exists, err := store.Exists(ctx, providerURL)
if err != nil {
t.Fatalf("Exists check failed: %v", err)
}
if !exists {
t.Error("Expected credentials to exist")
}
})
t.Run("delete credentials", func(t *testing.T) {
err := store.Delete(ctx, providerURL)
if err != nil {
t.Fatalf("Failed to delete credentials: %v", err)
}
exists, _ := store.Exists(ctx, providerURL)
if exists {
t.Error("Expected credentials to be deleted")
}
})
t.Run("load non-existent credentials", func(t *testing.T) {
loaded, err := store.Load(ctx, "https://nonexistent.example.com")
if err != nil {
t.Fatalf("Unexpected error for non-existent: %v", err)
}
if loaded != nil {
t.Error("Expected nil for non-existent credentials")
}
})
}
func TestRedisStore_TTLFromExpiry(t *testing.T) {
t.Parallel()
cache := newMockCache()
store := NewRedisStore(cache, "", nil)
ctx := context.Background()
t.Run("expired credentials should fail", func(t *testing.T) {
expiredCreds := &ClientRegistrationResponse{
ClientID: "expired-client",
ClientSecret: "expired-secret",
ClientSecretExpiresAt: time.Now().Add(-1 * time.Hour).Unix(),
}
err := store.Save(ctx, "https://expired.example.com", expiredCreds)
if err == nil {
t.Error("Expected error for expired credentials")
}
})
t.Run("credentials without expiry use default TTL", func(t *testing.T) {
creds := &ClientRegistrationResponse{
ClientID: "no-expiry-client",
ClientSecret: "no-expiry-secret",
ClientSecretExpiresAt: 0,
}
err := store.Save(ctx, "https://noexpiry.example.com", creds)
if err != nil {
t.Fatalf("Failed to save credentials without expiry: %v", err)
}
})
}
func TestRedisStore_InvalidInput(t *testing.T) {
t.Parallel()
cache := newMockCache()
store := NewRedisStore(cache, "", nil)
ctx := context.Background()
t.Run("save nil credentials", func(t *testing.T) {
err := store.Save(ctx, "https://example.com", nil)
if err == nil {
t.Error("Expected error for nil credentials")
}
})
}
func TestFileStore_CorruptedFile(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
basePath := filepath.Join(tempDir, "credentials.json")
store := NewFileStore(basePath, nil)
ctx := context.Background()
providerURL := "https://auth.example.com"
filePath := store.GetFilePath(providerURL)
if err := os.WriteFile(filePath, []byte("{corrupted json"), 0600); err != nil {
t.Fatalf("Failed to write corrupted file: %v", err)
}
_, err := store.Load(ctx, providerURL)
if err == nil {
t.Error("Expected error for corrupted JSON")
}
}
func TestFileStore_DirectoryCreation(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
deepPath := filepath.Join(tempDir, "deep", "nested", "path", "credentials.json")
store := NewFileStore(deepPath, nil)
ctx := context.Background()
creds := &ClientRegistrationResponse{ClientID: "test"}
err := store.Save(ctx, "https://example.com", creds)
if err != nil {
t.Fatalf("Failed to save with nested directory: %v", err)
}
loaded, err := store.Load(ctx, "https://example.com")
if err != nil {
t.Fatalf("Failed to load after nested directory creation: %v", err)
}
if loaded == nil || loaded.ClientID != "test" {
t.Error("Failed to load credentials from nested directory")
}
}
+1 -1
View File
@@ -173,7 +173,7 @@ func (m *FeatureManager) LoadFromEnv() {
for name, flag := range flags {
envVar := "FEATURE_" + name
if value := os.Getenv(envVar); value != "" {
enabled := strings.ToLower(value) == "true" || value == "1"
enabled := strings.EqualFold(value, "true") || value == "1"
flag.enabled.Store(enabled)
}
}
+1 -1
View File
@@ -40,7 +40,7 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
// Remove offline_access scope as Cognito doesn't use it (case-insensitive)
var filteredScopes []string
for _, scope := range scopes {
if strings.ToLower(scope) != ScopeOfflineAccess {
if !strings.EqualFold(scope, ScopeOfflineAccess) {
filteredScopes = append(filteredScopes, scope)
}
}
+2 -1
View File
@@ -147,7 +147,8 @@ func (r *ProviderRegistry) detectProviderUnsafe(issuerURL string) OIDCProvider {
return p
}
case ProviderTypeKeycloak:
if strings.Contains(host, "keycloak") || strings.Contains(normalizedURL.Path, "/auth/realms/") {
// Match both Keycloak <17 (`/auth/realms/`) and 17+ (`/realms/`).
if strings.Contains(host, "keycloak") || strings.Contains(normalizedURL.Path, "/realms/") {
return p
}
case ProviderTypeAWSCognito:
+6 -1
View File
@@ -225,10 +225,15 @@ func TestProviderRegistry_DetectProvider(t *testing.T) {
expected: oktaProvider,
},
{
name: "Keycloak provider detection",
name: "Keycloak provider detection (legacy /auth/realms/)",
issuerURL: "https://auth.example.com/auth/realms/master",
expected: keycloakProvider,
},
{
name: "Keycloak provider detection (modern /realms/, KC 17+)",
issuerURL: "https://auth.example.com/realms/master",
expected: keycloakProvider,
},
{
name: "AWS Cognito provider detection",
issuerURL: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_example",
+11 -10
View File
@@ -18,16 +18,17 @@ func GetProviderWarnings(providerType ProviderType) []ProviderWarning {
switch providerType {
case ProviderTypeGitHub:
warnings = append(warnings, ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "warning",
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
})
warnings = append(warnings, ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "info",
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
})
warnings = append(warnings,
ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "warning",
Message: "GitHub uses OAuth 2.0, not OpenID Connect. ID tokens are not available. Use access tokens for API calls only.",
},
ProviderWarning{
ProviderType: ProviderTypeGitHub,
Level: "info",
Message: "GitHub OAuth apps do not support refresh tokens. Users will need to re-authenticate when tokens expire.",
})
case ProviderTypeAuth0:
warnings = append(warnings, ProviderWarning{
+4 -3
View File
@@ -116,7 +116,7 @@ func (re *RetryExecutor) ExecuteWithContext(ctx context.Context, fn func() error
// Continue to next attempt
case <-ctx.Done():
re.RecordFailure()
return fmt.Errorf("retry cancelled: %w", ctx.Err())
return fmt.Errorf("retry canceled: %w", ctx.Err())
}
}
@@ -301,7 +301,7 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
}
}
allMetrics["summary"] = map[string]interface{}{
summary := map[string]interface{}{
"totalMechanisms": len(rm.mechanisms),
"totalRequests": totalRequests,
"totalSuccesses": totalSuccesses,
@@ -310,8 +310,9 @@ func (rm *RecoveryMetrics) GetAllMetrics() map[string]interface{} {
if totalRequests > 0 {
successRate := float64(totalSuccesses) / float64(totalRequests) * 100
allMetrics["summary"].(map[string]interface{})["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
summary["overallSuccessRate"] = fmt.Sprintf("%.2f%%", successRate)
}
allMetrics["summary"] = summary
return allMetrics
}
+3 -3
View File
@@ -223,7 +223,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelled(t *testing.T) {
wg.Wait()
if execErr == nil {
t.Error("Expected error when context is cancelled")
t.Error("Expected error when context is canceled")
}
}
@@ -240,7 +240,7 @@ func TestRetryExecutor_ExecuteWithContext_ContextCancelledBeforeStart(t *testing
})
if err == nil {
t.Error("Expected error when context is already cancelled")
t.Error("Expected error when context is already canceled")
}
}
@@ -282,7 +282,7 @@ func TestRetryExecutor_isRetryableError(t *testing.T) {
{name: "timeout", err: errors.New("TIMEOUT"), expected: true}, // case insensitive
{name: "EOF", err: errors.New("EOF"), expected: false},
{name: "random error", err: errors.New("something else"), expected: false},
{name: "context cancelled", err: context.Canceled, expected: false},
{name: "context canceled", err: context.Canceled, expected: false},
{name: "context deadline exceeded", err: context.DeadlineExceeded, expected: false},
}
+23 -1
View File
@@ -155,12 +155,34 @@ func DetermineScheme(req *http.Request, forceHTTPS bool) string {
// It checks X-Forwarded-Host header first (for proxy scenarios),
// then falls back to req.Host.
func DetermineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
if host := sanitizeForwardedHost(req.Header.Get("X-Forwarded-Host")); host != "" {
return host
}
return req.Host
}
// sanitizeForwardedHost returns a single, well-formed host from a (possibly
// comma-separated) X-Forwarded-Host header, or "" if none is usable. It takes
// only the first value and rejects whitespace and control characters, so a
// crafted header cannot inject CRLF, smuggle a second host, or otherwise poison
// the redirect URLs built from the result.
func sanitizeForwardedHost(v string) string {
if v == "" {
return ""
}
if i := strings.IndexByte(v, ','); i >= 0 {
v = v[:i]
}
v = strings.TrimSpace(v)
if v == "" {
return ""
}
if strings.IndexFunc(v, func(r rune) bool { return r < 0x20 || r == 0x7f || r == ' ' }) >= 0 {
return ""
}
return v
}
// BuildFullURL constructs a URL from scheme, host, and path components.
// It handles absolute URLs (returning them as-is) and ensures paths have leading slashes.
func BuildFullURL(scheme, host, path string) string {
+135
View File
@@ -0,0 +1,135 @@
package traefikoidc
import (
"net/http"
"net/http/httptest"
"testing"
)
// TestIssue132_RefreshTokenHonorsUserIdentifierClaim reproduces and verifies
// the fix for issue #132: token refresh path hardcoded the "email" claim and
// ignored the configured userIdentifierClaim. Keycloak users without an email
// claim (using sub or another identifier) were being kicked out on refresh
// even though their initial login worked.
//
// The callback path (auth_flow.go) already honored userIdentifierClaim with
// "sub" fallback. The refresh path (token_manager.go) had drifted out of sync
// after PR #100 (commit a316a98).
func TestIssue132_RefreshTokenHonorsUserIdentifierClaim(t *testing.T) {
tests := []struct {
claims map[string]any
name string
userIdentifierClaim string
expectedIdentifier string
expectSuccess bool
}{
{
name: "sub claim configured, only sub present (Keycloak no-email case)",
userIdentifierClaim: "sub",
claims: map[string]any{
"sub": "user-uuid-keycloak-12345",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "user-uuid-keycloak-12345",
},
{
name: "preferred_username configured, claim present",
userIdentifierClaim: "preferred_username",
claims: map[string]any{
"sub": "user-uuid-12345",
"preferred_username": "alice",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "alice",
},
{
name: "configured claim missing, falls back to sub",
userIdentifierClaim: "preferred_username",
claims: map[string]any{
"sub": "fallback-sub-id",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "fallback-sub-id",
},
{
name: "email default, email present (backward compatibility)",
userIdentifierClaim: "email",
claims: map[string]any{
"sub": "user-uuid-12345",
"email": "user@example.com",
"exp": float64(9999999999),
},
expectSuccess: true,
expectedIdentifier: "user@example.com",
},
{
name: "email default, no email and no sub - refresh fails",
userIdentifierClaim: "email",
claims: map[string]any{
"exp": float64(9999999999),
},
expectSuccess: false,
expectedIdentifier: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sessionManager, err := NewSessionManager(
"test-encryption-key-32-bytes-long!!",
false,
"",
"",
0,
NewLogger("error"),
)
if err != nil {
t.Fatalf("session manager: %v", err)
}
defer sessionManager.Shutdown()
capturedClaims := tt.claims
tOidc := &TraefikOidc{
logger: NewLogger("error"),
userIdentifierClaim: tt.userIdentifierClaim,
sessionManager: sessionManager,
tokenExchanger: &EnhancedMockTokenExchanger{
RefreshResponse: &TokenResponse{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
IDToken: "new-id-token-jwt",
ExpiresIn: 3600,
},
},
tokenVerifier: &EnhancedMockTokenVerifier{Err: nil},
extractClaimsFunc: func(token string) (map[string]any, error) {
return capturedClaims, nil
},
}
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("get session: %v", err)
}
defer session.returnToPoolSafely()
session.SetRefreshToken("initial-refresh-token")
refreshed := tOidc.refreshToken(rw, req, session)
if refreshed != tt.expectSuccess {
t.Fatalf("refreshToken() = %v, want %v", refreshed, tt.expectSuccess)
}
if got := session.GetUserIdentifier(); got != tt.expectedIdentifier {
t.Errorf("session.GetUserIdentifier() = %q, want %q", got, tt.expectedIdentifier)
}
})
}
}
+453
View File
@@ -0,0 +1,453 @@
package traefikoidc
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"math/big"
"net/http"
"testing"
"time"
"github.com/gorilla/sessions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
// signGraphStyleAccessToken builds a JWT in Microsoft's Graph proprietary
// nonce-header form: bytes that get signed contain the SHA256 hash of the
// nonce, while the wire token ships the original nonce. A standard JWS
// verifier always rejects these with `crypto/rsa: verification error`, which
// is why Microsoft documents Graph access tokens as opaque to client apps:
//
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
// "you can't validate tokens for Microsoft Graph according to these rules
// due to their proprietary format"
func signGraphStyleAccessToken(t *testing.T, key *rsa.PrivateKey, kid, originalNonce string, claims map[string]any) string {
t.Helper()
wireHeader := map[string]any{
"alg": "RS256",
"kid": kid,
"typ": "JWT",
"nonce": originalNonce,
}
wireHeaderJSON, err := json.Marshal(wireHeader)
require.NoError(t, err)
hashed := sha256.Sum256([]byte(originalNonce))
signedHeader := map[string]any{
"alg": "RS256",
"kid": kid,
"typ": "JWT",
"nonce": fmt.Sprintf("%x", hashed),
}
signedHeaderJSON, err := json.Marshal(signedHeader)
require.NoError(t, err)
claimsJSON, err := json.Marshal(claims)
require.NoError(t, err)
wireHeaderB64 := base64.RawURLEncoding.EncodeToString(wireHeaderJSON)
signedHeaderB64 := base64.RawURLEncoding.EncodeToString(signedHeaderJSON)
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
signedInput := signedHeaderB64 + "." + claimsB64
hSign := sha256.Sum256([]byte(signedInput))
sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, hSign[:])
require.NoError(t, err)
return wireHeaderB64 + "." + claimsB64 + "." + base64.RawURLEncoding.EncodeToString(sig)
}
// newAzureFollowupOIDC produces a TraefikOidc instance wired for an Azure
// AD tenant with a captured error log buffer. Used by the issue #134 followup
// tests to assert log behavior during validateAzureTokens flows.
func newAzureFollowupOIDC(t *testing.T, jwks *JWKSet) (*TraefikOidc, *bytes.Buffer) {
t.Helper()
tc := newTestCleanup(t)
errBuf := &bytes.Buffer{}
logger := &Logger{
logError: log.New(errBuf, "", 0),
logInfo: log.New(io.Discard, "", 0),
logDebug: log.New(io.Discard, "", 0),
}
tokenCache := tc.addTokenCache(NewTokenCache())
tokenBlacklist := tc.addCache(NewCache())
oidc := &TraefikOidc{
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
clientID: "test-client-id",
audience: "test-client-id",
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
limiter: rate.NewLimiter(rate.Every(time.Second), 100),
logger: logger,
httpClient: &http.Client{Timeout: 10 * time.Second},
jwkCache: &MockJWKCache{JWKS: jwks},
tokenCache: tokenCache,
tokenBlacklist: tokenBlacklist,
extractClaimsFunc: extractClaims,
}
oidc.tokenVerifier = oidc
oidc.jwtVerifier = oidc
require.True(t, oidc.isAzureProvider(), "fixture must be detected as Azure provider")
return oidc, errBuf
}
// authedSessionWithTokens returns a SessionData populated with the supplied
// access and ID tokens, marked authenticated and recently created. The
// SessionManager carries a real ChunkManager so that GetAccessToken /
// GetIDToken / GetRefreshToken behave like the production code path.
func authedSessionWithTokens(t *testing.T, accessToken, idToken string) *SessionData {
t.Helper()
chunkLogger := NewLogger("error")
chunkManager := NewChunkManager(chunkLogger)
t.Cleanup(chunkManager.Shutdown)
sd := CreateMockSessionData()
sd.manager = &SessionManager{
sessionMaxAge: 24 * time.Hour,
chunkManager: chunkManager,
logger: chunkLogger,
}
sd.mainSession = sessions.NewSession(nil, "main")
sd.mainSession.Values["authenticated"] = true
sd.mainSession.Values["created_at"] = time.Now().Unix()
sd.accessSession = sessions.NewSession(nil, "access")
sd.accessSession.Values["token"] = accessToken
sd.accessSession.Values["compressed"] = false
sd.idTokenSession = sessions.NewSession(nil, "id")
sd.idTokenSession.Values["token"] = idToken
sd.idTokenSession.Values["compressed"] = false
sd.refreshSession = sessions.NewSession(nil, "refresh")
sd.refreshSession.Values["token"] = ""
sd.refreshSession.Values["compressed"] = false
return sd
}
// TestIssue134_Followup_GraphAccessTokenReproducesUsersError sanity-checks
// that our crafted Graph-style token reproduces the exact rsa error string
// quoted on the issue thread (dada-engineer 2026-05-08, friek 2026-05-11).
//
// Sanity test: must always pass, regardless of the issue #134 followup fix.
// It exists so a future contributor does not accidentally weaken the
// reproducer and assume the followup fix is no longer needed.
func TestIssue134_Followup_GraphAccessTokenReproducesUsersError(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-followup-kid"
graphToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce", map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "00000003-0000-0000-c000-000000000000",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "user-azure-id",
"scp": "User.Read",
})
parsedJWT, err := parseJWT(graphToken)
require.NoError(t, err)
pubKey := &rsaKey.PublicKey
alg, _ := parsedJWT.Header["alg"].(string)
verifyErr := verifySignatureWithKey(graphToken, pubKey, alg)
require.Error(t, verifyErr)
assert.Contains(t, verifyErr.Error(), "crypto/rsa: verification error",
"reproducer must emit the exact error string reported on issue #134")
}
// TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken is the
// failing-then-passing test for the followup fix.
//
// Symptom (before fix): validateAzureTokens calls verifyToken on every
// JWT-shaped access token. For Microsoft Graph access tokens (the default
// when no custom resource is registered), verification always fails with
// `crypto/rsa: verification error`, generating two error log lines per
// request:
//
// UNKNOWN token verification failed: signature verification failed:
// crypto/rsa: verification error
// DIAGNOSTIC: Signature verification failed for kid=<kid>, alg=RS256:
// crypto/rsa: verification error
//
// Microsoft's own documentation tells client apps not to validate Graph
// access tokens. The fix matches that guidance: when an Azure access token
// carries Microsoft's proprietary `nonce` JWT header, treat it as opaque
// (skip JWT verification, fall through to ID token validation).
func TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-followup-kid"
jwk := JWK{
Kty: "RSA",
Use: "sig",
Alg: "RS256",
Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
}
jwks := &JWKSet{Keys: []JWK{jwk}}
now := time.Now()
exp := now.Add(time.Hour).Unix()
graphAccessToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce-azure-graph", map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "00000003-0000-0000-c000-000000000000",
"exp": exp,
"iat": now.Unix(),
"sub": "user-azure-id",
"appid": "test-client-id",
"scp": "User.Read",
})
idToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": exp,
"iat": now.Add(-2 * time.Minute).Unix(),
"nbf": now.Add(-2 * time.Minute).Unix(),
"sub": "user-azure-id",
"email": "user@example.com",
"nonce": "id-token-oidc-nonce",
"jti": "id-token-jti-followup",
})
require.NoError(t, err)
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, graphAccessToken, idToken)
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
output := errBuf.String()
assert.NotContains(t, output, "crypto/rsa: verification error",
"validateAzureTokens must not log rsa verification error for Graph-style access tokens; got: %q", output)
assert.NotContains(t, output, "DIAGNOSTIC: Signature verification failed",
"DIAGNOSTIC line must not fire for Graph-style access tokens; got: %q", output)
assert.NotContains(t, output, "UNKNOWN token verification failed",
"UNKNOWN classification log must not fire for Graph-style access tokens; got: %q", output)
assert.True(t, authenticated, "session must remain authenticated via the ID token fallback")
assert.False(t, needsRefresh, "valid ID token must not signal a refresh need")
assert.False(t, expired, "valid ID token must not be reported as expired")
}
// TestIssue134_Followup_IsUnverifiableAzureAccessToken_Detection covers the
// classifier added by the followup fix. Pure-function unit test for the
// Microsoft proprietary marker we rely on (nonce in JWT header).
func TestIssue134_Followup_IsUnverifiableAzureAccessToken_Detection(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-detection-kid"
standardToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "user-azure-id",
})
require.NoError(t, err)
graphToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce", map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "00000003-0000-0000-c000-000000000000",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "user-azure-id",
"scp": "User.Read",
})
oidc, _ := newAzureFollowupOIDC(t, &JWKSet{})
cases := []struct {
name string
token string
wantUnverified bool
}{
{name: "standard JWT without nonce header", token: standardToken, wantUnverified: false},
{name: "Microsoft proprietary token (nonce in header)", token: graphToken, wantUnverified: true},
{name: "garbage token treated as unverifiable", token: "not-a-jwt-at-all", wantUnverified: true},
{name: "empty token treated as unverifiable", token: "", wantUnverified: true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := oidc.isUnverifiableAzureAccessToken(tc.token)
assert.Equal(t, tc.wantUnverified, got)
})
}
}
// TestIssue134_Followup_StandardAzureAccessTokenStillVerifies guards against
// regression in the happy path: an access token issued for our own clientID
// (custom Azure-registered API) — no proprietary nonce header, signed normally
// — must still flow through the standard verification path and authenticate.
func TestIssue134_Followup_StandardAzureAccessTokenStillVerifies(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-standard-kid"
jwk := JWK{
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
}
jwks := &JWKSet{Keys: []JWK{jwk}}
now := time.Now()
exp := now.Add(time.Hour).Unix()
// Custom-resource access token: aud points to the app, no nonce header.
accessToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": exp,
"iat": now.Add(-2 * time.Minute).Unix(),
"nbf": now.Add(-2 * time.Minute).Unix(),
"sub": "user-azure-id",
"scp": "api.read",
"jti": "standard-access-jti",
})
require.NoError(t, err)
idToken, err := createTestJWT(rsaKey, "RS256", kid, map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": exp,
"iat": now.Add(-2 * time.Minute).Unix(),
"nbf": now.Add(-2 * time.Minute).Unix(),
"sub": "user-azure-id",
"email": "user@example.com",
"nonce": "id-token-oidc-nonce",
"jti": "standard-id-jti",
})
require.NoError(t, err)
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, accessToken, idToken)
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
assert.True(t, authenticated, "standard Azure access token must verify and authenticate")
assert.False(t, needsRefresh)
assert.False(t, expired)
assert.NotContains(t, errBuf.String(), "crypto/rsa: verification error",
"standard Azure token must not produce signature errors")
}
// TestIssue134_Followup_GraphAccessTokenWithoutIDToken covers the edge where
// the session has only a Graph access token (no ID token). The classifier must
// preserve the existing "treat as opaque" semantics for backward compatibility:
// authenticated=true even when there is no ID token to verify.
func TestIssue134_Followup_GraphAccessTokenWithoutIDToken(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-no-idt-kid"
jwk := JWK{
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
}
jwks := &JWKSet{Keys: []JWK{jwk}}
graphAccessToken := signGraphStyleAccessToken(t, rsaKey, kid, "wire-only-nonce-no-idt", map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "00000003-0000-0000-c000-000000000000",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "user-azure-id",
"scp": "User.Read",
})
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, graphAccessToken, "")
rs := (&requestState{}).captureSession(session)
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
assert.True(t, authenticated, "Graph token without ID token must remain authenticated (matches existing opaque-token semantics)")
assert.False(t, needsRefresh)
assert.False(t, expired)
assert.NotContains(t, errBuf.String(), "crypto/rsa: verification error")
}
// TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification proves
// the classifier is not a security regression. An attacker who forges a JWT
// with a `nonce` JWT header (Microsoft's proprietary marker) but a payload
// claiming `aud=our-clientID` should NOT gain authenticated status simply by
// triggering the "treat as opaque" branch.
//
// This is the confused-deputy guardrail Microsoft warns about
// (https://cwe.mitre.org/data/definitions/441.html): we treat the access token
// as opaque, which means we DO NOT authorize from it — authorization comes
// only from a separately verifiable ID token. An attacker without a valid ID
// token must not be authenticated.
func TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
attackerKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-attack-kid"
jwk := JWK{
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(rsaKey.E)).Bytes()),
}
jwks := &JWKSet{Keys: []JWK{jwk}}
// Forged: attacker uses their OWN key, sets aud = our clientID, plants a
// `nonce` header to trip the opaque-detection path.
forgedAccessToken := signGraphStyleAccessToken(t, attackerKey, kid, "attacker-nonce", map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "attacker",
"scp": "admin",
})
// Forged ID token signed with the attacker's key — must fail verification
// against the tenant JWKS.
forgedIDToken, err := createTestJWT(attackerKey, "RS256", kid, map[string]any{
"iss": "https://login.microsoftonline.com/tenant-id/v2.0",
"aud": "test-client-id",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Minute).Unix(),
"nbf": time.Now().Add(-2 * time.Minute).Unix(),
"sub": "attacker",
"email": "attacker@evil.example",
"nonce": "id-token-oidc-nonce",
"jti": "attacker-id-jti",
})
require.NoError(t, err)
oidc, _ := newAzureFollowupOIDC(t, jwks)
session := authedSessionWithTokens(t, forgedAccessToken, forgedIDToken)
rs := (&requestState{}).captureSession(session)
authenticated, _, _ := oidc.validateAzureTokensRS(rs)
assert.False(t, authenticated,
"attacker's forged tokens must not authenticate even when the access token has a nonce header — ID token verification rejects the wrong-key signature")
}
+256
View File
@@ -0,0 +1,256 @@
package traefikoidc
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"log"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError reproduces and
// verifies the fix for issue #134.
//
// Symptom (before fix): with a Redis backend wired into UniversalCache,
// caching the parsed *parsedJWKS triggered:
//
// json: cannot unmarshal number 2251513...
// into Go value of type float64
//
// Root cause: under yaegi, json.Marshal of a struct exposes unexported
// fields with an X-prefixed name. parsedJWKS{ keys map[string]crypto.PublicKey }
// thus serialized the inner *rsa.PublicKey, whose modulus *big.Int marshals
// as a JSON number hundreds of digits long. On read, json.Unmarshal into
// interface{} parses numbers as float64, which cannot represent that range.
// The user saw the error log on every request even though auth still worked
// (fallback path rebuilt the keys in memory).
//
// Fix: route both *JWKSet and *parsedJWKS through SetLocal/GetLocal — the
// distributed backend never sees them.
func TestIssue134_AzureRSAJWKSDistributedCacheNoFloatError(t *testing.T) {
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
redisCfg := backends.DefaultRedisConfig(mr.Addr())
redisCfg.RedisPrefix = "issue134:"
backend, err := backends.NewRedisBackend(redisCfg)
require.NoError(t, err)
defer backend.Close()
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-test-kid"
jwk := JWK{
Kty: "RSA",
Use: "sig",
Alg: "RS256",
Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
}
var fetchCount int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&fetchCount, 1)
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
}))
defer server.Close()
errBuf := &bytes.Buffer{}
infoBuf := &bytes.Buffer{}
logger := &Logger{
logError: log.New(errBuf, "", 0),
logInfo: log.New(infoBuf, "", 0),
logDebug: log.New(io.Discard, "", 0),
}
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeJWK,
MaxSize: 100,
Logger: logger,
}, backend)
defer cache.Close()
jwkCache := &JWKCache{cache: cache}
ctx := context.Background()
pub1, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
require.NoError(t, err, "first GetPublicKey should succeed")
require.NotNil(t, pub1)
gotRSA, ok := pub1.(*rsa.PublicKey)
require.True(t, ok, "returned key should be *rsa.PublicKey, got %T", pub1)
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N), "modulus must survive intact")
assert.Equal(t, rsaKey.E, gotRSA.E, "exponent must survive intact")
pub2, err := jwkCache.GetPublicKey(ctx, server.URL, kid, http.DefaultClient)
require.NoError(t, err, "second GetPublicKey should succeed")
require.True(t, samePublicKey(pub1, pub2), "second call must return the same parsed key (cache hit)")
assert.Equal(t, int32(1), atomic.LoadInt32(&fetchCount),
"upstream JWKS endpoint must be hit exactly once; second call must be served from local cache")
errOutput := errBuf.String()
assert.NotContains(t, errOutput, "Failed to deserialize",
"deserialize error must not appear with the fix in place; got: %s", errOutput)
assert.NotContains(t, errOutput, "into Go value of type float64",
"float64 unmarshal error must not appear; got: %s", errOutput)
parsedKey := server.URL + parsedKeysSuffix
jwksKey := server.URL
for _, k := range []string{cache.prefixKey(parsedKey), cache.prefixKey(jwksKey)} {
fullKey := redisCfg.RedisPrefix + k
assert.False(t, mr.Exists(fullKey),
"key %q must not exist in Redis (local-only caching); got %v", fullKey, mr.Keys())
}
}
// TestIssue134_StalePoisonedRedisDataIgnored verifies that pre-existing bad
// data left in Redis under a JWK :parsed key from a prior buggy version is
// ignored: the local-only fix never reads that key, so no log spam, and the
// fallback path returns a real *rsa.PublicKey.
func TestIssue134_StalePoisonedRedisDataIgnored(t *testing.T) {
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
redisCfg := backends.DefaultRedisConfig(mr.Addr())
redisCfg.RedisPrefix = "issue134stale:"
backend, err := backends.NewRedisBackend(redisCfg)
require.NoError(t, err)
defer backend.Close()
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
const kid = "azure-test-kid"
jwk := JWK{
Kty: "RSA", Use: "sig", Alg: "RS256", Kid: kid,
N: base64.RawURLEncoding.EncodeToString(rsaKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big2bytes(rsaKey.E)),
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(JWKSet{Keys: []JWK{jwk}})
}))
defer server.Close()
// Pre-poison Redis with the kind of payload the old buggy path would have
// produced (huge unquoted JSON number for the modulus). With the fix the
// JWKCache must not even read this key.
poisoned := []byte("\x01" + strings.Replace(
`{"Xkeys":{"azure-test-kid":{"N":NUMBER,"E":65537}}}`,
"NUMBER", rsaKey.N.String(), 1,
))
parsedRedisKey := redisCfg.RedisPrefix + "jwk:" + server.URL + parsedKeysSuffix
require.NoError(t, mr.Set(parsedRedisKey, string(poisoned)))
errBuf := &bytes.Buffer{}
logger := &Logger{
logError: log.New(errBuf, "", 0),
logInfo: log.New(io.Discard, "", 0),
logDebug: log.New(io.Discard, "", 0),
}
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeJWK,
MaxSize: 100,
Logger: logger,
}, backend)
defer cache.Close()
jwkCache := &JWKCache{cache: cache}
pub, err := jwkCache.GetPublicKey(context.Background(), server.URL, kid, http.DefaultClient)
require.NoError(t, err)
require.NotNil(t, pub)
gotRSA, ok := pub.(*rsa.PublicKey)
require.True(t, ok)
assert.Equal(t, 0, rsaKey.N.Cmp(gotRSA.N))
assert.NotContains(t, errBuf.String(), "Failed to deserialize",
"poisoned Redis entry must not be touched; got error log: %s", errBuf.String())
}
// TestIssue134_SetLocalGetLocalSkipBackend verifies the new SetLocal/GetLocal
// pair never reads or writes the configured backend.
func TestIssue134_SetLocalGetLocalSkipBackend(t *testing.T) {
mr, err := miniredis.Run()
require.NoError(t, err)
defer mr.Close()
redisCfg := backends.DefaultRedisConfig(mr.Addr())
redisCfg.RedisPrefix = "local:"
backend, err := backends.NewRedisBackend(redisCfg)
require.NoError(t, err)
defer backend.Close()
cache := NewUniversalCacheWithBackend(UniversalCacheConfig{
Type: CacheTypeGeneral,
MaxSize: 10,
Logger: GetSingletonNoOpLogger(),
}, backend)
defer cache.Close()
type unsafeShape struct {
hidden map[string]interface{}
}
val := &unsafeShape{hidden: map[string]interface{}{"k": 1}}
require.NoError(t, cache.SetLocal("local-key", val, 1*time.Hour))
got, found := cache.GetLocal("local-key")
require.True(t, found)
assert.Same(t, val, got, "GetLocal must return the exact pointer stored, no JSON round-trip")
for _, k := range mr.Keys() {
assert.NotContains(t, k, "local-key",
"SetLocal must not write to Redis; found key %q (all keys: %v)", k, mr.Keys())
}
cache.mu.Lock()
delete(cache.items, "local-key")
cache.lruList.Init()
cache.currentSize = 0
cache.currentMemory = 0
cache.mu.Unlock()
_, found = cache.GetLocal("local-key")
assert.False(t, found, "GetLocal must not fall back to backend after local cache cleared")
}
// big2bytes returns the big-endian byte slice for a positive int.
func big2bytes(e int) []byte {
if e <= 0 {
return []byte{}
}
var buf []byte
for e > 0 {
buf = append([]byte{byte(e & 0xff)}, buf...)
e >>= 8
}
return buf
}
// samePublicKey reports whether two crypto.PublicKey instances represent the
// same RSA key, used to confirm cache hits return identical reconstructed
// keys.
func samePublicKey(a, b interface{}) bool {
ar, ok1 := a.(*rsa.PublicKey)
br, ok2 := b.(*rsa.PublicKey)
if !ok1 || !ok2 {
return false
}
return ar.N.Cmp(br.N) == 0 && ar.E == br.E
}
+925
View File
@@ -0,0 +1,925 @@
package traefikoidc
// issue135_regression_test.go — regression tests for RFC 7523 private_key_jwt
// client authentication (issue #135).
//
// These tests guard:
// - Correct JWT construction and cryptographic signature for all supported
// algorithms (RS*/PS*/ES*).
// - Proper validation of alg/key type combinations and empty-kid rejection.
// - JTI uniqueness across concurrent calls.
// - PEM variant tolerance (PKCS#8, PKCS#1, SEC1).
// - Config.Validate() behavior for all private_key_jwt configuration paths.
// - buildClientAssertionSignerFromConfig: inline PEM, file-backed PEM, default alg.
// - Wire-up in exchangeTokens: assertion fields sent, client_secret absent.
// - Wire-up in RevokeTokenWithProvider: assertion fields sent, audience = tokenURL.
// - Back-compat: client_secret_post path unchanged when clientAssertion == nil.
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ── A. Signer unit tests ──────────────────────────────────────────────────────
// TestIssue135_SignerRSAFamily verifies that NewClientAssertionSigner + Sign
// produces a well-formed, cryptographically valid JWT for every RSA-family
// algorithm (RS256/RS384/RS512/PS256/PS384/PS512).
func TestIssue135_SignerRSAFamily(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
cases := []struct {
alg string
hashFn func([]byte) []byte
isPS bool
hash crypto.Hash
}{
{"RS256", func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, false, crypto.SHA256},
{"RS384", func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, false, crypto.SHA384},
{"RS512", func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, false, crypto.SHA512},
{"PS256", func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, true, crypto.SHA256},
{"PS384", func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, true, crypto.SHA384},
{"PS512", func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, true, crypto.SHA512},
}
const (
audience = "https://example.com/token"
clientID = "client-abc"
kid = "kid-1"
)
for _, tc := range cases {
t.Run(tc.alg, func(t *testing.T) {
signer, err := NewClientAssertionSigner(pemBytes, tc.alg, kid)
require.NoError(t, err)
jwtStr, err := signer.Sign(audience, clientID)
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3, "JWT must have three dot-separated parts")
// Decode and check header.
hdr := decodeJSONPart(t, parts[0])
assert.Equal(t, tc.alg, hdr["alg"])
assert.Equal(t, "JWT", hdr["typ"])
assert.Equal(t, kid, hdr["kid"])
// Decode and check claims.
clms := decodeJSONPart(t, parts[1])
assert.Equal(t, clientID, clms["iss"])
assert.Equal(t, clientID, clms["sub"])
assert.Equal(t, audience, clms["aud"])
iat, ok := clms["iat"].(float64)
require.True(t, ok, "iat must be numeric")
exp, ok := clms["exp"].(float64)
require.True(t, ok, "exp must be numeric")
assert.InDelta(t, 60, exp-iat, 2, "exp-iat must equal ~60s")
now := float64(time.Now().Unix())
assert.True(t, iat <= now+2 && iat >= now-5, "iat must be current time ±5s")
jti, ok := clms["jti"].(string)
require.True(t, ok, "jti must be a string")
assert.Len(t, jti, 32, "jti must be 32-char hex (16 bytes → hex)")
// Verify cryptographic signature.
sigInput := parts[0] + "." + parts[1]
digest := tc.hashFn([]byte(sigInput))
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
pub := &rsaKey.PublicKey
if tc.isPS {
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: tc.hash}
assert.NoError(t, rsa.VerifyPSS(pub, tc.hash, digest, sigBytes, opts),
"PSS signature verification failed for %s", tc.alg)
} else {
assert.NoError(t, rsa.VerifyPKCS1v15(pub, tc.hash, digest, sigBytes),
"PKCS1v15 signature verification failed for %s", tc.alg)
}
})
}
}
// TestIssue135_SignerECDSAFamily verifies correct JWT production for all
// ECDSA algorithms (ES256/ES384/ES512) including that the signature is the
// raw r||s encoding (not ASN.1 DER) and is verifiable with the matching key.
func TestIssue135_SignerECDSAFamily(t *testing.T) {
cases := []struct {
alg string
curve elliptic.Curve
hashFn func([]byte) []byte
hash crypto.Hash
}{
{"ES256", elliptic.P256(), func(b []byte) []byte { h := sha256.Sum256(b); return h[:] }, crypto.SHA256},
{"ES384", elliptic.P384(), func(b []byte) []byte { h := sha512.Sum384(b); return h[:] }, crypto.SHA384},
{"ES512", elliptic.P521(), func(b []byte) []byte { h := sha512.Sum512(b); return h[:] }, crypto.SHA512},
}
const (
audience = "https://idp.example.com/token"
clientID = "ec-client"
kid = "ec-kid"
)
for _, tc := range cases {
t.Run(tc.alg, func(t *testing.T) {
ecKey, err := ecdsa.GenerateKey(tc.curve, rand.Reader)
require.NoError(t, err)
pemBytes := encodeECPKCS8(t, ecKey)
signer, err := NewClientAssertionSigner(pemBytes, tc.alg, kid)
require.NoError(t, err)
jwtStr, err := signer.Sign(audience, clientID)
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
byteLen := (tc.curve.Params().BitSize + 7) / 8
assert.Len(t, sigBytes, 2*byteLen,
"ECDSA signature must be raw r||s (2×%d bytes for %s)", byteLen, tc.alg)
r := new(big.Int).SetBytes(sigBytes[:byteLen])
s := new(big.Int).SetBytes(sigBytes[byteLen:])
sigInput := parts[0] + "." + parts[1]
digest := tc.hashFn([]byte(sigInput))
ok := ecdsa.Verify(&ecKey.PublicKey, digest, r, s)
assert.True(t, ok, "ECDSA signature verification failed for %s", tc.alg)
})
}
}
// TestIssue135_SignerRejectsAlgKeyMismatch verifies that the signer constructor
// rejects type mismatches between key type and algorithm, unknown algorithms,
// and an empty kid.
func TestIssue135_SignerRejectsAlgKeyMismatch(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
rsaPEM := encodeRSAPKCS8(t, rsaKey)
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
ecPEM := encodeECPKCS8(t, ecKey)
cases := []struct {
name string
pemBytes []byte
alg string
kid string
wantErr string
}{
{
name: "RSA key with ES256",
pemBytes: rsaPEM,
alg: "ES256",
kid: "k1",
wantErr: "EC key",
},
{
name: "EC key with RS256",
pemBytes: ecPEM,
alg: "RS256",
kid: "k1",
wantErr: "RSA key",
},
{
name: "unknown alg HS256",
pemBytes: rsaPEM,
alg: "HS256",
kid: "k1",
wantErr: "unsupported",
},
{
name: "empty kid",
pemBytes: rsaPEM,
alg: "RS256",
kid: "",
wantErr: "kid must not be empty",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
_, err := NewClientAssertionSigner(tc.pemBytes, tc.alg, tc.kid)
require.Error(t, err)
assert.Contains(t, strings.ToLower(err.Error()), strings.ToLower(tc.wantErr),
"error should mention %q", tc.wantErr)
})
}
}
// TestIssue135_SignerJTIUniqueness signs 50 assertions with the same signer
// and asserts all jti values are distinct. Guards against broken entropy reuse.
func TestIssue135_SignerJTIUniqueness(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "jti-kid")
require.NoError(t, err)
seen := make(map[string]bool, 50)
for i := range 50 {
jwtStr, err := signer.Sign("https://example.com/token", "client-x")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
clms := decodeJSONPart(t, parts[1])
jti, ok := clms["jti"].(string)
require.True(t, ok)
assert.False(t, seen[jti], "jti %q was reused at iteration %d", jti, i)
seen[jti] = true
}
}
// TestIssue135_SignerPEMVariants confirms that all PEM block types understood
// by NewClientAssertionSigner are parsed correctly: PKCS#8 ("PRIVATE KEY"),
// PKCS#1 ("RSA PRIVATE KEY"), and SEC1 ("EC PRIVATE KEY").
func TestIssue135_SignerPEMVariants(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
t.Run("RSA PKCS8", func(t *testing.T) {
pemBytes := encodeRSAPKCS8(t, rsaKey)
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "k1")
require.NoError(t, err)
assertValidRSAJWT(t, rsaKey, signer, "RS256")
})
t.Run("RSA PKCS1", func(t *testing.T) {
der := x509.MarshalPKCS1PrivateKey(rsaKey)
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: der})
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "k1")
require.NoError(t, err)
assertValidRSAJWT(t, rsaKey, signer, "RS256")
})
t.Run("EC PKCS8", func(t *testing.T) {
pemBytes := encodeECPKCS8(t, ecKey)
signer, err := NewClientAssertionSigner(pemBytes, "ES256", "k1")
require.NoError(t, err)
jwtStr, err := signer.Sign("https://example.com/token", "cid")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
})
t.Run("EC SEC1", func(t *testing.T) {
der, err := x509.MarshalECPrivateKey(ecKey)
require.NoError(t, err)
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der})
signer, err := NewClientAssertionSigner(pemBytes, "ES256", "k1")
require.NoError(t, err)
jwtStr, err := signer.Sign("https://example.com/token", "cid")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
})
}
// ── B. Config validation ──────────────────────────────────────────────────────
// TestIssue135_ConfigValidation table-drives Config.Validate() for every
// client-authentication-related validation branch.
func TestIssue135_ConfigValidation(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
validPEM := string(encodeRSAPKCS8(t, rsaKey))
// baseConfig returns the minimum valid config, modified per test case.
base := func() *Config {
return &Config{
ProviderURL: "https://idp.example.com",
CallbackURL: "/cb",
ClientID: "cid",
ClientSecret: "secret",
SessionEncryptionKey: "01234567890123456789012345678901", // 32 chars
RateLimit: 100,
}
}
cases := []struct {
name string
mutate func(*Config)
wantErr string // empty = expect nil error
}{
{
name: "default empty method + secret ok",
mutate: func(c *Config) { /* nothing extra */ },
wantErr: "",
},
{
name: "explicit client_secret_post + secret ok",
mutate: func(c *Config) {
c.ClientAuthMethod = "client_secret_post"
},
wantErr: "",
},
{
name: "private_key_jwt inline key + kid ok",
mutate: func(c *Config) {
c.ClientAuthMethod = "private_key_jwt"
c.ClientSecret = ""
c.ClientAssertionPrivateKey = validPEM
c.ClientAssertionKeyID = "k1"
},
wantErr: "",
},
{
name: "private_key_jwt no key at all",
mutate: func(c *Config) {
c.ClientAuthMethod = "private_key_jwt"
c.ClientSecret = ""
c.ClientAssertionKeyID = "k1"
},
wantErr: "clientAssertionPrivateKey",
},
{
name: "private_key_jwt both inline and path",
mutate: func(c *Config) {
c.ClientAuthMethod = "private_key_jwt"
c.ClientSecret = ""
c.ClientAssertionPrivateKey = validPEM
c.ClientAssertionKeyPath = "/tmp/key.pem"
c.ClientAssertionKeyID = "k1"
},
wantErr: "only one of",
},
{
name: "private_key_jwt key but no kid",
mutate: func(c *Config) {
c.ClientAuthMethod = "private_key_jwt"
c.ClientSecret = ""
c.ClientAssertionPrivateKey = validPEM
},
wantErr: "clientAssertionKeyID",
},
{
name: "private_key_jwt unsupported alg HS256",
mutate: func(c *Config) {
c.ClientAuthMethod = "private_key_jwt"
c.ClientSecret = ""
c.ClientAssertionPrivateKey = validPEM
c.ClientAssertionKeyID = "k1"
c.ClientAssertionAlg = "HS256"
},
wantErr: "is not supported",
},
{
name: "unknown client auth method",
mutate: func(c *Config) {
c.ClientAuthMethod = "weird"
},
wantErr: "is not supported",
},
{
name: "client_secret_post with no secret",
mutate: func(c *Config) {
c.ClientAuthMethod = "client_secret_post"
c.ClientSecret = ""
},
wantErr: "clientSecret is required",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
cfg := base()
tc.mutate(cfg)
err := cfg.Validate()
if tc.wantErr == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.wantErr,
"error must mention %q", tc.wantErr)
}
})
}
}
// TestIssue135_ConfigKeyPathLoadsFile verifies that buildClientAssertionSignerFromConfig
// reads the PEM key from disk when ClientAssertionKeyPath is set.
func TestIssue135_ConfigKeyPathLoadsFile(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
dir := t.TempDir()
keyFile := dir + "/private.pem"
require.NoError(t, os.WriteFile(keyFile, pemBytes, 0o600))
cfg := &Config{
ClientAuthMethod: "private_key_jwt",
ClientAssertionKeyPath: keyFile,
ClientAssertionKeyID: "file-kid",
ClientAssertionAlg: "RS256",
}
signer, err := buildClientAssertionSignerFromConfig(cfg)
require.NoError(t, err, "should load signer from key file")
require.NotNil(t, signer)
// Confirm signer produces a valid JWT.
jwtStr, err := signer.Sign("https://example.com/token", "client-from-file")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3, "should produce a 3-part JWT")
}
// ── C. Wire-up — exchangeTokens ───────────────────────────────────────────────
// TestIssue135_AuthCodeExchangeUsesAssertion confirms that exchangeTokens sends
// client_assertion + client_assertion_type instead of client_secret when a
// ClientAssertionSigner is configured, and that the assertion JWT is valid.
func TestIssue135_AuthCodeExchangeUsesAssertion(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
var capturedBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body := make([]byte, r.ContentLength)
_, _ = r.Body.Read(body)
capturedBody = body
w.Header().Set("Content-Type", "application/json")
// Return a minimal token response so exchangeTokens doesn't error.
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "at",
IDToken: "it",
RefreshToken: "rt",
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "wire-kid")
require.NoError(t, err)
oidc := &TraefikOidc{
clientID: "wire-client",
tokenHTTPClient: server.Client(),
clientAssertion: signer,
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = server.URL
_, err = oidc.exchangeTokens(context.Background(), "authorization_code", "code-x", "https://app/cb", "")
require.NoError(t, err)
form, err := url.ParseQuery(string(capturedBody))
require.NoError(t, err)
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
form.Get("client_assertion_type"), "client_assertion_type must be set")
assertionJWT := form.Get("client_assertion")
assert.NotEmpty(t, assertionJWT, "client_assertion must be present")
assert.Empty(t, form.Get("client_secret"), "client_secret must not be sent when using assertion")
assert.Equal(t, "wire-client", form.Get("client_id"))
assert.Equal(t, "code-x", form.Get("code"))
assert.Equal(t, "authorization_code", form.Get("grant_type"))
// Verify assertion JWT: header, claims, signature.
parts := strings.Split(assertionJWT, ".")
require.Len(t, parts, 3)
hdr := decodeJSONPart(t, parts[0])
assert.Equal(t, "RS256", hdr["alg"])
clms := decodeJSONPart(t, parts[1])
assert.Equal(t, "wire-client", clms["iss"])
assert.Equal(t, "wire-client", clms["sub"])
assert.Equal(t, server.URL, clms["aud"],
"audience must be the tokenURL (RFC 7523 §3)")
// Verify signature with RSA public key.
sigInput := parts[0] + "." + parts[1]
digest := sha256SumBytes([]byte(sigInput))
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
assert.NoError(t, rsa.VerifyPKCS1v15(&rsaKey.PublicKey, crypto.SHA256, digest, sigBytes))
}
// TestIssue135_RefreshTokenUsesAssertion verifies that the refresh_token grant
// type also sends client_assertion and the correct form fields.
func TestIssue135_RefreshTokenUsesAssertion(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
var capturedForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, r.ParseForm())
capturedForm = r.Form
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "new-at",
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "rt-kid")
require.NoError(t, err)
oidc := &TraefikOidc{
clientID: "rt-client",
tokenHTTPClient: server.Client(),
clientAssertion: signer,
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = server.URL
_, err = oidc.exchangeTokens(context.Background(), "refresh_token", "rt-y", "", "")
require.NoError(t, err)
assert.Equal(t, "refresh_token", capturedForm.Get("grant_type"))
assert.Equal(t, "rt-y", capturedForm.Get("refresh_token"))
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
capturedForm.Get("client_assertion_type"))
assert.NotEmpty(t, capturedForm.Get("client_assertion"))
assert.Empty(t, capturedForm.Get("client_secret"))
}
// TestIssue135_BackcompatClientSecretPath confirms that exchangeTokens sends
// client_secret and does NOT send client_assertion when clientAssertion is nil.
func TestIssue135_BackcompatClientSecretPath(t *testing.T) {
var capturedForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, r.ParseForm())
capturedForm = r.Form
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "at",
TokenType: "Bearer",
ExpiresIn: 3600,
})
}))
defer server.Close()
oidc := &TraefikOidc{
clientID: "legacy-client",
clientSecret: "legacy-secret",
tokenHTTPClient: server.Client(),
clientAssertion: nil, // back-compat path
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = server.URL
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "code-bc", "https://app/cb", "")
require.NoError(t, err)
assert.Equal(t, "legacy-secret", capturedForm.Get("client_secret"),
"client_secret must be sent on the classic path")
assert.Empty(t, capturedForm.Get("client_assertion"),
"client_assertion must NOT be present on the classic path")
assert.Empty(t, capturedForm.Get("client_assertion_type"),
"client_assertion_type must NOT be present on the classic path")
}
// TestIssue135_ClientSecretBasicAuth verifies that when clientAuthMethod is
// "client_secret_basic", exchangeTokens sends an HTTP Basic Authorization
// header carrying url-encoded client_id:client_secret per RFC 6749 §2.3.1,
// and that neither client_id nor client_secret appears in the form body.
func TestIssue135_ClientSecretBasicAuth(t *testing.T) {
var capturedAuth string
var capturedForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
require.NoError(t, r.ParseForm())
capturedForm = r.Form
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "at-basic", TokenType: "Bearer", ExpiresIn: 3600,
})
}))
defer server.Close()
oidc := &TraefikOidc{
clientID: "basic-client",
clientSecret: "basic-secret",
clientAuthMethod: "client_secret_basic",
tokenHTTPClient: server.Client(),
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = server.URL
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "code-bb", "https://app/cb", "")
require.NoError(t, err)
require.True(t, strings.HasPrefix(capturedAuth, "Basic "),
"Authorization header must start with 'Basic ', got %q", capturedAuth)
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
require.NoError(t, err, "Authorization payload must be valid base64")
user, pass, ok := strings.Cut(string(raw), ":")
require.True(t, ok, "Authorization payload must contain a single ':' separator")
assert.Equal(t, "basic-client", user, "client_id should round-trip through QueryEscape")
assert.Equal(t, "basic-secret", pass, "client_secret should round-trip through QueryEscape")
assert.Empty(t, capturedForm.Get("client_id"),
"client_id must NOT be in the body when using client_secret_basic")
assert.Empty(t, capturedForm.Get("client_secret"),
"client_secret must NOT be in the body when using client_secret_basic")
assert.Empty(t, capturedForm.Get("client_assertion"),
"client_assertion must NOT be present on the basic-auth path")
}
// TestIssue135_ClientSecretBasicURLEncodesReservedChars verifies that
// credentials containing reserved characters (`:`, `+`, `/`, etc.) are
// form-urlencoded before base64 per RFC 6749 §2.3.1, so the receiving
// authorization server can decode them deterministically.
func TestIssue135_ClientSecretBasicURLEncodesReservedChars(t *testing.T) {
var capturedAuth string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(TokenResponse{AccessToken: "at", TokenType: "Bearer", ExpiresIn: 3600})
}))
defer server.Close()
const (
clientID = "weird:id+1"
clientSecret = "p@ss/word=&" //nolint:gosec // test fixture
)
oidc := &TraefikOidc{
clientID: clientID,
clientSecret: clientSecret,
clientAuthMethod: "client_secret_basic",
tokenHTTPClient: server.Client(),
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = server.URL
_, err := oidc.exchangeTokens(context.Background(), "authorization_code", "c", "https://app/cb", "")
require.NoError(t, err)
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
require.NoError(t, err)
wantUser := url.QueryEscape(clientID)
wantPass := url.QueryEscape(clientSecret)
assert.Equal(t, wantUser+":"+wantPass, string(raw),
"both halves must be form-urlencoded before the base64 step")
}
// TestIssue135_ClientSecretBasicRevocation verifies that the revocation path
// honors client_secret_basic identically to the token path.
func TestIssue135_ClientSecretBasicRevocation(t *testing.T) {
var capturedAuth string
var capturedForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
require.NoError(t, r.ParseForm())
capturedForm = r.Form
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
oidc := &TraefikOidc{
clientID: "rev-basic",
clientSecret: "rev-secret",
clientAuthMethod: "client_secret_basic",
httpClient: server.Client(),
logger: GetSingletonNoOpLogger(),
}
oidc.tokenURL = "https://idp.example.com/token"
oidc.revocationURL = server.URL
require.NoError(t, oidc.RevokeTokenWithProvider("opaque-tok", "access_token"))
require.True(t, strings.HasPrefix(capturedAuth, "Basic "), "got %q", capturedAuth)
raw, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(capturedAuth, "Basic "))
require.NoError(t, err)
assert.Equal(t, "rev-basic:rev-secret", string(raw))
assert.Equal(t, "opaque-tok", capturedForm.Get("token"))
assert.Equal(t, "access_token", capturedForm.Get("token_type_hint"))
assert.Empty(t, capturedForm.Get("client_id"),
"client_id must NOT be in body on Basic-auth revocation")
assert.Empty(t, capturedForm.Get("client_secret"),
"client_secret must NOT be in body on Basic-auth revocation")
}
// ── D. Wire-up — RevokeTokenWithProvider ────────────────────────────────────
// TestIssue135_RevocationUsesAssertion verifies that RevokeTokenWithProvider
// sends client_assertion (not client_secret), and that the assertion's audience
// is the tokenURL, not the revocationURL (per RFC 7523 §3).
func TestIssue135_RevocationUsesAssertion(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
const (
tokenEndpoint = "https://idp.example.com/token" // audience for assertion
clientIDVal = "revoke-client"
)
var capturedForm url.Values
// Revocation endpoint — deliberate separate URL to confirm audience != revocationURL.
revokeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, r.ParseForm())
capturedForm = r.Form
w.WriteHeader(http.StatusOK)
}))
defer revokeServer.Close()
signer, err := NewClientAssertionSigner(pemBytes, "RS256", "rev-kid")
require.NoError(t, err)
oidc := &TraefikOidc{
clientID: clientIDVal,
clientAssertion: signer,
httpClient: revokeServer.Client(),
logger: GetSingletonNoOpLogger(),
}
// tokenURL drives assertion audience; revocationURL is where the POST goes.
oidc.tokenURL = tokenEndpoint
oidc.revocationURL = revokeServer.URL
err = oidc.RevokeTokenWithProvider("some-token", "refresh_token")
require.NoError(t, err)
assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
capturedForm.Get("client_assertion_type"))
assertionJWT := capturedForm.Get("client_assertion")
assert.NotEmpty(t, assertionJWT)
assert.Empty(t, capturedForm.Get("client_secret"),
"client_secret must not appear in revocation request with assertion")
// Verify the assertion audience is tokenURL (not revocationURL).
parts := strings.Split(assertionJWT, ".")
require.Len(t, parts, 3)
clms := decodeJSONPart(t, parts[1])
assert.Equal(t, tokenEndpoint, clms["aud"],
"assertion audience must be tokenURL, not revocationURL")
// Sanity-check cryptographic validity.
sigInput := parts[0] + "." + parts[1]
digest := sha256SumBytes([]byte(sigInput))
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
assert.NoError(t, rsa.VerifyPKCS1v15(&rsaKey.PublicKey, crypto.SHA256, digest, sigBytes))
}
// ── E. End-to-end via buildClientAssertionSignerFromConfig ───────────────────
// TestIssue135_BuildSignerFromInlineConfig confirms that the full config→signer
// pipeline works for an ES256 key specified inline in the Config struct.
func TestIssue135_BuildSignerFromInlineConfig(t *testing.T) {
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
pemBytes := encodeECPKCS8(t, ecKey)
cfg := &Config{
ClientAuthMethod: "private_key_jwt",
ClientAssertionPrivateKey: string(pemBytes),
ClientAssertionKeyID: "inline-ec-kid",
ClientAssertionAlg: "ES256",
}
signer, err := buildClientAssertionSignerFromConfig(cfg)
require.NoError(t, err)
require.NotNil(t, signer)
jwtStr, err := signer.Sign("https://example.com/token", "inline-client")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
hdr := decodeJSONPart(t, parts[0])
assert.Equal(t, "ES256", hdr["alg"])
assert.Equal(t, "inline-ec-kid", hdr["kid"])
// Verify the EC signature.
byteLen := (elliptic.P256().Params().BitSize + 7) / 8
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
require.Len(t, sigBytes, 2*byteLen)
r := new(big.Int).SetBytes(sigBytes[:byteLen])
s := new(big.Int).SetBytes(sigBytes[byteLen:])
sigInput := parts[0] + "." + parts[1]
digest := sha256SumBytes([]byte(sigInput))
assert.True(t, ecdsa.Verify(&ecKey.PublicKey, digest, r, s))
}
// TestIssue135_BuildSignerDefaultsToRS256 verifies that an empty
// ClientAssertionAlg defaults to RS256.
func TestIssue135_BuildSignerDefaultsToRS256(t *testing.T) {
rsaKey := genRSAKey(t, 2048)
pemBytes := encodeRSAPKCS8(t, rsaKey)
cfg := &Config{
ClientAssertionPrivateKey: string(pemBytes),
ClientAssertionKeyID: "default-alg-kid",
ClientAssertionAlg: "", // intentionally empty
}
signer, err := buildClientAssertionSignerFromConfig(cfg)
require.NoError(t, err)
jwtStr, err := signer.Sign("https://example.com/token", "default-client")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
hdr := decodeJSONPart(t, parts[0])
assert.Equal(t, "RS256", hdr["alg"], "empty alg must default to RS256")
}
// ── Helpers ───────────────────────────────────────────────────────────────────
// genRSAKey generates an RSA key of the given bit size, failing the test on error.
func genRSAKey(t *testing.T, bits int) *rsa.PrivateKey {
t.Helper()
k, err := rsa.GenerateKey(rand.Reader, bits)
require.NoError(t, err)
return k
}
// encodeRSAPKCS8 marshals an RSA key as PKCS#8 PEM ("PRIVATE KEY").
func encodeRSAPKCS8(t *testing.T, key *rsa.PrivateKey) []byte {
t.Helper()
der, err := x509.MarshalPKCS8PrivateKey(key)
require.NoError(t, err)
return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
}
// encodeECPKCS8 marshals an EC key as PKCS#8 PEM ("PRIVATE KEY").
func encodeECPKCS8(t *testing.T, key *ecdsa.PrivateKey) []byte {
t.Helper()
der, err := x509.MarshalPKCS8PrivateKey(key)
require.NoError(t, err)
return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})
}
// decodeJSONPart base64url-decodes a JWT part and parses it as a JSON object.
func decodeJSONPart(t *testing.T, b64url string) map[string]any {
t.Helper()
raw, err := base64.RawURLEncoding.DecodeString(b64url)
require.NoError(t, err, "base64url decode of JWT part failed")
var m map[string]any
require.NoError(t, json.Unmarshal(raw, &m), "JSON unmarshal of JWT part failed")
return m
}
// sha256SumBytes returns the SHA-256 digest of b as a byte slice.
func sha256SumBytes(b []byte) []byte {
h := sha256.Sum256(b)
return h[:]
}
// assertValidRSAJWT signs a JWT with signer and verifies the RS256 signature
// against the given RSA public key. Used by PEM variant tests.
func assertValidRSAJWT(t *testing.T, key *rsa.PrivateKey, signer *ClientAssertionSigner, alg string) {
t.Helper()
jwtStr, err := signer.Sign("https://example.com/token", "pem-client")
require.NoError(t, err)
parts := strings.Split(jwtStr, ".")
require.Len(t, parts, 3)
hdr := decodeJSONPart(t, parts[0])
assert.Equal(t, alg, hdr["alg"])
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
require.NoError(t, err)
sigInput := parts[0] + "." + parts[1]
digest := sha256SumBytes([]byte(sigInput))
assert.NoError(t, rsa.VerifyPKCS1v15(&key.PublicKey, crypto.SHA256, digest, sigBytes))
}
+7 -6
View File
@@ -478,11 +478,10 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
// Test 3: Rate limiting
t.Run("RateLimiting", func(t *testing.T) {
// Reset circuit breaker to closed state for this test
coordinator.circuitBreaker.mutex.Lock()
// Reset circuit breaker to closed state for this test. All fields are
// atomic so we don't need any mutex.
atomic.StoreInt32(&coordinator.circuitBreaker.state, 0) // closed
atomic.StoreInt32(&coordinator.circuitBreaker.failures, 0)
coordinator.circuitBreaker.mutex.Unlock()
// Temporarily increase circuit breaker threshold to not interfere
oldMaxFailures := coordinator.circuitBreaker.config.MaxFailures
@@ -525,9 +524,11 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
time.Sleep(config.CleanupInterval * 3)
// Old sessions should be cleaned up
coordinator.attemptsMutex.RLock()
count := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
count := 0
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
count++
return true
})
// Should have fewer sessions after cleanup
if count > 10 {
+159 -21
View File
@@ -2,6 +2,7 @@ package traefikoidc
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
@@ -18,6 +19,18 @@ import (
"time"
)
// parsedKeysSuffix marks the parallel UniversalCache entry that stores
// pre-parsed public keys for a given JWKS URL.
const parsedKeysSuffix = ":parsed"
// parsedJWKS holds keys decoded from a JWKSet, indexed by kid. Storing the
// already-parsed crypto.PublicKey avoids re-running the DER/PEM round trip
// on every JWT verification — a costly operation under the yaegi interpreter
// that hosts Traefik plugins.
type parsedJWKS struct {
keys map[string]crypto.PublicKey
}
// JWK represents a JSON Web Key as defined in RFC 7517.
// It can represent different key types including RSA, EC, and symmetric keys.
type JWK struct {
@@ -40,15 +53,32 @@ type JWKSet struct {
Keys []JWK `json:"keys"`
}
// JWKCache provides thread-safe caching of JWKS using UniversalCache
// JWKCache provides thread-safe caching of JWKS using UniversalCache.
//
// inflightFetches deduplicates concurrent fetches for the same JWKS URL.
// It replaces a global sync.RWMutex that was previously held for the entire
// HTTP round-trip in GetJWKS: on a cold cache (cold pod, JWK rotation, brief
// network blip) every concurrent request piled up on that single Lock(), and
// under Yaegi each Lock acquisition costs 10-50ms of interpreter-dispatch
// overhead. The singleflight pattern keeps the cold-cache cost O(1) HTTP
// fetch regardless of how many requests are waiting.
type JWKCache struct {
cache *UniversalCache
mutex sync.RWMutex
cache *UniversalCache
inflightFetches sync.Map // map[jwksURL string]*jwksFetch
}
// jwksFetch represents an in-flight JWKS fetch. Done is closed when the fetch
// completes; jwks and err carry the result (one of them is set, never both).
type jwksFetch struct {
done chan struct{}
jwks *JWKSet
err error
}
// JWKCacheInterface defines the contract for JWK caching implementations.
type JWKCacheInterface interface {
GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error)
GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error)
Cleanup()
Close()
}
@@ -62,38 +92,146 @@ func NewJWKCache() *JWKCache {
}
// GetJWKS retrieves JWKS from cache or fetches from the remote URL if not cached.
//
// The entry is stored locally only via SetLocal/GetLocal. Going through a
// distributed backend defeats the cache: JSON round-tripping turns *JWKSet
// into map[string]interface{}, the type assertion below fails, and every
// request refetches from the upstream. JWK rotation is rare and a per-replica
// HTTP fetch on cold cache is cheap, so cross-replica coherence buys nothing.
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
// Check cache first
if cachedValue, found := c.cache.Get(jwksURL); found {
// Fast path: cache hit.
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
if jwks, ok := cachedValue.(*JWKSet); ok {
return jwks, nil
}
}
c.mutex.Lock()
defer c.mutex.Unlock()
// Singleflight: dedupe concurrent fetches per URL key. The first arrival
// performs the HTTP fetch; any later arrival for the same URL waits on
// its done channel and shares the result. No global lock is held during
// the fetch.
candidate := &jwksFetch{done: make(chan struct{})}
if existing, loaded := c.inflightFetches.LoadOrStore(jwksURL, candidate); loaded {
f, _ := existing.(*jwksFetch)
select {
case <-f.done:
return f.jwks, f.err
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Double-check after acquiring lock
if cachedValue, found := c.cache.Get(jwksURL); found {
// We're the leader. Make absolutely sure the result fields and the
// in-flight map entry are cleaned up before any waiter unblocks.
defer func() {
c.inflightFetches.Delete(jwksURL)
close(candidate.done)
}()
// Re-check the cache in case a concurrent fetch completed between our
// initial miss and our LoadOrStore win.
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
if jwks, ok := cachedValue.(*JWKSet); ok {
candidate.jwks = jwks
return jwks, nil
}
}
// Fetch from URL
jwks, err := fetchJWKS(ctx, jwksURL, httpClient)
if err != nil {
candidate.err = err
return nil, err
}
if len(jwks.Keys) == 0 {
candidate.err = fmt.Errorf("JWKS response contains no keys")
return nil, candidate.err
}
// Cache for 1 hour.
_ = c.cache.SetLocal(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
candidate.jwks = jwks
return jwks, nil
}
// GetPublicKey returns the parsed public key for a given kid, fetching and
// caching the JWKS plus its derived parsedJWKS on miss. The parsed entry is
// stored alongside the raw JWKSet under a sibling cache key with the same
// 1-hour TTL, so both invalidate together when the upstream JWKS rotates.
//
// parsedJWKS is stored locally only (SetLocal/GetLocal). Its values are
// crypto.PublicKey interfaces wrapping *rsa.PublicKey/*ecdsa.PublicKey,
// which contain *big.Int that marshals to a hundreds-digit JSON number.
// On a distributed backend round-trip, json.Unmarshal into interface{} would
// try to fit that into float64 and fail with UnmarshalTypeError. Under yaegi
// the unexported parsedJWKS.keys field is exposed via an X-prefixed name on
// Marshal, leaking the modulus into the cached payload (issue #134).
func (c *JWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
parsedKey := jwksURL + parsedKeysSuffix
if v, found := c.cache.GetLocal(parsedKey); found {
if pj, ok := v.(*parsedJWKS); ok {
if k, ok := pj.keys[kid]; ok {
return k, nil
}
}
}
jwks, err := c.GetJWKS(ctx, jwksURL, httpClient)
if err != nil {
return nil, err
}
if len(jwks.Keys) == 0 {
return nil, fmt.Errorf("JWKS response contains no keys")
pj := buildParsedJWKS(jwks)
_ = c.cache.SetLocal(parsedKey, pj, 1*time.Hour) // Safe to ignore: cache failures are non-critical
if k, ok := pj.keys[kid]; ok {
return k, nil
}
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
}
// Cache for 1 hour
_ = c.cache.Set(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
return jwks, nil
// buildParsedJWKS pre-parses every JWK in the set into the matching
// crypto.PublicKey, indexed by kid. Errors on individual keys are skipped so
// a single bad key does not block the rest of the keyset.
func buildParsedJWKS(jwks *JWKSet) *parsedJWKS {
out := make(map[string]crypto.PublicKey, len(jwks.Keys))
for i := range jwks.Keys {
k := &jwks.Keys[i]
if k.Kid == "" {
continue
}
// Skip keys that are not intended for signature verification.
if k.Use != "" && k.Use != "sig" {
continue
}
if len(k.KeyOps) > 0 {
hasVerify := false
for _, op := range k.KeyOps {
if op == "verify" {
hasVerify = true
break
}
}
if !hasVerify {
continue
}
}
var pub crypto.PublicKey
var err error
switch k.Kty {
case "RSA":
pub, err = k.ToRSAPublicKey()
case "EC":
pub, err = k.ToECDSAPublicKey()
default:
continue
}
if err != nil {
continue
}
out[k.Kid] = pub
}
return &parsedJWKS{keys: out}
}
// Cleanup is a no-op as cleanup is handled by UniversalCache
@@ -120,11 +258,11 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
defer func() { _ = resp.Body.Close() }() // Safe to ignore: closing body on defer
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) // Safe to ignore: reading error body for diagnostics
body, _ := io.ReadAll(io.LimitReader(resp.Body, 10*1024)) // Safe to ignore: reading error body for diagnostics
return nil, fmt.Errorf("JWKS fetch failed with status %d: %s", resp.StatusCode, body)
}
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("error reading JWKS response: %w", err)
}
@@ -213,9 +351,9 @@ func (jwk *JWK) ToECDSAPublicKey() (*ecdsa.PublicKey, error) {
// GetKey finds a key by its ID (kid) in the JWKSet.
// Returns nil if no key with the given ID is found.
func (jwks *JWKSet) GetKey(kid string) *JWK {
for _, key := range jwks.Keys {
if key.Kid == kid {
return &key
for i := range jwks.Keys {
if jwks.Keys[i].Kid == kid {
return &jwks.Keys[i]
}
}
return nil
+16 -9
View File
@@ -120,7 +120,7 @@ func getReplayCacheStats() (size int, maxSize int) {
// Parameters:
// - ctx: Parent context for cancellation
// - logger: Logger for debug output (can be nil)
func startReplayCacheCleanup(ctx context.Context, logger *Logger) {
func startReplayCacheCleanup(_ context.Context, logger *Logger) {
registry := GetGlobalTaskRegistry()
// Define the cleanup task function
@@ -528,6 +528,21 @@ func verifyNotBefore(notBefore float64) error {
// - An error if the key parsing fails, the algorithm is unsupported,
// or the signature verification fails
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
}
return verifySignatureWithKey(tokenString, pubKey, alg)
}
// verifySignatureWithKey verifies a JWT signature using an already-parsed
// public key, skipping the PEM-encode/decode round trip that verifySignature
// performs. This is the hot path used by VerifyJWTSignatureAndClaims.
func verifySignatureWithKey(tokenString string, pubKey crypto.PublicKey, alg string) error {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
@@ -537,14 +552,6 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
}
var hashFunc crypto.Hash
switch alg {
case "RS256", "PS256", "ES256":
+505
View File
@@ -0,0 +1,505 @@
// Package traefikoidc provides OIDC authentication middleware for Traefik.
// This file implements OIDC Backchannel Logout (OpenID Connect Back-Channel Logout 1.0)
// and Front-Channel Logout (OpenID Connect Front-Channel Logout 1.0) functionality.
package traefikoidc
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
const (
// logoutTokenType is the expected typ claim for logout tokens
// #nosec G101 -- This is a JWT type claim value from OIDC spec, not a credential
logoutTokenType = "logout+jwt"
// sessionInvalidationTTL is how long to remember invalidated sessions
// Should be at least as long as your session max age
sessionInvalidationTTL = 25 * time.Hour
)
// LogoutTokenClaims represents the claims in an OIDC logout token
// as defined in OpenID Connect Back-Channel Logout 1.0
type LogoutTokenClaims struct {
Issuer string `json:"iss"`
Subject string `json:"sub,omitempty"`
Audience interface{} `json:"aud"` // Can be string or []string
IssuedAt int64 `json:"iat"`
JTI string `json:"jti"`
Events map[string]interface{} `json:"events"`
SessionID string `json:"sid,omitempty"`
Nonce string `json:"nonce,omitempty"` // Must NOT be present
}
// handleBackchannelLogout processes OIDC Backchannel Logout requests.
// It accepts POST requests with a logout_token parameter containing a JWT
// that identifies which session(s) to terminate.
//
// According to OpenID Connect Back-Channel Logout 1.0:
// - The logout_token is a JWT signed by the IdP
// - It contains either a 'sid' (session ID) or 'sub' (subject) claim to identify the session
// - The RP must validate the token and invalidate the matching session(s)
//
// Parameters:
// - rw: The HTTP response writer
// - req: The HTTP request containing the logout_token
func (t *TraefikOidc) handleBackchannelLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing backchannel logout request")
// Backchannel logout must be POST
if req.Method != http.MethodPost {
t.logger.Errorf("Backchannel logout: invalid method %s, expected POST", req.Method)
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse form data to get logout_token
if err := req.ParseForm(); err != nil {
t.logger.Errorf("Backchannel logout: failed to parse form: %v", err)
http.Error(rw, "Bad request", http.StatusBadRequest)
return
}
logoutToken := req.FormValue("logout_token")
if logoutToken == "" {
// Also try reading from request body as raw JWT
body, err := io.ReadAll(io.LimitReader(req.Body, 64*1024)) // 64KB limit
if err == nil && len(body) > 0 {
logoutToken = string(body)
}
}
if logoutToken == "" {
t.logger.Error("Backchannel logout: missing logout_token")
http.Error(rw, "logout_token required", http.StatusBadRequest)
return
}
// Parse and validate the logout token
claims, err := t.validateLogoutToken(logoutToken)
if err != nil {
t.logger.Errorf("Backchannel logout: token validation failed: %v", err)
// Return 400 for invalid token per spec
http.Error(rw, "Invalid logout token", http.StatusBadRequest)
return
}
// Invalidate session(s) based on sid or sub
if err := t.invalidateSession(claims.SessionID, claims.Subject); err != nil {
t.logger.Errorf("Backchannel logout: failed to invalidate session: %v", err)
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
return
}
t.logger.Infof("Backchannel logout: successfully invalidated session (sid=%s, sub=%s)",
claims.SessionID, claims.Subject)
// Return 200 OK with empty body per spec
rw.WriteHeader(http.StatusOK)
}
// handleFrontchannelLogout processes OIDC Front-Channel Logout requests.
// It accepts GET requests with 'iss' and 'sid' query parameters that identify
// which session to terminate. The IdP typically loads this URL in an iframe.
//
// According to OpenID Connect Front-Channel Logout 1.0:
// - The request contains 'iss' (issuer) and optionally 'sid' (session ID)
// - The RP should clear the session and return a response (typically empty or image)
// - The response must be cacheable to allow the IdP to load it in an iframe
//
// Parameters:
// - rw: The HTTP response writer
// - req: The HTTP request containing iss and sid parameters
func (t *TraefikOidc) handleFrontchannelLogout(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Processing front-channel logout request")
// Front-channel logout should be GET
if req.Method != http.MethodGet {
t.logger.Errorf("Front-channel logout: invalid method %s, expected GET", req.Method)
http.Error(rw, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Get iss and sid from query parameters
iss := req.URL.Query().Get("iss")
sid := req.URL.Query().Get("sid")
// Validate issuer matches our expected issuer
t.metadataMu.RLock()
expectedIssuer := t.issuerURL
t.metadataMu.RUnlock()
// Require a matching issuer. An empty iss must be rejected too: accepting a
// missing issuer would let an unauthenticated attacker force-logout any
// session whose sid is known by simply omitting iss.
if iss == "" || iss != expectedIssuer {
t.logger.Errorf("Front-channel logout: issuer validation failed: got %q, expected %q", iss, expectedIssuer)
http.Error(rw, "Invalid issuer", http.StatusBadRequest)
return
}
// Must have at least sid for front-channel logout
if sid == "" {
t.logger.Error("Front-channel logout: missing sid parameter")
http.Error(rw, "sid parameter required", http.StatusBadRequest)
return
}
// Invalidate the session
if err := t.invalidateSession(sid, ""); err != nil {
t.logger.Errorf("Front-channel logout: failed to invalidate session: %v", err)
http.Error(rw, "Failed to invalidate session", http.StatusInternalServerError)
return
}
t.logger.Infof("Front-channel logout: successfully invalidated session (sid=%s)", sid)
// Return a minimal HTML response that's suitable for iframe loading
// Set headers to allow embedding and caching
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.Header().Set("Cache-Control", "no-cache, no-store")
rw.Header().Set("Pragma", "no-cache")
// Allow embedding in iframes from any origin (required for front-channel logout)
rw.Header().Del("X-Frame-Options")
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte("<!DOCTYPE html><html><head><title>Logged Out</title></head><body></body></html>"))
}
// validateLogoutToken parses and validates a logout token JWT.
// It verifies the token signature, issuer, audience, and required claims.
//
// Parameters:
// - tokenString: The raw JWT logout token
//
// Returns:
// - The parsed logout token claims
// - An error if validation fails
func (t *TraefikOidc) validateLogoutToken(tokenString string) (*LogoutTokenClaims, error) {
// Parse the JWT
jwt, err := parseJWT(tokenString)
if err != nil {
return nil, fmt.Errorf("failed to parse logout token: %w", err)
}
// Check token type if present
if typ, ok := jwt.Header["typ"].(string); ok {
// The typ should be "logout+jwt" or omitted
if typ != "" && typ != logoutTokenType && typ != "JWT" {
return nil, fmt.Errorf("invalid token type: %s", typ)
}
}
// Verify signature only (not standard claims - logout tokens don't have 'exp')
if err := t.verifyLogoutTokenSignature(jwt, tokenString); err != nil {
return nil, fmt.Errorf("signature verification failed: %w", err)
}
// Extract claims
claims := &LogoutTokenClaims{}
claimsJSON, err := json.Marshal(jwt.Claims)
if err != nil {
return nil, fmt.Errorf("failed to marshal claims: %w", err)
}
if err := json.Unmarshal(claimsJSON, claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
}
// Validate required claims
t.metadataMu.RLock()
expectedIssuer := t.issuerURL
t.metadataMu.RUnlock()
// Validate issuer
if claims.Issuer != expectedIssuer {
return nil, fmt.Errorf("issuer mismatch: got %s, expected %s", claims.Issuer, expectedIssuer)
}
// Validate audience (must contain our client_id)
if !t.validateLogoutTokenAudience(claims.Audience) {
return nil, fmt.Errorf("audience validation failed")
}
// Validate iat (issued at) - must be present and not too old
if claims.IssuedAt == 0 {
return nil, fmt.Errorf("missing iat claim")
}
iatTime := time.Unix(claims.IssuedAt, 0)
// Allow up to 5 minutes clock skew and 10 minutes token age
if time.Since(iatTime) > 15*time.Minute {
return nil, fmt.Errorf("logout token too old: issued at %v", iatTime)
}
// Token should not be from the future (with 5 min clock skew tolerance)
if iatTime.After(time.Now().Add(5 * time.Minute)) {
return nil, fmt.Errorf("logout token issued in the future: %v", iatTime)
}
// Validate events claim - must contain the logout event
if claims.Events == nil {
return nil, fmt.Errorf("missing events claim")
}
if _, ok := claims.Events["http://schemas.openid.net/event/backchannel-logout"]; !ok {
return nil, fmt.Errorf("missing backchannel-logout event in events claim")
}
// Validate that nonce is NOT present (per spec)
if claims.Nonce != "" {
return nil, fmt.Errorf("nonce claim must not be present in logout token")
}
// Must have either sid or sub (or both)
if claims.SessionID == "" && claims.Subject == "" {
return nil, fmt.Errorf("logout token must contain either sid or sub claim")
}
return claims, nil
}
// validateLogoutTokenAudience checks if the logout token audience contains our client_id
func (t *TraefikOidc) validateLogoutTokenAudience(aud interface{}) bool {
switch v := aud.(type) {
case string:
return v == t.clientID
case []interface{}:
for _, a := range v {
if s, ok := a.(string); ok && s == t.clientID {
return true
}
}
case []string:
for _, a := range v {
if a == t.clientID {
return true
}
}
}
return false
}
// verifyLogoutTokenSignature verifies only the signature of a logout token.
// Unlike VerifyJWTSignatureAndClaims, this does NOT validate standard claims like 'exp'
// because logout tokens don't have an expiration claim per OIDC Back-Channel Logout spec.
//
// Parameters:
// - jwt: The parsed JWT structure
// - tokenString: The raw token string for signature verification
//
// Returns:
// - An error if signature verification fails
func (t *TraefikOidc) verifyLogoutTokenSignature(jwt *JWT, tokenString string) error {
t.logger.Debug("Verifying logout token signature")
// Read jwksURL with RLock
t.metadataMu.RLock()
jwksURL := t.jwksURL
t.metadataMu.RUnlock()
jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
if jwks == nil {
return fmt.Errorf("JWKS is nil, cannot verify token")
}
kid, ok := jwt.Header["kid"].(string)
if !ok || kid == "" {
return fmt.Errorf("missing key ID in token header")
}
alg, ok := jwt.Header["alg"].(string)
if !ok || alg == "" {
return fmt.Errorf("missing algorithm in token header")
}
// Find the matching key in JWKS
var matchingKey *JWK
for i := range jwks.Keys {
if jwks.Keys[i].Kid == kid {
matchingKey = &jwks.Keys[i]
break
}
}
if matchingKey == nil {
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
publicKeyPEM, err := jwkToPEM(matchingKey)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
}
if err := verifySignature(tokenString, publicKeyPEM, alg); err != nil {
return fmt.Errorf("signature verification failed: %w", err)
}
t.logger.Debug("Logout token signature verified successfully")
return nil
}
// invalidateSession marks a session as invalidated in the session invalidation cache.
// It stores entries by both sid and sub if available.
//
// Parameters:
// - sid: The session ID to invalidate (from the 'sid' claim)
// - sub: The subject to invalidate (from the 'sub' claim)
//
// Returns:
// - An error if the invalidation fails
func (t *TraefikOidc) invalidateSession(sid, sub string) error {
if t.sessionInvalidationCache == nil {
return fmt.Errorf("session invalidation cache not initialized")
}
now := time.Now().Unix()
// Store by session ID
if sid != "" {
key := t.buildSessionInvalidationKey("sid", sid)
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
t.logger.Debugf("Invalidated session by sid: %s", sid)
}
// Store by subject (invalidates all sessions for this user)
if sub != "" {
key := t.buildSessionInvalidationKey("sub", sub)
t.sessionInvalidationCache.Set(key, now, sessionInvalidationTTL)
t.logger.Debugf("Invalidated session by sub: %s", sub)
}
return nil
}
// isSessionInvalidated checks if a session has been invalidated via backchannel
// or front-channel logout.
//
// Parameters:
// - sid: The session ID to check
// - sub: The subject to check
// - sessionCreatedAt: When the session was created (to compare against invalidation time)
//
// Returns:
// - true if the session has been invalidated, false otherwise
func (t *TraefikOidc) isSessionInvalidated(sid, sub string, sessionCreatedAt time.Time) bool {
if t.sessionInvalidationCache == nil {
return false
}
// Truncate session creation time to seconds for fair comparison with Unix timestamps
sessionCreatedAtSec := sessionCreatedAt.Truncate(time.Second)
// Check by session ID first (more specific)
if sid != "" {
key := t.buildSessionInvalidationKey("sid", sid)
if val, found := t.sessionInvalidationCache.Get(key); found {
if invalidatedAt, ok := val.(int64); ok {
// Session was invalidated at or after it was created
invalidationTime := time.Unix(invalidatedAt, 0)
if !invalidationTime.Before(sessionCreatedAtSec) {
t.logger.Debugf("Session invalidated by sid: %s", sid)
return true
}
}
}
}
// Check by subject (all sessions for this user)
if sub != "" {
key := t.buildSessionInvalidationKey("sub", sub)
if val, found := t.sessionInvalidationCache.Get(key); found {
if invalidatedAt, ok := val.(int64); ok {
// Sessions for this subject created at or before invalidation are invalid
invalidationTime := time.Unix(invalidatedAt, 0)
if !invalidationTime.Before(sessionCreatedAtSec) {
t.logger.Debugf("Session invalidated by sub: %s", sub)
return true
}
}
}
}
return false
}
// buildSessionInvalidationKey creates a cache key for session invalidation
func (t *TraefikOidc) buildSessionInvalidationKey(keyType, value string) string {
return fmt.Sprintf("session_invalidation:%s:%s", keyType, value)
}
// extractSessionInfo extracts sid and sub from an ID token for session tracking
func (t *TraefikOidc) extractSessionInfo(idToken string) (sid, sub string, createdAt time.Time) {
if idToken == "" {
return "", "", time.Time{}
}
jwt, err := parseJWT(idToken)
if err != nil {
return "", "", time.Time{}
}
// Extract sid (session ID)
if sidVal, ok := jwt.Claims["sid"].(string); ok {
sid = sidVal
}
// Extract sub (subject)
if subVal, ok := jwt.Claims["sub"].(string); ok {
sub = subVal
}
// Extract iat for session creation time
if iatVal, ok := jwt.Claims["iat"].(float64); ok {
createdAt = time.Unix(int64(iatVal), 0)
} else {
// Default to now if iat not present
createdAt = time.Now()
}
return sid, sub, createdAt
}
// determineLogoutPath checks if the given path matches any logout URL
func (t *TraefikOidc) determineLogoutPath(path string) string {
// Check backchannel logout path
if t.backchannelLogoutPath != "" && path == t.backchannelLogoutPath {
return "backchannel"
}
// Check front-channel logout path
if t.frontchannelLogoutPath != "" && path == t.frontchannelLogoutPath {
return "frontchannel"
}
// Check regular logout path (for RP-initiated logout)
if path == t.logoutURLPath {
return "rp"
}
return ""
}
// normalizeLogoutPath ensures logout paths start with / and prevents open redirects
func normalizeLogoutPath(path string) string {
if path == "" {
return ""
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
// Prevent open redirect: ensure second character is not / or \
// This prevents URLs like //example.com or /\example.com from being treated as absolute URLs
if len(path) > 1 && (path[1] == '/' || path[1] == '\\') {
// Strip leading slashes/backslashes and re-normalize
path = strings.TrimLeft(path, "/\\")
if path != "" {
path = "/" + path
}
}
return path
}
+1667
View File
File diff suppressed because it is too large Load Diff
+242 -32
View File
@@ -9,6 +9,7 @@ import (
"encoding/hex"
"fmt"
"net/http"
"net/url"
"os"
"runtime"
"strings"
@@ -16,6 +17,7 @@ import (
"text/template"
"time"
telemetry "github.com/lukaszraczylo/oss-telemetry"
"golang.org/x/time/rate"
)
@@ -23,6 +25,11 @@ const (
ConstSessionTimeout = 86400
)
// telemetryStartupOnce keeps the anonymous "plugin loaded" ping to one per
// process. Traefik calls New once per route that uses the plugin; oss-telemetry
// does not deduplicate client-side (the server does), so the gate stays here.
var telemetryStartupOnce sync.Once
// isTestMode detects if the code is running in a test environment.
func isTestMode() bool {
if os.Getenv("SUPPRESS_DIAGNOSTIC_LOGS") == "1" {
@@ -89,6 +96,13 @@ var defaultExcludedURLs = map[string]struct{}{
// - The configured TraefikOidc handler ready to process requests.
// - An error if essential configuration is missing or invalid (e.g., short encryption key).
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
telemetryStartupOnce.Do(func() {
// Only stamped release builds phone home; dev/local/test builds keep the
// devPluginVersion sentinel (see version.go) and stay silent.
if traefikoidcPluginVersion != devPluginVersion {
telemetry.Send("traefikoidc", traefikoidcPluginVersion)
}
})
return NewWithContext(ctx, config, next, name)
}
@@ -99,26 +113,40 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
config = CreateConfig()
}
if config.SessionEncryptionKey == "" {
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
}
logger := NewLogger(config.LogLevel)
if len(config.SessionEncryptionKey) < minEncryptionKeyLength {
if runtime.Compiler == "yaegi" {
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
logger.Infof("Session encryption key is too short; using default key for analyzer")
} else {
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
}
// Fail closed on invalid configuration. Validate() enforces the security
// constraints (required fields, HTTPS-only URLs, key length, excludedURLs
// safety, rate-limit floor, audience format, ...) that were previously
// unenforced because this constructor never called it. Crucially it rejects
// an empty or too-short SessionEncryptionKey instead of silently
// substituting a public hardcoded key, which would let an attacker forge
// any session. Traefik's yaegi plugin analyzer supplies a valid key via
// .traefik.yml testData, so it passes; only misconfigured deployments fail.
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
// Setup HTTP client
caPool, err := config.loadCACertPool()
if err != nil {
return nil, fmt.Errorf("failed to load CA certificates: %w", err)
}
if config.InsecureSkipVerify {
logger.Errorf("SECURITY WARNING: InsecureSkipVerify is enabled for the OIDC provider. TLS certificate verification is DISABLED. Do not use in production.")
}
var httpClient *http.Client
if config.HTTPClient != nil {
httpClient = config.HTTPClient
} else {
httpClient = CreateDefaultHTTPClient()
defaultCfg := DefaultHTTPClientConfig()
defaultCfg.RootCAs = caPool
defaultCfg.InsecureSkipVerify = config.InsecureSkipVerify
httpClient = CreatePooledHTTPClient(defaultCfg)
}
tokenCfg := TokenHTTPClientConfig()
tokenCfg.RootCAs = caPool
tokenCfg.InsecureSkipVerify = config.InsecureSkipVerify
tokenHTTPClient := CreatePooledHTTPClient(tokenCfg)
goroutineWG := &sync.WaitGroup{}
cacheManager := GetGlobalCacheManagerWithConfig(goroutineWG, config)
@@ -155,6 +183,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
introspectionCache: cacheManager.GetSharedIntrospectionCache(), // Cache for introspection results
clientID: config.ClientID,
clientSecret: config.ClientSecret,
clientAuthMethod: func() string {
if config.ClientAuthMethod != "" {
return config.ClientAuthMethod
}
return "client_secret_post"
}(),
audience: func() string {
if config.Audience != "" {
return config.Audience
@@ -181,6 +215,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
}(),
forceHTTPS: config.ForceHTTPS,
enablePKCE: config.EnablePKCE,
extraAuthParams: config.ExtraAuthParams,
overrideScopes: config.OverrideScopes,
strictAudienceValidation: config.StrictAudienceValidation,
allowOpaqueTokens: config.AllowOpaqueTokens,
@@ -199,9 +234,9 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: cacheManager.GetSharedTokenCache(),
httpClient: httpClient,
tokenHTTPClient: CreateTokenHTTPClient(),
tokenHTTPClient: tokenHTTPClient,
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedUserDomains: createCaseInsensitiveStringMap(config.AllowedUserDomains),
allowedUsers: createCaseInsensitiveStringMap(config.AllowedUsers),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
@@ -212,16 +247,70 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
}
return 60 * time.Second
}(),
tokenCleanupStopChan: make(chan struct{}),
metadataRefreshStopChan: make(chan struct{}),
ctx: pluginCtx,
cancelFunc: cancelFunc,
suppressDiagnosticLogs: isTestMode(),
securityHeadersApplier: config.GetSecurityHeadersApplier(),
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
dcrConfig: config.DynamicClientRegistration,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
minimalHeaders: config.MinimalHeaders,
maxRefreshTokenAge: func() time.Duration {
// 0 (or unset) disables the heuristic; negative is rejected by Validate.
if config.MaxRefreshTokenAgeSeconds > 0 {
return time.Duration(config.MaxRefreshTokenAgeSeconds) * time.Second
}
return 0
}(),
tokenCleanupStopChan: make(chan struct{}),
metadataRefreshStopChan: make(chan struct{}),
ctx: pluginCtx,
cancelFunc: cancelFunc,
suppressDiagnosticLogs: isTestMode(),
securityHeadersApplier: config.GetSecurityHeadersApplier(),
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
dcrConfig: config.DynamicClientRegistration,
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
minimalHeaders: config.MinimalHeaders,
stripAuthCookies: config.StripAuthCookies,
enableBackchannelLogout: config.EnableBackchannelLogout,
enableFrontchannelLogout: config.EnableFrontchannelLogout,
backchannelLogoutPath: normalizeLogoutPath(config.BackchannelLogoutURL),
frontchannelLogoutPath: normalizeLogoutPath(config.FrontchannelLogoutURL),
sessionInvalidationCache: cacheManager.GetSharedSessionInvalidationCache(),
refreshResultCache: cacheManager.GetSharedRefreshResultCache(),
enableBearerAuth: config.EnableBearerAuth,
stripAuthorizationHeader: config.StripAuthorizationHeader,
bearerEmitWWWAuthenticate: config.BearerEmitWWWAuthenticate,
bearerOverridesCookie: config.BearerOverridesCookie,
bearerIdentifierClaim: func() string {
if config.BearerIdentifierClaim != "" {
return config.BearerIdentifierClaim
}
return "sub"
}(),
maxIdentifierLength: func() int {
if config.MaxIdentifierLength > 0 {
return config.MaxIdentifierLength
}
return 256
}(),
maxTokenAge: func() time.Duration {
if config.MaxTokenAgeSeconds > 0 {
return time.Duration(config.MaxTokenAgeSeconds) * time.Second
}
return 24 * time.Hour
}(),
bearerFailureThreshold: func() int {
if config.BearerFailureThreshold > 0 {
return config.BearerFailureThreshold
}
return 20
}(),
bearerFailureWindow: func() time.Duration {
if config.BearerFailureWindowSeconds > 0 {
return time.Duration(config.BearerFailureWindowSeconds) * time.Second
}
return 60 * time.Second
}(),
bearerFailurePenalty: func() time.Duration {
if config.BearerFailurePenaltySeconds > 0 {
return time.Duration(config.BearerFailurePenaltySeconds) * time.Second
}
return 60 * time.Second
}(),
}
// Log audience configuration
@@ -231,15 +320,59 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
t.logger.Debugf("No custom audience specified, using clientID as audience: %s", t.clientID)
}
// Bearer-auth startup validation. The bearer path is M2M-only and demands
// a non-default audience so tokens issued for a different resource cannot
// be replayed against this service. The BearerIdentifierClaim guard blocks
// the `email` claim explicitly — without email_verified enforcement (out of
// scope for M2M), trusting email is a spoofing vector for federated IdPs.
// See spec §7.9 / §13.
if config.EnableBearerAuth {
if config.Audience == "" {
cancelFunc()
return nil, fmt.Errorf("EnableBearerAuth=true requires Audience to be set explicitly (cannot default to clientID — that path accepts ID tokens)")
}
if t.bearerIdentifierClaim == "email" {
cancelFunc()
return nil, fmt.Errorf("enableBearerAuth=true with bearerIdentifierClaim=%q is rejected: email-based identity without email_verified enforcement is a spoofing vector for federated IdPs (use \"sub\" or a custom claim; cookie-path userIdentifierClaim is unaffected)", t.bearerIdentifierClaim)
}
if !config.StrictAudienceValidation {
t.logger.Infof("EnableBearerAuth=true with StrictAudienceValidation=false: recommend enabling strict audience validation for hardening")
}
t.bearerFailureTracker = newBearerFailureTracker(
t.bearerFailureThreshold, t.bearerFailureWindow, t.bearerFailurePenalty,
)
t.logger.Infof("Bearer-token auth enabled: audience=%q identifierClaim=%q stripAuthz=%t bearerOverridesCookie=%t maxTokenAge=%s",
config.Audience, t.bearerIdentifierClaim, t.stripAuthorizationHeader, t.bearerOverridesCookie, t.maxTokenAge)
}
// Convert sessionMaxAge from seconds to duration (0 will use default 24 hours)
sessionMaxAge := time.Duration(config.SessionMaxAge) * time.Second
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger) // Safe to ignore: session manager creation with fallback to defaults
sessionManager, err := NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, config.CookiePrefix, sessionMaxAge, t.logger)
if err != nil {
cancelFunc()
return nil, fmt.Errorf("failed to create session manager: %w", err)
}
t.sessionManager = sessionManager
t.errorRecoveryManager = NewErrorRecoveryManager(t.logger)
// Initialize token resilience manager with default configuration
tokenResilienceConfig := DefaultTokenResilienceConfig()
t.tokenResilienceManager = NewTokenResilienceManager(tokenResilienceConfig, t.logger)
// Coalesces concurrent refresh-token grants per refresh_token to one upstream
// call, preventing the thundering herd that yields invalid_grant when the IdP
// rotates refresh tokens (Zitadel/Authentik default).
t.refreshCoordinator = NewRefreshCoordinator(DefaultRefreshCoordinatorConfig(), t.logger)
if config.ClientAuthMethod == "private_key_jwt" {
signer, err := buildClientAssertionSignerFromConfig(config)
if err != nil {
cancelFunc()
return nil, fmt.Errorf("failed to build client assertion signer: %w", err)
}
t.clientAssertion = signer
}
t.extractClaimsFunc = extractClaims
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
@@ -287,17 +420,22 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
startReplayCacheCleanup(pluginCtx, logger)
// Start memory monitoring for leak detection and performance insights
// Start memory monitoring for leak detection and performance insights.
// The interval is clamped to MinMemoryMonitorInterval (30s) inside
// StartMonitoring; tests that need deterministic sampling should call
// MemoryMonitor.Refresh() directly instead of waiting on a fast ticker.
memoryMonitor := GetGlobalMemoryMonitor()
monitorInterval := 60 * time.Second
if isTestMode() {
monitorInterval = 100 * time.Millisecond // Fast interval for tests
}
memoryMonitor.StartMonitoring(pluginCtx, monitorInterval)
memoryMonitor.StartMonitoring(pluginCtx, DefaultMemoryMonitorInterval)
logger.Debug("Started global memory monitoring")
logger.Debugf("TraefikOidc.New: Final t.scopes initialized to: %v", t.scopes)
// Log callback URL configuration to help diagnose redirect loop issues.
// If callbackURL is a full URL instead of a path, the callback matching
// in ServeHTTP will silently fail because req.URL.Path is compared directly.
logger.Debugf("TraefikOidc.New: callbackURL (redirURLPath) configured as: %q", t.redirURLPath)
logger.Debugf("TraefikOidc.New: logoutURLPath configured as: %q", t.logoutURLPath)
t.providerURL = config.ProviderURL
// Use singleton resource manager for metadata initialization
@@ -305,6 +443,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
// Add reference for this instance
rm.AddReference(name)
registerLiveInstance()
// Initialize metadata in a goroutine with proper tracking
if t.goroutineWG != nil {
@@ -382,13 +521,58 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
// Parameters:
// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints.
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
// SSRF defense (audit ranks 3 & 4): a discovery document is attacker-
// influenced when the provider or its TLS is compromised. Reject any
// discovered endpoint pointed at a blocked address before the plugin issues
// outbound requests to it, so it can never be used to reach the cloud
// metadata service or an internal host.
allowLoopback := false
if pu, err := url.Parse(t.providerURL); err == nil {
allowLoopback = isLoopbackHost(pu.Hostname())
}
sanitize := func(name, raw string) string {
if err := t.validateDiscoveredEndpoint(raw, allowLoopback); err != nil {
t.logger.Errorf("Ignoring discovered %s endpoint %q: %v", name, raw, err)
return ""
}
return raw
}
metadata.JWKSURL = sanitize("jwks_uri", metadata.JWKSURL)
metadata.AuthURL = sanitize("authorization", metadata.AuthURL)
metadata.TokenURL = sanitize("token", metadata.TokenURL)
metadata.RevokeURL = sanitize("revocation", metadata.RevokeURL)
metadata.EndSessionURL = sanitize("end_session", metadata.EndSessionURL)
metadata.RegistrationURL = sanitize("registration", metadata.RegistrationURL)
metadata.IntrospectionURL = sanitize("introspection", metadata.IntrospectionURL)
// The introspection request authenticates with the client secret via HTTP
// Basic, so the endpoint must live on the same host as the operator-
// configured provider; otherwise a poisoned discovery document could
// exfiltrate the client secret to an attacker-controlled host.
if metadata.IntrospectionURL != "" && t.providerURL != "" && !sameHost(metadata.IntrospectionURL, t.providerURL) {
t.logger.Errorf("Ignoring introspection endpoint %q: host does not match configured providerURL", metadata.IntrospectionURL)
metadata.IntrospectionURL = ""
}
// Pin the discovered issuer to the operator-configured provider host. The
// issuer is the trust anchor for JWT issuer validation, so a poisoned
// discovery document advertising an attacker-chosen issuer must never be
// stored. Real providers (Google, Azure, Keycloak, Okta, Auth0) keep the
// issuer on the same host as the configured providerURL. On mismatch, leave
// issuerURL empty/unchanged so downstream issuer validation fails closed
// rather than trusting the attacker-chosen value.
discoveredIssuer := metadata.Issuer
if discoveredIssuer != "" && t.providerURL != "" && !sameHost(discoveredIssuer, t.providerURL) {
t.logger.Errorf("Ignoring discovered issuer %q: host does not match configured providerURL", discoveredIssuer)
discoveredIssuer = ""
}
t.metadataMu.Lock()
t.jwksURL = metadata.JWKSURL
t.scopesSupported = metadata.ScopesSupported // Store supported scopes from discovery
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.issuerURL = discoveredIssuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662)
@@ -398,6 +582,19 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
introspectionURL := t.introspectionURL
registrationURL := t.registrationURL
// Publish the read-mostly URL bundle atomically. Hot-path readers Load
// this directly instead of acquiring metadataMu.RLock per request.
t.metadataSnapshot.Store(&MetadataSnapshot{
IssuerURL: discoveredIssuer,
JWKSURL: metadata.JWKSURL,
TokenURL: metadata.TokenURL,
AuthURL: metadata.AuthURL,
RevocationURL: metadata.RevokeURL,
EndSessionURL: metadata.EndSessionURL,
IntrospectionURL: metadata.IntrospectionURL,
RegistrationURL: metadata.RegistrationURL,
})
t.metadataMu.Unlock()
// Log introspection endpoint availability for opaque token support
@@ -433,6 +630,19 @@ func (t *TraefikOidc) performDynamicClientRegistration() {
t.dcrConfig,
t.providerURL,
)
// Set up storage backend for credentials persistence
if t.dcrConfig.PersistCredentials {
cacheManager := GetGlobalCacheManagerWithConfig(t.goroutineWG, nil)
store, err := NewDCRCredentialsStore(t.dcrConfig, cacheManager, t.logger)
if err != nil {
t.logger.Errorf("Failed to create DCR credentials store: %v", err)
// Continue without persistence - registration will still work
} else {
t.dynamicClientRegistrar.SetStore(store)
t.logger.Debugf("DCR credentials store initialized with backend: %s", t.dcrConfig.StorageBackend)
}
}
}
// Get registration endpoint (from metadata or config override)
+1 -1
View File
@@ -9,7 +9,7 @@ import (
"gopkg.in/yaml.v3"
)
// Config Marshalling Tests
// Config Marshaling Tests
func TestConfig_MarshalJSON(t *testing.T) {
config := &Config{
+2
View File
@@ -194,6 +194,7 @@ func TestGoroutineLeakPrevention_MultipleInstances(t *testing.T) {
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
handler, err := New(ctx, nil, config, "test")
if err != nil {
@@ -322,6 +323,7 @@ func TestGoroutineLeakPrevention_BackgroundTaskCleanup(t *testing.T) {
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
config.ClientID = "test-client-id"
config.ClientSecret = "test-client-secret"
config.CallbackURL = "/callback"
handler, err := New(ctx, nil, config, "test")
if err != nil {
+70 -68
View File
@@ -8,6 +8,7 @@ import (
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
@@ -25,38 +26,47 @@ func TestInitializeMetadata(t *testing.T) {
name: "successful metadata initialization",
providerURL: "",
setupMock: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Issuer must share the host with providerURL (the httptest
// server), otherwise the discovery doc is rejected as poisoned
// (audit ranks 21/22). Real providers keep issuer + endpoints on
// the same host, so derive them all from the server URL.
var srv *httptest.Server
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: "https://provider.example.com",
AuthURL: "https://provider.example.com/auth",
TokenURL: "https://provider.example.com/token",
JWKSURL: "https://provider.example.com/jwks",
RevokeURL: "https://provider.example.com/revoke",
EndSessionURL: "https://provider.example.com/logout",
Issuer: srv.URL,
AuthURL: srv.URL + "/auth",
TokenURL: srv.URL + "/token",
JWKSURL: srv.URL + "/jwks",
RevokeURL: srv.URL + "/revoke",
EndSessionURL: srv.URL + "/logout",
})
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
return srv
},
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
if oidc.authURL != "https://provider.example.com/auth" {
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
}
if oidc.tokenURL != "https://provider.example.com/token" {
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
}
if oidc.jwksURL != "https://provider.example.com/jwks" {
if oidc.jwksURL == "" || !strings.HasSuffix(oidc.jwksURL, "/jwks") {
t.Errorf("expected jwksURL to be set, got %s", oidc.jwksURL)
}
if oidc.revocationURL != "https://provider.example.com/revoke" {
if oidc.revocationURL == "" || !strings.HasSuffix(oidc.revocationURL, "/revoke") {
t.Errorf("expected revocationURL to be set, got %s", oidc.revocationURL)
}
if oidc.endSessionURL != "https://provider.example.com/logout" {
if oidc.endSessionURL == "" || !strings.HasSuffix(oidc.endSessionURL, "/logout") {
t.Errorf("expected endSessionURL to be set, got %s", oidc.endSessionURL)
}
if oidc.issuerURL == "" {
t.Errorf("expected issuerURL to be pinned to provider host, got empty")
}
},
wantPanic: false,
},
@@ -115,24 +125,27 @@ func TestInitializeMetadata(t *testing.T) {
name: "partial metadata response",
providerURL: "",
setupMock: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Issuer host must match providerURL (audit ranks 21/22).
var srv *httptest.Server
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Only return some fields
json.NewEncoder(w).Encode(map[string]string{
"issuer": "https://partial.example.com",
"authorization_endpoint": "https://partial.example.com/auth",
"token_endpoint": "https://partial.example.com/token",
"issuer": srv.URL,
"authorization_endpoint": srv.URL + "/auth",
"token_endpoint": srv.URL + "/token",
// Missing jwks_uri, revocation_endpoint, end_session_endpoint
})
}
}))
return srv
},
validateFunc: func(t *testing.T, oidc *TraefikOidc) {
if oidc.authURL != "https://partial.example.com/auth" {
if oidc.authURL == "" || !strings.HasSuffix(oidc.authURL, "/auth") {
t.Errorf("expected authURL to be set, got %s", oidc.authURL)
}
if oidc.tokenURL != "https://partial.example.com/token" {
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
t.Errorf("expected tokenURL to be set, got %s", oidc.tokenURL)
}
// JWKS URL and others may be empty
@@ -197,20 +210,22 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
requestCount := 0
var mu sync.Mutex
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
requestCount++
mu.Unlock()
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Issuer host must match providerURL (audit ranks 21/22).
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: "https://concurrent.example.com",
AuthURL: "https://concurrent.example.com/auth",
TokenURL: "https://concurrent.example.com/token",
JWKSURL: "https://concurrent.example.com/jwks",
RevokeURL: "https://concurrent.example.com/revoke",
EndSessionURL: "https://concurrent.example.com/logout",
Issuer: server.URL,
AuthURL: server.URL + "/auth",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/jwks",
RevokeURL: server.URL + "/revoke",
EndSessionURL: server.URL + "/logout",
})
}
}))
@@ -249,7 +264,7 @@ func TestInitializeMetadata_Concurrency(t *testing.T) {
oidc.initializeMetadata(server.URL)
// Verify initialization
if oidc.tokenURL != "https://concurrent.example.com/token" {
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
t.Errorf("expected tokenURL to be set")
}
}()
@@ -341,17 +356,19 @@ func TestProviderDetection(t *testing.T) {
// TestInitializationWaiting tests waiting for initialization to complete
func TestInitializationWaiting(t *testing.T) {
t.Run("wait for initialization completion", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Delay response to simulate slow initialization
time.Sleep(100 * time.Millisecond)
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Issuer host must match providerURL (audit ranks 21/22).
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: "https://slow.example.com",
AuthURL: "https://slow.example.com/auth",
TokenURL: "https://slow.example.com/token",
JWKSURL: "https://slow.example.com/jwks",
Issuer: server.URL,
AuthURL: server.URL + "/auth",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/jwks",
})
}
}))
@@ -388,7 +405,7 @@ func TestInitializationWaiting(t *testing.T) {
select {
case <-oidc.initComplete:
// Success
if oidc.tokenURL != "https://slow.example.com/token" {
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
t.Error("expected tokenURL to be set after initialization")
}
case <-time.After(2 * time.Second):
@@ -397,17 +414,19 @@ func TestInitializationWaiting(t *testing.T) {
})
t.Run("multiple waiters for initialization", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Delay to ensure multiple waiters
time.Sleep(50 * time.Millisecond)
if strings.HasSuffix(r.URL.Path, "/.well-known/openid-configuration") {
w.Header().Set("Content-Type", "application/json")
// Issuer host must match providerURL (audit ranks 21/22).
json.NewEncoder(w).Encode(ProviderMetadata{
Issuer: "https://multi.example.com",
AuthURL: "https://multi.example.com/auth",
TokenURL: "https://multi.example.com/token",
JWKSURL: "https://multi.example.com/jwks",
Issuer: server.URL,
AuthURL: server.URL + "/auth",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/jwks",
})
}
}))
@@ -452,7 +471,7 @@ func TestInitializationWaiting(t *testing.T) {
select {
case <-oidc.initComplete:
// All waiters should see the same initialized state
if oidc.tokenURL != "https://multi.example.com/token" {
if oidc.tokenURL == "" || !strings.HasSuffix(oidc.tokenURL, "/token") {
t.Errorf("waiter %d: expected tokenURL to be set", id)
}
case <-time.After(2 * time.Second):
@@ -484,9 +503,8 @@ func TestFirstRequestHandling(t *testing.T) {
defer server.Close()
oidc := &TraefikOidc{
providerURL: server.URL,
firstRequestReceived: false,
firstRequestMutex: sync.Mutex{},
providerURL: server.URL,
firstRequestStarted: 0,
httpClient: &http.Client{
Timeout: 5 * time.Second,
},
@@ -508,19 +526,13 @@ func TestFirstRequestHandling(t *testing.T) {
},
}
// Simulate first request processing
oidc.firstRequestMutex.Lock()
if !oidc.firstRequestReceived {
oidc.firstRequestReceived = true
oidc.firstRequestMutex.Unlock()
// Simulate first request processing — single-firing via CAS.
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
// This would normally be called asynchronously
go func() {
oidc.initializeMetadata(server.URL)
// initComplete is closed internally by initializeMetadata
}()
} else {
oidc.firstRequestMutex.Unlock()
}
// Wait for initialization
@@ -556,9 +568,8 @@ func TestFirstRequestHandling(t *testing.T) {
defer server.Close()
oidc := &TraefikOidc{
providerURL: server.URL,
firstRequestReceived: false,
firstRequestMutex: sync.Mutex{},
providerURL: server.URL,
firstRequestStarted: 0,
httpClient: &http.Client{
Timeout: 5 * time.Second,
},
@@ -580,31 +591,22 @@ func TestFirstRequestHandling(t *testing.T) {
},
}
// Simulate multiple concurrent "first" requests
// Simulate multiple concurrent "first" requests — only one CAS winner
// fires the bootstrap path.
const numRequests = 10
var wg sync.WaitGroup
wg.Add(numRequests)
initStarted := 0
var initMu sync.Mutex
var initStarted int32
for i := 0; i < numRequests; i++ {
go func() {
defer wg.Done()
oidc.firstRequestMutex.Lock()
if !oidc.firstRequestReceived {
oidc.firstRequestReceived = true
oidc.firstRequestMutex.Unlock()
initMu.Lock()
initStarted++
initMu.Unlock()
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
atomic.AddInt32(&initStarted, 1)
// Only one should actually start initialization
oidc.initializeMetadata(server.URL)
} else {
oidc.firstRequestMutex.Unlock()
}
}()
}
@@ -612,8 +614,8 @@ func TestFirstRequestHandling(t *testing.T) {
wg.Wait()
// Verify only one initialization was started
if initStarted != 1 {
t.Errorf("expected exactly 1 initialization, got %d", initStarted)
if atomic.LoadInt32(&initStarted) != 1 {
t.Errorf("expected exactly 1 initialization, got %d", atomic.LoadInt32(&initStarted))
}
// The metadata endpoint might be called once or not at all depending on timing
+549 -56
View File
@@ -61,8 +61,8 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com", // Required for initialization check
}
close(oidc.initComplete)
@@ -79,34 +79,186 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
}
}
// TestServeHTTP_EventStream tests the event-stream bypass functionality
// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the
// handshake must skip the OIDC redirect dance (clients can't follow it
// mid-stream) but it must STILL require an authenticated session, otherwise
// any caller could reach the backend by setting Accept: text/event-stream.
func TestServeHTTP_EventStream(t *testing.T) {
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
t.Run("unauthenticated_request_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated SSE bypass")
}
})
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
issuerURL: "https://provider.example.com",
t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
// Build an authenticated session and inject its cookies onto req.
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated SSE request to be forwarded to backend")
}
if forwardedUser != "user@example.com" {
t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser)
}
})
}
// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket
// handshake bypasses the OIDC redirect (clients can't follow it) but the
// session must already be authenticated, otherwise the backend is exposed
// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`.
func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
sessionManager := createTestSessionManager(t)
newOidc := func(next http.Handler) *TraefikOidc {
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
}
close(oidc.initComplete)
return oidc
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/events", nil)
req.Header.Set("Accept", "text/event-stream")
rw := httptest.NewRecorder()
t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) {
nextCalled := false
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
}))
oidc.ServeHTTP(rw, req)
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
rw := httptest.NewRecorder()
if !nextCalled {
t.Error("expected event-stream request to bypass OIDC")
}
oidc.ServeHTTP(rw, req)
if rw.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code)
}
if nextCalled {
t.Error("backend handler must NOT be called for unauthenticated WS bypass")
}
})
t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) {
nextCalled := false
var forwardedUser string
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
forwardedUser = r.Header.Get("X-Forwarded-User")
}))
req := httptest.NewRequest("GET", "/ws", nil)
// Mixed-case + multi-token Connection header to exercise parsing.
req.Header.Set("Connection", "keep-alive, Upgrade")
req.Header.Set("Upgrade", "WebSocket")
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("failed to create test session: %v", err)
}
session.SetUserIdentifier("ws-user@example.com")
if err := session.SetAuthenticated(true); err != nil {
t.Fatalf("failed to mark session authenticated: %v", err)
}
setupRW := httptest.NewRecorder()
if err := session.Save(req, setupRW); err != nil {
t.Fatalf("failed to save session: %v", err)
}
for _, c := range setupRW.Result().Cookies() {
req.AddCookie(c)
}
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if !nextCalled {
t.Fatal("expected authenticated WS handshake to be forwarded to backend")
}
if forwardedUser != "ws-user@example.com" {
t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser)
}
})
t.Run("plain_http_does_not_bypass", func(t *testing.T) {
// Sanity: requests without Upgrade headers must NOT hit the WS
// bypass branch (otherwise the new code path could short-circuit
// normal authentication).
oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("backend must not be called for unauthenticated plain HTTP")
}))
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Connection", "keep-alive")
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
if rw.Code == http.StatusOK {
t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200")
}
})
}
// TestServeHTTP_InitializationTimeout tests initialization timeout handling
@@ -120,8 +272,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}), // Never close this to simulate timeout
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
}
req := httptest.NewRequest("GET", "/protected", nil)
@@ -155,8 +307,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -185,8 +337,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -215,8 +367,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -256,7 +408,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "successful authorization with email",
setupSession: func() *MockSessionData {
session := &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
@@ -288,7 +440,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "no email triggers reauth",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "",
userIdentifier: "",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -309,7 +461,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "roles and groups authorization",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -342,7 +494,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "unauthorized role/group returns 403",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -369,7 +521,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "template headers processing",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
isDirty: false,
@@ -401,7 +553,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
name: "OPTIONS request with CORS",
setupSession: func() *MockSessionData {
return &MockSessionData{
email: "user@example.com",
userIdentifier: "user@example.com",
idToken: "test-id-token",
accessToken: "test-access-token",
}
@@ -452,7 +604,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
manager: &SessionManager{logger: NewLogger("debug")},
}
// Copy values from mock to concrete session
concreteSession.SetEmail(session.email)
concreteSession.SetUserIdentifier(session.userIdentifier)
concreteSession.SetIDToken(session.idToken)
concreteSession.SetAccessToken(session.accessToken)
concreteSession.SetRefreshToken(session.refreshToken)
@@ -502,23 +654,23 @@ func TestProcessAuthorizedRequest(t *testing.T) {
// MockSessionData is a test implementation of SessionData interface
type MockSessionData struct {
email string
idToken string
accessToken string
refreshToken string
csrf string
nonce string
codeVerifier string
redirectCount int
authenticated bool
isDirty bool
userIdentifier string
idToken string
accessToken string
refreshToken string
csrf string
nonce string
codeVerifier string
redirectCount int
authenticated bool
isDirty bool
}
func (m *MockSessionData) GetEmail() string { return m.email }
func (m *MockSessionData) GetUserIdentifier() string { return m.userIdentifier }
func (m *MockSessionData) GetIDToken() string { return m.idToken }
func (m *MockSessionData) GetAccessToken() string { return m.accessToken }
func (m *MockSessionData) GetRefreshToken() string { return m.refreshToken }
func (m *MockSessionData) SetEmail(email string) { m.email = email }
func (m *MockSessionData) SetUserIdentifier(userIdentifier string) { m.userIdentifier = userIdentifier }
func (m *MockSessionData) SetIDToken(token string) { m.idToken = token }
func (m *MockSessionData) SetAccessToken(token string) { m.accessToken = token }
func (m *MockSessionData) SetRefreshToken(token string) { m.refreshToken = token }
@@ -588,8 +740,8 @@ func TestMinimalHeaders(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
minimalHeaders: tt.minimalHeaders,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
@@ -610,7 +762,7 @@ func TestMinimalHeaders(t *testing.T) {
}
// Set up session data
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Call processAuthorizedRequest directly
@@ -665,8 +817,8 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
minimalHeaders: true, // Enable minimal headers
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
@@ -685,7 +837,7 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
t.Fatalf("Failed to get session: %v", err)
}
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
@@ -710,3 +862,344 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
t.Error("expected X-Auth-Request-Redirect to NOT be set with minimalHeaders=true")
}
}
// TestStripAuthCookies tests the stripAuthCookies configuration option.
// This addresses GitHub issue #122 - OIDC cookies bloating backend requests.
func TestStripAuthCookies(t *testing.T) {
tests := []struct {
name string
stripAuthCookies bool
expectOIDCCookies bool
expectAppCookies bool
}{
{
name: "stripAuthCookies=false (default) forwards all cookies",
stripAuthCookies: false,
expectOIDCCookies: true,
expectAppCookies: true,
},
{
name: "stripAuthCookies=true strips OIDC cookies but keeps app cookies",
stripAuthCookies: true,
expectOIDCCookies: false,
expectAppCookies: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
cookiePrefix := sessionManager.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: tt.stripAuthCookies,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
// Get a valid session first (before adding fake cookies)
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Now add OIDC session cookies (simulating what the browser would send)
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_1", Value: "chunk1"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "r", Value: "refresh-token"})
// Add non-OIDC application cookies (these must always pass through)
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "app-session-id"})
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Check for OIDC cookies in captured cookies
hasOIDCCookie := false
hasAppSession := false
hasTheme := false
for _, c := range capturedCookies {
if len(c.Name) >= len(cookiePrefix) && c.Name[:len(cookiePrefix)] == cookiePrefix {
hasOIDCCookie = true
}
if c.Name == "my_app_session" {
hasAppSession = true
}
if c.Name == "theme" {
hasTheme = true
}
}
if tt.expectOIDCCookies && !hasOIDCCookie {
t.Error("expected OIDC cookies to be forwarded to backend")
}
if !tt.expectOIDCCookies && hasOIDCCookie {
t.Error("expected OIDC cookies to be stripped before forwarding to backend")
}
if tt.expectAppCookies && !hasAppSession {
t.Error("expected my_app_session cookie to be forwarded to backend")
}
if tt.expectAppCookies && !hasTheme {
t.Error("expected theme cookie to be forwarded to backend")
}
})
}
}
// TestStripAuthCookies_NoCookies verifies stripping works when the request has no cookies.
func TestStripAuthCookies_NoCookies(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 0 {
t.Errorf("expected no cookies, got %d", len(capturedCookies))
}
}
// TestStripAuthCookies_OnlyOIDCCookies verifies that when all cookies are OIDC cookies,
// the Cookie header is empty after stripping.
func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
var capturedCookieHeader string
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookieHeader = r.Header.Get("Cookie")
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
cookiePrefix := sessionManager.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only OIDC cookies
req.AddCookie(&http.Cookie{Name: cookiePrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "s_0", Value: "chunk0"})
req.AddCookie(&http.Cookie{Name: cookiePrefix + "a", Value: "access-token"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 0 {
t.Errorf("expected all cookies to be stripped, got %d", len(capturedCookies))
}
if capturedCookieHeader != "" {
t.Errorf("expected empty Cookie header, got %q", capturedCookieHeader)
}
}
// TestStripAuthCookies_OnlyAppCookies verifies that non-OIDC cookies pass through
// untouched when stripping is enabled.
func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
sessionManager := createTestSessionManager(t)
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add only non-OIDC cookies
req.AddCookie(&http.Cookie{Name: "my_app_session", Value: "abc123"})
req.AddCookie(&http.Cookie{Name: "lang", Value: "en"})
req.AddCookie(&http.Cookie{Name: "theme", Value: "dark"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
if len(capturedCookies) != 3 {
t.Errorf("expected 3 cookies, got %d", len(capturedCookies))
}
cookieNames := make(map[string]bool)
for _, c := range capturedCookies {
cookieNames[c.Name] = true
}
for _, expected := range []string{"my_app_session", "lang", "theme"} {
if !cookieNames[expected] {
t.Errorf("expected cookie %q to be forwarded", expected)
}
}
}
// TestStripAuthCookies_CustomPrefix verifies stripping works with a custom cookie prefix.
func TestStripAuthCookies_CustomPrefix(t *testing.T) {
var capturedCookies []*http.Cookie
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedCookies = r.Cookies()
w.WriteHeader(http.StatusOK)
})
// Create session manager with custom prefix
sm, err := NewSessionManager("test-encryption-key-32-characters", false, "", "myapp_oidc_", 0, NewLogger("debug"))
if err != nil {
t.Fatalf("Failed to create session manager: %v", err)
}
customPrefix := sm.GetCookiePrefix()
oidc := &TraefikOidc{
next: next,
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sm,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
stripAuthCookies: true,
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
return map[string]interface{}{"email": "user@example.com"}, nil
},
}
close(oidc.initComplete)
req := httptest.NewRequest("GET", "/protected", nil)
rw := httptest.NewRecorder()
session, err := sm.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
// Add cookies with the custom prefix (should be stripped)
req.AddCookie(&http.Cookie{Name: customPrefix + "m", Value: "session-data"})
req.AddCookie(&http.Cookie{Name: customPrefix + "s_0", Value: "chunk0"})
// Add default-prefix cookie (should NOT be stripped — different prefix)
req.AddCookie(&http.Cookie{Name: "_oidc_raczylo_m", Value: "other-session"})
// Add app cookie (should NOT be stripped)
req.AddCookie(&http.Cookie{Name: "my_app", Value: "val"})
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
cookieNames := make(map[string]bool)
for _, c := range capturedCookies {
cookieNames[c.Name] = true
}
// Custom prefix cookies should be stripped
if cookieNames[customPrefix+"m"] {
t.Errorf("expected cookie %q to be stripped", customPrefix+"m")
}
if cookieNames[customPrefix+"s_0"] {
t.Errorf("expected cookie %q to be stripped", customPrefix+"s_0")
}
// Default prefix cookie should pass through (different prefix)
if !cookieNames["_oidc_raczylo_m"] {
t.Error("expected _oidc_raczylo_m cookie to pass through (different prefix)")
}
// App cookie should pass through
if !cookieNames["my_app"] {
t.Error("expected my_app cookie to pass through")
}
}
+113 -63
View File
@@ -16,6 +16,7 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@@ -208,6 +209,32 @@ func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *
return m.JWKS, m.Err
}
func (m *MockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.Err != nil {
return nil, m.Err
}
if m.JWKS == nil {
return nil, fmt.Errorf("JWKS is nil")
}
for i := range m.JWKS.Keys {
k := &m.JWKS.Keys[i]
if k.Kid != kid {
continue
}
switch k.Kty {
case "RSA":
return k.ToRSAPublicKey()
case "EC":
return k.ToECDSAPublicKey()
default:
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
}
}
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
}
func (m *MockJWKCache) Cleanup() {
// Mock cleanup is a no-op - we don't want to destroy the mock JWKS data
// Real cleanup is for expired entries, not resetting all data
@@ -554,7 +581,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Generate a fresh valid token for this test case to avoid replay issues
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -577,7 +604,7 @@ func TestServeHTTP(t *testing.T) {
// even if session.SetAuthenticated(true) was called.
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -634,7 +661,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -652,7 +679,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -680,7 +707,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true) // Set flag initially
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
// Create an expired token for this test
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
@@ -715,7 +742,7 @@ func TestServeHTTP(t *testing.T) {
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken(nearExpiryToken)
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
},
@@ -746,7 +773,7 @@ func TestServeHTTP(t *testing.T) {
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
})
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken(validToken)
session.SetIDToken(validToken) // Ensure ID token is also set
session.SetRefreshToken("should-not-be-used-refresh-token")
@@ -766,7 +793,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@disallowed.com") // Use disallowed domain
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -788,7 +815,7 @@ func TestServeHTTP(t *testing.T) {
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@disallowed.com") // Use disallowed domain
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
// Generate a fresh valid token for this test case
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
@@ -1848,14 +1875,14 @@ func TestHandleLogout(t *testing.T) {
},
endSessionURL: "",
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
expectedURL: "/",
host: "test-host",
},
{
name: "Logout with empty session",
setupSession: func(session *SessionData) {},
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
expectedURL: "/",
host: "test-host",
},
{
@@ -2153,7 +2180,7 @@ func TestHandleExpiredToken(t *testing.T) {
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
})
session.SetAccessToken(expiredToken)
session.SetEmail("test@example.com")
session.SetUserIdentifier("test@example.com")
},
expectedPath: "/original/path",
},
@@ -2322,19 +2349,22 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
t.Skip("Skipping test in short mode")
}
// Create mock provider metadata server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create mock provider metadata server. Issuer + endpoints must share the
// host with ProviderURL (the httptest server), otherwise the discovery doc
// is rejected as poisoned (audit ranks 21/22). Derive them from the server.
var mockServer *httptest.Server
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
metadata := ProviderMetadata{
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
RevokeURL: "https://test-issuer.com/revoke",
EndSessionURL: "https://test-issuer.com/end-session",
Issuer: mockServer.URL,
AuthURL: mockServer.URL + "/auth",
TokenURL: mockServer.URL + "/token",
JWKSURL: mockServer.URL + "/jwks",
RevokeURL: mockServer.URL + "/revoke",
EndSessionURL: mockServer.URL + "/end-session",
}
json.NewEncoder(w).Encode(metadata)
}))
@@ -2347,6 +2377,7 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
RateLimit: 100,
}
// Create multiple middleware instances
@@ -2387,18 +2418,20 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
t.Fatalf("Middleware instance %d failed to initialize", i)
}
// Verify each instance has its own unique configuration
if m.issuerURL != "https://test-issuer.com" {
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL)
// Verify each instance has its own unique configuration. Issuer is now
// pinned to the provider host (audit ranks 21/22), so it equals the
// mock server URL rather than a fixed literal.
if m.issuerURL != mockServer.URL {
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, mockServer.URL, m.issuerURL)
}
if m.authURL != "https://test-issuer.com/auth" {
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL)
if m.authURL != mockServer.URL+"/auth" {
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, mockServer.URL+"/auth", m.authURL)
}
if m.tokenURL != "https://test-issuer.com/token" {
t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL)
if m.tokenURL != mockServer.URL+"/token" {
t.Errorf("Instance %d: Expected token URL %s, got %s", i, mockServer.URL+"/token", m.tokenURL)
}
if m.jwksURL != "https://test-issuer.com/jwks" {
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL)
if m.jwksURL != mockServer.URL+"/jwks" {
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, mockServer.URL+"/jwks", m.jwksURL)
}
if m.redirURLPath != routes[i]+"/callback" {
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
@@ -2412,15 +2445,16 @@ func TestMultipleMiddlewareInstances(t *testing.T) {
m.ServeHTTP(rr, req)
// Should redirect to auth URL since not authenticated
// Should redirect (302) to the auth flow since not authenticated. The
// absolute auth URL is not asserted here: with issuer pinning (audit
// ranks 21/22) the discovery host equals the httptest server host,
// which is loopback, so buildAuthURL's SSRF guard legitimately refuses
// to emit a loopback authorization URL in this test environment. The
// per-instance auth/token/jwks/issuer URLs were already verified above;
// here we only confirm each instance independently triggers a redirect.
if rr.Code != http.StatusFound {
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
}
location := rr.Header().Get("Location")
if !strings.Contains(location, "https://test-issuer.com/auth") {
t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location)
}
}
}
@@ -2433,33 +2467,43 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
}
// Create two mock provider metadata servers simulating different Keycloak realms
realm1Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Issuer + endpoints must share the host with each realm's ProviderURL
// (the httptest server), otherwise the discovery doc is rejected as
// poisoned (audit ranks 21/22). Keep the distinguishing /realms/realmN
// path so the per-realm isolation assertions below still hold, but base
// the host on the server URL — which is exactly what a same-host Keycloak
// deployment looks like.
var realm1Server *httptest.Server
realm1Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
base := realm1Server.URL + "/realms/realm1"
metadata := ProviderMetadata{
Issuer: "https://keycloak.example.com/realms/realm1",
AuthURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/auth",
TokenURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/token",
JWKSURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/certs",
EndSessionURL: "https://keycloak.example.com/realms/realm1/protocol/openid-connect/logout",
Issuer: base,
AuthURL: base + "/protocol/openid-connect/auth",
TokenURL: base + "/protocol/openid-connect/token",
JWKSURL: base + "/protocol/openid-connect/certs",
EndSessionURL: base + "/protocol/openid-connect/logout",
}
json.NewEncoder(w).Encode(metadata)
}))
defer realm1Server.Close()
realm2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var realm2Server *httptest.Server
realm2Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
w.WriteHeader(http.StatusNotFound)
return
}
base := realm2Server.URL + "/realms/realm2"
metadata := ProviderMetadata{
Issuer: "https://keycloak.example.com/realms/realm2",
AuthURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/auth",
TokenURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/token",
JWKSURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/certs",
EndSessionURL: "https://keycloak.example.com/realms/realm2/protocol/openid-connect/logout",
Issuer: base,
AuthURL: base + "/protocol/openid-connect/auth",
TokenURL: base + "/protocol/openid-connect/token",
JWKSURL: base + "/protocol/openid-connect/certs",
EndSessionURL: base + "/protocol/openid-connect/logout",
}
json.NewEncoder(w).Encode(metadata)
}))
@@ -2473,6 +2517,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
CallbackURL: "/realm1/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
CookiePrefix: "_oidc_realm1_",
RateLimit: 100,
}
// Config for realm2
@@ -2483,6 +2528,7 @@ func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
CallbackURL: "/realm2/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
CookiePrefix: "_oidc_realm2_",
RateLimit: 100,
}
// Create middleware instances for both realms
@@ -2581,8 +2627,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
providerAvailable := false
var mu sync.Mutex
// Create mock provider that initially fails, then becomes available
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Create mock provider that initially fails, then becomes available.
// Issuer + endpoints must share the host with ProviderURL (audit ranks
// 21/22), so derive them from the server URL.
var mockServer *httptest.Server
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
available := providerAvailable
mu.Unlock()
@@ -2594,11 +2643,11 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
if r.URL.Path == "/.well-known/openid-configuration" {
metadata := ProviderMetadata{
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
EndSessionURL: "https://test-issuer.com/logout",
Issuer: mockServer.URL,
AuthURL: mockServer.URL + "/auth",
TokenURL: mockServer.URL + "/token",
JWKSURL: mockServer.URL + "/jwks",
EndSessionURL: mockServer.URL + "/logout",
}
json.NewEncoder(w).Encode(metadata)
return
@@ -2613,6 +2662,7 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
RateLimit: 100,
}
// Create middleware while provider is unavailable
@@ -2659,10 +2709,9 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
providerAvailable = true
mu.Unlock()
// Reset the retry timer to allow immediate retry
m.metadataRetryMutex.Lock()
m.lastMetadataRetryTime = time.Time{} // Reset to zero time
m.metadataRetryMutex.Unlock()
// Reset the retry timer to allow immediate retry. The field is atomic
// now, so no lock is needed.
atomic.StoreInt64(&m.lastMetadataRetryNano, 0)
// Second request should trigger recovery attempt
req2 := httptest.NewRequest("GET", "/protected", nil)
@@ -2730,7 +2779,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2756,7 +2805,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2783,7 +2832,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusForbidden,
},
@@ -2803,7 +2852,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
@@ -2825,7 +2874,7 @@ func TestServeHTTPRolesAndGroups(t *testing.T) {
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{},
@@ -4526,6 +4575,7 @@ func TestNewWithScopeAppending(t *testing.T) {
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
Scopes: tc.configScopes,
RateLimit: 100,
}
// Create middleware instance
+24 -15
View File
@@ -9,13 +9,18 @@ import (
// LazyBackgroundTask wraps BackgroundTask to provide delayed initialization.
// This prevents memory leaks from unnecessary background tasks by starting
// them only when actually needed, reducing resource usage in idle scenarios.
//
// Lifecycle is one-shot: once Stop has been called the task cannot be
// restarted. The underlying BackgroundTask uses sync.Once for Start and
// refuses to re-run after Stop, so restart is not supported by design.
type LazyBackgroundTask struct {
// BackgroundTask is the underlying task implementation
*BackgroundTask
// started tracks whether the task has been activated
// mu guards the started flag against concurrent StartIfNeeded / Stop calls.
mu sync.Mutex
// started tracks whether the task has been activated.
// Only mutated while holding mu.
started bool
// startOnce ensures single initialization
startOnce sync.Once
}
// NewLazyBackgroundTask creates a background task that doesn't start immediately.
@@ -29,24 +34,28 @@ func NewLazyBackgroundTask(name string, interval time.Duration, taskFunc func(),
}
// StartIfNeeded starts the background task only if it hasn't been started yet.
// Uses sync.Once to ensure thread-safe single initialization.
// Safe to call concurrently. After Stop has been called this is a no-op;
// the task is not restartable.
func (lt *LazyBackgroundTask) StartIfNeeded() {
lt.startOnce.Do(func() {
if !lt.started {
lt.BackgroundTask.Start()
lt.started = true
}
})
lt.mu.Lock()
defer lt.mu.Unlock()
if lt.started {
return
}
lt.BackgroundTask.Start()
lt.started = true
}
// Stop stops the background task if it was started.
// Resets the start state to allow potential future re-initialization.
// Once stopped, the task cannot be restarted (see type doc).
func (lt *LazyBackgroundTask) Stop() {
if lt.started {
lt.BackgroundTask.Stop()
lt.started = false
lt.startOnce = sync.Once{}
lt.mu.Lock()
defer lt.mu.Unlock()
if !lt.started {
return
}
lt.BackgroundTask.Stop()
lt.started = false
}
// NewLazyCacheWithLogger creates a cache that doesn't start cleanup until first use.
+1
View File
@@ -1652,6 +1652,7 @@ func TestGoroutineLeaks(t *testing.T) {
config.SessionEncryptionKey = "test-encryption-key-32-bytes-long"
config.ClientID = "test-client"
config.ClientSecret = "test-secret"
config.CallbackURL = "/callback"
handler, err := New(context.Background(), nil, config, "test")
require.NoError(t, err)
+142 -12
View File
@@ -58,13 +58,21 @@ func (mpl MemoryPressureLevel) String() string {
}
}
// MemoryMonitor provides comprehensive memory monitoring and alerting
// MemoryMonitor provides comprehensive memory monitoring and alerting.
//
// Memory sampling is expensive: runtime.ReadMemStats is a stop-the-world
// operation. To keep latency predictable the monitor caches the most recent
// sample and only refreshes it when the background ticker fires, when TriggerGC
// is invoked, or when a caller explicitly calls Refresh(). GetCurrentStats is a
// cheap read of that cached sample.
type MemoryMonitor struct {
lastGCTime time.Time
startTime time.Time
lastStats *MemoryStats
cachedMemStats runtime.MemStats
logger *Logger
alertThresholds MemoryAlertThresholds
config MemoryMonitorConfig
baselineGoroutines int
baselineHeap uint64
heapGrowthRate float64
@@ -84,6 +92,30 @@ type MemoryAlertThresholds struct {
GCFrequency float64 // Alert when GC frequency exceeds this per minute
}
// MemoryMonitorConfig configures the memory monitor's scheduling behavior.
// Thresholds are kept separate in MemoryAlertThresholds.
type MemoryMonitorConfig struct {
// Interval between background samples. Must be >= MinMemoryMonitorInterval
// (30s). Values below the minimum are clamped when monitoring starts.
Interval time.Duration
}
// Default and minimum interval values. The minimum exists because
// runtime.ReadMemStats is stop-the-world and hammering it on a hot loop causes
// noticeable latency spikes, especially under yaegi.
const (
DefaultMemoryMonitorInterval = 60 * time.Second
MinMemoryMonitorInterval = 30 * time.Second
)
// DefaultMemoryMonitorConfig returns a config with sensible production
// defaults.
func DefaultMemoryMonitorConfig() MemoryMonitorConfig {
return MemoryMonitorConfig{
Interval: DefaultMemoryMonitorInterval,
}
}
// DefaultMemoryAlertThresholds returns sensible default alert thresholds
func DefaultMemoryAlertThresholds() MemoryAlertThresholds {
return MemoryAlertThresholds{
@@ -95,35 +127,82 @@ func DefaultMemoryAlertThresholds() MemoryAlertThresholds {
}
}
// NewMemoryMonitor creates a new memory monitor
// NewMemoryMonitor creates a new memory monitor using default scheduling
// configuration. See NewMemoryMonitorWithConfig for full control.
func NewMemoryMonitor(logger *Logger, thresholds MemoryAlertThresholds) *MemoryMonitor {
return NewMemoryMonitorWithConfig(logger, thresholds, DefaultMemoryMonitorConfig())
}
// NewMemoryMonitorWithConfig creates a new memory monitor with an explicit
// scheduling config.
//
// NOTE: the constructor performs a single runtime.ReadMemStats call to capture
// baseline heap / goroutine / GC counters used for leak and growth detection.
// This is a one-time stop-the-world cost at startup; all subsequent samples
// only happen on the monitoring ticker or on explicit Refresh() calls.
func NewMemoryMonitorWithConfig(logger *Logger, thresholds MemoryAlertThresholds, config MemoryMonitorConfig) *MemoryMonitor {
if logger == nil {
logger = GetSingletonNoOpLogger()
}
if config.Interval <= 0 {
config.Interval = DefaultMemoryMonitorInterval
}
// One-time initial sample to seed baselines used for growth / leak
// detection. All subsequent sampling is gated by the monitoring ticker or
// explicit Refresh() calls.
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
return &MemoryMonitor{
mm := &MemoryMonitor{
logger: logger,
startTime: time.Now(),
alertThresholds: thresholds,
config: config,
baselineHeap: memStats.HeapAlloc,
baselineGoroutines: runtime.NumGoroutine(),
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
lastGCTime: time.Unix(0, int64(memStats.LastGC)),
lastGCCount: memStats.NumGC,
}
mm.cachedMemStats = memStats
return mm
}
// GetCurrentStats collects current memory statistics
// GetCurrentStats returns the most recently sampled memory statistics.
//
// This is a cheap cached read: it does NOT call runtime.ReadMemStats. Samples
// are refreshed only by the monitoring ticker or by an explicit call to
// Refresh(). If no sample has been produced yet, stats derived from the
// constructor-time raw sample are returned (with no additional STW cost).
func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
mm.mu.RLock()
stats := mm.lastStats
mm.mu.RUnlock()
if stats != nil {
return stats
}
return mm.buildStatsFromCache()
}
// Refresh synchronously samples current memory statistics via
// runtime.ReadMemStats and updates the cached value. This is the only path
// (other than the monitoring ticker and TriggerGC) that pays the stop-the-world
// cost. Use it in tests or in callers that explicitly need a fresh sample.
func (mm *MemoryMonitor) Refresh() *MemoryStats {
return mm.sample()
}
// sample performs a stop-the-world ReadMemStats, updates the cached raw stats,
// computes a derived MemoryStats snapshot, and stores it as lastStats.
func (mm *MemoryMonitor) sample() *MemoryStats {
var memStats runtime.MemStats
runtime.ReadMemStats(&memStats)
now := time.Now()
// Calculate GC frequency
// Calculate GC frequency relative to the previous snapshot.
gcFrequency := 0.0
mm.mu.RLock()
lastStats := mm.lastStats
@@ -168,6 +247,7 @@ func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
mm.updateHeapGrowthTracking(stats)
mm.mu.Lock()
mm.cachedMemStats = memStats
mm.lastStats = stats
mm.lastGCCount = memStats.NumGC
mm.mu.Unlock()
@@ -175,6 +255,35 @@ func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
return stats
}
// buildStatsFromCache constructs a MemoryStats snapshot from the cached raw
// runtime.MemStats without issuing a new ReadMemStats call. Used as a fallback
// when GetCurrentStats is called before the first sample() has completed.
func (mm *MemoryMonitor) buildStatsFromCache() *MemoryStats {
mm.mu.RLock()
memStats := mm.cachedMemStats
mm.mu.RUnlock()
stats := &MemoryStats{
HeapAllocBytes: memStats.HeapAlloc,
HeapSysBytes: memStats.HeapSys,
HeapIdleBytes: memStats.HeapIdle,
HeapInuseBytes: memStats.HeapInuse,
HeapReleasedBytes: memStats.HeapReleased,
HeapObjects: memStats.HeapObjects,
StackInuseBytes: memStats.StackInuse,
StackSysBytes: memStats.StackSys,
GCSysBytes: memStats.GCSys,
NumGoroutines: runtime.NumGoroutine(),
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
LastGCTime: time.Unix(0, int64(memStats.LastGC)),
GCFrequency: 0.0,
Timestamp: time.Now(),
}
mm.collectApplicationStats(stats)
stats.MemoryPressure = mm.calculateMemoryPressure(stats)
return stats
}
// collectApplicationStats gathers application-specific memory stats
func (mm *MemoryMonitor) collectApplicationStats(stats *MemoryStats) {
// Get session count from ChunkManager if available
@@ -229,7 +338,7 @@ func (mm *MemoryMonitor) updateGoroutineTracking(stats *MemoryStats) {
}
// Check for potential goroutine leak
if stats.NumGoroutines > mm.baselineGoroutines+int(mm.alertThresholds.GoroutineCount) {
if stats.NumGoroutines > mm.baselineGoroutines+mm.alertThresholds.GoroutineCount {
mm.mu.Lock()
wasAlert := mm.goroutineLeakAlert
if !wasAlert {
@@ -302,7 +411,16 @@ var (
globalMonitoringMutex sync.Mutex
)
// StartMonitoring starts continuous memory monitoring as a global singleton
// StartMonitoring starts continuous memory monitoring as a global singleton.
//
// The effective interval is resolved as follows:
// 1. If the caller passes a positive interval, that is used.
// 2. Otherwise the configured MemoryMonitorConfig.Interval is used.
// 3. Otherwise the built-in default (60s) is used.
//
// The result is then clamped to a minimum of MinMemoryMonitorInterval (30s) to
// avoid stop-the-world ReadMemStats storms. Callers that need rapid updates in
// tests should call Refresh() directly instead of spinning the ticker fast.
func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Duration) {
globalMonitoringMutex.Lock()
defer globalMonitoringMutex.Unlock()
@@ -316,7 +434,17 @@ func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Dura
}
if interval <= 0 {
interval = 30 * time.Second
interval = mm.config.Interval
}
if interval <= 0 {
interval = DefaultMemoryMonitorInterval
}
if interval < MinMemoryMonitorInterval {
if !isTestMode() {
mm.logger.Debug("Memory monitor interval %v is below minimum %v; clamping",
interval, MinMemoryMonitorInterval)
}
interval = MinMemoryMonitorInterval
}
registry := GetGlobalTaskRegistry()
@@ -325,7 +453,7 @@ func (mm *MemoryMonitor) StartMonitoring(ctx context.Context, interval time.Dura
"memory-monitor",
interval,
func() {
stats := mm.GetCurrentStats()
stats := mm.sample()
mm.LogMemoryStats(stats)
mm.checkAlerts(stats)
},
@@ -369,14 +497,16 @@ func (mm *MemoryMonitor) checkAlerts(stats *MemoryStats) {
}
}
// TriggerGC forces garbage collection and logs the impact
// TriggerGC forces garbage collection and logs the impact. Both the before and
// after measurements are fresh samples (explicit Refresh() calls) because the
// comparison is meaningless against a stale cached snapshot.
func (mm *MemoryMonitor) TriggerGC() {
before := mm.GetCurrentStats()
before := mm.Refresh()
runtime.GC()
runtime.GC() // Run twice to ensure full collection
after := mm.GetCurrentStats()
after := mm.Refresh()
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
freedBytes := int64(before.HeapAllocBytes) - int64(after.HeapAllocBytes)
+11 -1
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
@@ -141,10 +142,19 @@ func (mc *MetadataCache) GetProviderMetadata(ctx context.Context, providerURL st
}
var metadata ProviderMetadata
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
if err := json.NewDecoder(io.LimitReader(resp.Body, 1<<20)).Decode(&metadata); err != nil {
return nil, fmt.Errorf("failed to decode metadata: %w", err)
}
// Pin the advertised issuer to the configured provider host. The issuer is
// the trust anchor for JWT issuer validation; rejecting a mismatch here
// ensures a poisoned discovery document advertising an attacker-chosen
// issuer is never cached or returned. Real providers (Google, Azure,
// Keycloak, Okta, Auth0) keep the issuer on the same host as providerURL.
if metadata.Issuer != "" && !sameHost(metadata.Issuer, providerURL) {
return nil, fmt.Errorf("discovery issuer %q host does not match provider %q", metadata.Issuer, providerURL)
}
// Cache for 1 hour by default
if err := mc.Set(providerURL, &metadata, 1*time.Hour); err != nil {
mc.logger.Errorf("Failed to cache metadata: %v", err)
+596 -111
View File
@@ -8,11 +8,105 @@ import (
"fmt"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/lukaszraczylo/traefikoidc/internal/utils"
)
// bypassReason describes why a request is being forwarded without OIDC auth.
// It is only used for logging and to decide whether extra side-effects
// (propagating the user header from an existing session) should run.
const (
bypassReasonExcluded = "excluded-url"
bypassReasonSSE = "sse"
bypassReasonWebSocket = "websocket"
)
// isWebSocketUpgrade reports whether req is a WebSocket upgrade handshake
// (RFC 6455). The middleware can only see the handshake; once Traefik
// completes the upgrade it forwards frames directly, so we never re-process
// per-frame traffic. We bypass auth on the handshake the same way we do for
// SSE, because browser WebSocket clients cannot follow an OIDC redirect.
func isWebSocketUpgrade(req *http.Request) bool {
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
return false
}
for _, token := range strings.Split(req.Header.Get("Connection"), ",") {
if strings.EqualFold(strings.TrimSpace(token), "upgrade") {
return true
}
}
return false
}
// shouldBypassAuth decides whether a request must skip OIDC authentication
// entirely. It returns (true, reason) when either the request path matches a
// configured excluded URL, the Accept header asks for a text/event-stream
// response (SSE), or the request is a WebSocket upgrade handshake. The
// reason lets ServeHTTP apply any side-effects that are unique to the bypass
// kind (e.g. propagating user headers).
//
// This must be called BEFORE waiting on t.initComplete so excluded, SSE and
// WebSocket traffic is never blocked by a slow/broken provider.
func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) {
if t.determineExcludedURL(req.URL.Path) {
return true, bypassReasonExcluded
}
if strings.Contains(req.Header.Get("Accept"), "text/event-stream") {
return true, bypassReasonSSE
}
if isWebSocketUpgrade(req) {
return true, bypassReasonWebSocket
}
return false, ""
}
// applyBypassUserHeaders enforces authentication on SSE / WebSocket bypass
// requests and, on success, copies the authenticated user's identity onto
// the outgoing request so downstream services can see who the user is.
//
// Returns true when the request carries a valid authenticated session and
// the bypass should proceed. Returns false when no usable session is
// present; callers must then reject the request (typically with 401) to
// prevent unauthenticated traffic from reaching the backend just by setting
// `Accept: text/event-stream` or sending a WebSocket upgrade.
//
// The check is cookie-only: the session cookie is sealed by our encryption
// key, so the authenticated flag cannot be forged. We do NOT run full token
// signature verification here so that SSE/WS keeps working when the OIDC
// provider is briefly unavailable for JWK fetches.
func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) bool {
if t.sessionManager == nil {
return false
}
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Debugf("%s bypass: unable to load session: %v", reason, err)
return false
}
defer session.returnToPoolSafely()
if !session.GetAuthenticated() {
t.logger.Debugf("%s bypass: rejecting request without authenticated session", reason)
return false
}
userIdentifier := session.GetUserIdentifier()
if userIdentifier == "" {
t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason)
return false
}
req.Header.Set("X-Forwarded-User", userIdentifier)
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-User", userIdentifier)
}
t.logger.Debugf("%s bypass: forwarded user %s from session", reason, userIdentifier)
return true
}
// ServeHTTP implements the main middleware logic for processing HTTP requests.
// It handles the complete OIDC authentication flow including:
// - Excluded URL bypass
@@ -26,38 +120,120 @@ import (
// - rw: The HTTP response writer.
// - req: The incoming HTTP request.
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Log request entry for debugging routing issues
t.logger.Debugf("Incoming request: %s %s", req.Method, req.URL.Path)
// Handle logout requests early - before waiting for OIDC initialization
// This allows users to logout even if the OIDC provider is unavailable
if req.URL.Path == t.logoutURLPath {
t.logger.Debugf("Logout path matched early: %s", req.URL.Path)
t.handleLogout(rw, req)
return
}
// Handle backchannel logout (IdP-initiated POST with logout_token)
if t.enableBackchannelLogout && t.backchannelLogoutPath != "" && req.URL.Path == t.backchannelLogoutPath {
t.logger.Debug("Backchannel logout path matched")
t.handleBackchannelLogout(rw, req)
return
}
// Handle front-channel logout (IdP-initiated GET with sid/iss in iframe)
if t.enableFrontchannelLogout && t.frontchannelLogoutPath != "" && req.URL.Path == t.frontchannelLogoutPath {
t.logger.Debug("Front-channel logout path matched")
t.handleFrontchannelLogout(rw, req)
return
}
if !strings.HasPrefix(req.URL.Path, "/health") {
t.firstRequestMutex.Lock()
if !t.firstRequestReceived {
t.firstRequestReceived = true
// Lock-free one-shot bootstrap. The previous firstRequestMutex.Lock()
// fired on EVERY non-health request forever (even after the boolean
// flipped true), which under Yaegi added a per-request serialization
// point. CAS gives single-firing semantics with zero steady-state cost.
if atomic.CompareAndSwapInt32(&t.firstRequestStarted, 0, 1) {
t.logger.Debug("Starting background tasks on first request")
t.startTokenCleanup()
if !t.metadataRefreshStarted && t.providerURL != "" {
t.metadataRefreshStarted = true
if t.providerURL != "" &&
atomic.CompareAndSwapInt32(&t.metadataRefreshStartedAtomic, 0, 1) {
// Metadata refresh is handled by singleton resource manager
t.startMetadataRefresh(t.providerURL)
}
}
t.firstRequestMutex.Unlock()
}
// Evaluate auth-bypass once, before waiting for initialization. Excluded
// URLs, SSE and WebSocket upgrade requests must not block on provider
// init. For SSE/WebSocket we ALSO require an authenticated session
// (cookie-only check, no JWK fetch) and otherwise return 401 — clients
// of in-flight streams can't follow an OIDC redirect, so forwarding
// unauthenticated traffic would silently expose the backend.
if bypass, reason := t.shouldBypassAuth(req); bypass {
t.logger.Debugf("Bypassing OIDC for %s (%s)", req.URL.Path, reason)
// When bearer auth is enabled, strip the Authorization header on
// bypassed paths so a bearer token can't leak into health/metrics/
// public endpoint logs via downstream services that don't expect it.
// Excluded URLs are explicitly public; bearer is an artifact of the
// API auth flow that doesn't belong on them.
if t.enableBearerAuth {
req.Header.Del("Authorization")
}
switch reason {
case bypassReasonExcluded:
// Operator-declared excluded URLs forward unconditionally.
t.next.ServeHTTP(rw, req)
case bypassReasonSSE, bypassReasonWebSocket:
// Skip the OIDC redirect dance (clients can't follow it
// mid-stream) but still require an authenticated session.
// Otherwise an unauthenticated client could hit the backend
// just by setting Accept: text/event-stream or sending a
// WebSocket upgrade.
if !t.applyBypassUserHeaders(req, reason) {
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
return
}
t.next.ServeHTTP(rw, req)
default:
t.next.ServeHTTP(rw, req)
}
return
}
// Log waiting for initialization to help diagnose hanging requests
t.logger.Debug("Waiting for OIDC provider initialization...")
// time.NewTimer + Stop avoids leaking a goroutine+channel for 30s on every
// request when initComplete fires quickly (would happen with time.After).
initTimer := time.NewTimer(30 * time.Second)
defer initTimer.Stop()
select {
case <-t.initComplete:
// Read issuerURL with RLock
t.metadataMu.RLock()
issuerURL := t.issuerURL
t.metadataMu.RUnlock()
// Read issuerURL via atomic snapshot when available — replaces the
// metadataMu.RLock that previously fired on every non-bypass request.
// Under Yaegi each RLock acquisition costs 1-5ms of interpreter
// dispatch; the snapshot is a single atomic.Value.Load. Falls back
// to the legacy field+RLock for paths that haven't published a
// snapshot yet (notably some test setups that initialize the struct
// fields directly).
var issuerURL string
if snap := t.metadataSnap(); snap != nil {
issuerURL = snap.IssuerURL
} else {
t.metadataMu.RLock()
issuerURL = t.issuerURL
t.metadataMu.RUnlock()
}
if issuerURL == "" {
// Provider metadata initialization failed - try to recover
// Retry every 30 seconds to allow automatic recovery when provider comes back online
t.metadataRetryMutex.Lock()
shouldRetry := time.Since(t.lastMetadataRetryTime) >= 30*time.Second
if shouldRetry {
t.lastMetadataRetryTime = time.Now()
}
t.metadataRetryMutex.Unlock()
// Provider metadata initialization failed - try to recover.
// Retry every 30 seconds to allow automatic recovery. Lock-free
// throttle via CAS on lastMetadataRetryNano: one goroutine wins
// the window, others see shouldRetry=false.
nowNano := time.Now().UnixNano()
last := atomic.LoadInt64(&t.lastMetadataRetryNano)
shouldRetry := time.Duration(nowNano-last) >= 30*time.Second &&
atomic.CompareAndSwapInt64(&t.lastMetadataRetryNano, last, nowNano)
if shouldRetry && t.providerURL != "" {
t.logger.Info("Attempting to recover OIDC provider metadata...")
@@ -72,26 +248,33 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.logger.Debug("Request canceled while waiting for OIDC initialization")
t.sendErrorResponse(rw, req, "Request canceled", http.StatusRequestTimeout)
return
case <-time.After(30 * time.Second):
case <-initTimer.C:
t.logger.Error("Timeout waiting for OIDC initialization")
t.sendErrorResponse(rw, req, "Timeout waiting for OIDC provider initialization - please try again later", http.StatusServiceUnavailable)
return
}
if t.determineExcludedURL(req.URL.Path) {
t.logger.Debugf("Request path %s excluded by configuration, bypassing OIDC", req.URL.Path)
t.next.ServeHTTP(rw, req)
return
}
acceptHeader := req.Header.Get("Accept")
if strings.Contains(acceptHeader, "text/event-stream") {
t.logger.Debugf("Request accepts text/event-stream (%s), bypassing OIDC", acceptHeader)
t.next.ServeHTTP(rw, req)
return
}
// Bypass checks already ran before the init wait; no need to repeat them.
t.sessionManager.CleanupOldCookies(rw, req)
// Bearer-token auth (opt-in). Runs after init (we need issuer+JWKs+aud
// available) and after bypass (excluded URLs always win). Cookie-vs-
// bearer precedence is configurable; the safe default is cookie-wins.
// See bearer_auth.go for the full pipeline.
if t.enableBearerAuth {
if _, hasBearer := detectBearerToken(req); hasBearer {
cookiePresent := t.hasSessionCookie(req)
if !cookiePresent || t.bearerOverridesCookie {
if cookiePresent {
t.logger.Infof("Both Authorization: Bearer and session cookie present on %s; bearer-wins per BearerOverridesCookie=true", req.URL.Path)
}
t.handleBearerRequest(rw, req)
return
}
t.logger.Infof("Both Authorization: Bearer and session cookie present on %s; cookie-wins (default); bearer ignored", req.URL.Path)
}
}
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v. Initiating authentication.", err)
@@ -107,6 +290,14 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.sendErrorResponse(rw, req, "Critical session error", http.StatusInternalServerError)
return
}
// Sub-resource requests (script/image/fetch/serviceWorker) must not
// trigger an OIDC redirect from this path either: they would overwrite
// any in-flight CSRF/nonce in the session. Let the next HTML navigation
// initiate the flow. See issue #129.
if t.isAjaxRequest(req) || t.isNonNavigationRequest(req) {
t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized)
return
}
scheme := utils.DetermineScheme(req, t.forceHTTPS)
host := utils.DetermineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
@@ -120,16 +311,32 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
host := utils.DetermineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
// Capture per-request state: one RLock on sd.sessionMutex covers all the
// getter values the handler chain needs (instead of 5-7 separate
// session.GetX() calls each acquiring their own RLock under Yaegi).
// metadataSnap is also stored once so downstream handlers don't repeat
// the atomic.Value.Load.
rs := (&requestState{
scheme: scheme,
host: host,
redirectURL: redirectURL,
next: t.next,
metadata: t.metadataSnap(),
}).captureSession(session)
// Check if the current request is the OIDC callback
t.logger.Debugf("Checking callback URL match: request_path=%q, configured_callback=%q", req.URL.Path, t.redirURLPath)
if req.URL.Path == t.redirURLPath {
t.logger.Debugf("Callback URL matched, processing OIDC callback (redirect_url=%s)", redirectURL)
t.handleCallback(rw, req, redirectURL)
return
}
t.logger.Debugf("Callback URL did not match (request_path=%q != configured=%q), continuing auth flow", req.URL.Path, t.redirURLPath)
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
// Token validation reads session via the captured snapshot — saves ~21
// sd.sessionMutex.RLock acquisitions (Yaegi-dispatched, ~1-5ms each)
// across the validation path.
authenticated, needsRefresh, expired := t.isUserAuthenticatedRS(rs)
if expired {
t.logger.Debug("Session token is definitively expired or invalid, initiating re-auth")
@@ -137,7 +344,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
userIdentifier := rs.userIdentifier
// User authorization check
if authenticated && userIdentifier != "" {
if !t.isAllowedUser(userIdentifier) {
@@ -154,14 +361,18 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// methods (validateAzureTokens/validateStandardTokens) before reaching this point.
// Redundant validation here was causing issues with Azure AD tokens that have
// JWT format but unverifiable signatures. See issue #89.
t.processAuthorizedRequest(rw, req, session, redirectURL)
t.processAuthorizedRequestRS(rw, req, rs)
return
}
refreshTokenPresent := session.GetRefreshToken() != ""
refreshTokenPresent := rs.refreshToken != ""
// Check if this is an AJAX request that should receive 401 instead of redirect
isAjaxRequest := t.isAjaxRequest(req)
// Decide whether to answer with 401 instead of a redirect. AJAX requests
// cannot follow a 302 into an IdP, and sub-resource loads (script/image/
// fetch/serviceWorker) must not trigger a fresh OIDC flow because parallel
// loads would each overwrite the session CSRF/nonce (issue #129). Only
// top-level HTML navigations should redirect.
isAjaxRequest := t.isAjaxRequest(req) || t.isNonNavigationRequest(req)
// Check if refresh token is likely expired (older than 6 hours)
refreshTokenExpired := refreshTokenPresent && t.isRefreshTokenExpired(session)
@@ -205,7 +416,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
refreshed := t.refreshToken(rw, req, session)
if refreshed {
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
userIdentifier = session.GetUserIdentifier()
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
@@ -245,55 +456,293 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// processAuthorizedRequest processes requests for authenticated users.
// It extracts claims, validates roles/groups if configured, sets authentication headers,
// processes header templates, and forwards the request to the next handler.
// Domain checks should be performed before calling this method.
// processAuthorizedRequest processes requests for authenticated cookie/session
// users. It performs session-specific checks (identifier presence, backchannel-
// logout invalidation, claims extraction with potential re-auth), persists
// dirty session state, then delegates the post-auth pipeline (roles/groups,
// header injection, security headers, cookie strip, forward) to
// forwardAuthorized.
//
// The bearer-token path uses the same forwardAuthorized helper but takes a
// different route to it (see bearer_auth.go). Keeping forwardAuthorized
// session-agnostic is what lets the two auth methods share one pipeline.
//
// Parameters:
// - rw: The HTTP response writer.
// - req: The HTTP request to process.
// - session: The user's session data containing tokens and claims.
// - redirectURL: The callback URL for re-authentication if needed.
//
// processAuthorizedRequestRS is the requestState-aware variant of
// processAuthorizedRequest. It reads SessionData fields from the captured
// snapshot in rs instead of calling session.GetX() (each of which acquires
// sd.sessionMutex.RLock — under Yaegi every RLock pays ~1-5ms of interpreter
// dispatch). Only session-mutating operations (Save, ResetRedirectCount,
// Clear, IsDirty) still go through the session pointer because those write
// state and have no snapshot.
func (t *TraefikOidc) processAuthorizedRequestRS(rw http.ResponseWriter, req *http.Request, rs *requestState) {
session := rs.session
redirectURL := rs.redirectURL
userIdentifier := rs.userIdentifier
if userIdentifier == "" {
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
// Check if session has been invalidated via backchannel or front-channel logout
idToken := rs.idToken
if t.enableBackchannelLogout || t.enableFrontchannelLogout {
if idToken != "" {
sid, sub, createdAt := t.extractSessionInfo(idToken)
if t.isSessionInvalidated(sid, sub, createdAt) {
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing invalidated session: %v", err)
}
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
}
}
// Resolve ID-token claims at most once per request. SessionData caches
// the parsed claims keyed on the raw ID token.
var (
idClaims map[string]interface{}
idClaimsErr error
)
if idToken != "" {
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
}
var (
groupClaims map[string]interface{}
groupClaimsErr error
)
if idToken != "" {
groupClaims, groupClaimsErr = idClaims, idClaimsErr
} else if rs.accessToken != "" {
groupClaims, groupClaimsErr = t.extractClaimsFunc(rs.accessToken)
} else if len(t.allowedRolesAndGroups) > 0 {
t.logger.Error("No token available but roles/groups checks are required")
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
if groupClaimsErr != nil && len(t.allowedRolesAndGroups) > 0 {
t.logger.Errorf("Failed to extract claims for roles/groups check: %v", groupClaimsErr)
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
// Persist any dirty session state BEFORE forwardAuthorized writes the
// response.
if session.IsDirty() {
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session after processing headers: %v", err)
}
} else {
t.logger.Debug("Session not dirty, skipping save in processAuthorizedRequest")
}
p := &principal{
Source: sourceSession,
Identifier: userIdentifier,
AccessToken: rs.accessToken,
IDToken: idToken,
RefreshToken: rs.refreshToken,
Claims: groupClaims,
}
t.forwardAuthorized(rw, req, p)
}
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
email := session.GetEmail()
if email == "" {
t.logger.Info("No email found in session during final processing, initiating re-auth")
userIdentifier := session.GetUserIdentifier()
if userIdentifier == "" {
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
// Reset redirect count to prevent loops when session is invalid
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
tokenForClaims := session.GetIDToken()
if tokenForClaims == "" {
tokenForClaims = session.GetAccessToken()
if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 {
t.logger.Error("No token available but roles/groups checks are required")
// Reset redirect count to prevent loops when token is missing
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
// Check if session has been invalidated via backchannel or front-channel logout
if t.enableBackchannelLogout || t.enableFrontchannelLogout {
idToken := session.GetIDToken()
if idToken != "" {
sid, sub, createdAt := t.extractSessionInfo(idToken)
if t.isSessionInvalidated(sid, sub, createdAt) {
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
// Clear the session and redirect to login
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing invalidated session: %v", err)
}
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
}
}
// Initialize empty slices
var groups, roles []string
// Resolve ID-token claims at most once per request. SessionData caches
// the parsed claims keyed on the raw ID token, so concurrent dashboard
// panel requests on the same session don't repeatedly base64-decode and
// JSON-unmarshal the same JWT (a real cost under the yaegi interpreter
// that hosts Traefik plugins).
idToken := session.GetIDToken()
var (
idClaims map[string]interface{}
idClaimsErr error
)
if idToken != "" {
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
}
if tokenForClaims != "" {
var err error
groups, roles, err = t.extractGroupsAndRoles(tokenForClaims)
if err != nil && len(t.allowedRolesAndGroups) > 0 {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
// Reset redirect count to prevent loops when claim extraction fails
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
// Choose which claims drive groups/roles extraction. Prefer the ID
// token (cached) and fall back to the access token if there is no ID
// token in the session — matching the prior behavior for opaque
// ID-token providers.
var (
groupClaims map[string]interface{}
groupClaimsErr error
)
if idToken != "" {
groupClaims, groupClaimsErr = idClaims, idClaimsErr
} else if accessToken := session.GetAccessToken(); accessToken != "" {
groupClaims, groupClaimsErr = t.extractClaimsFunc(accessToken)
} else if len(t.allowedRolesAndGroups) > 0 {
t.logger.Error("No token available but roles/groups checks are required")
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
if groupClaimsErr != nil && len(t.allowedRolesAndGroups) > 0 {
// Claims couldn't be extracted but roles checks are required:
// re-authenticate rather than 403 (session may be salvageable on
// re-issue). Bearer path uses 401 for the equivalent failure.
t.logger.Errorf("Failed to extract claims for roles/groups check: %v", groupClaimsErr)
session.ResetRedirectCount()
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
// Persist any dirty session state BEFORE forwardAuthorized writes the
// response. Once next.ServeHTTP fires, Set-Cookie can no longer reach
// the client. The forwardAuthorized pipeline does not mutate session
// state, so saving here is safe.
if session.IsDirty() {
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session after processing headers: %v", err)
}
} else {
t.logger.Debug("Session not dirty, skipping save in processAuthorizedRequest")
}
// Build the source-agnostic principal. ID-token claims drive header
// templates and roles when present; otherwise fall back to access-token
// claims (matches prior behavior for opaque-ID-token providers).
p := &principal{
Source: sourceSession,
Identifier: userIdentifier,
AccessToken: session.GetAccessToken(),
IDToken: idToken,
RefreshToken: session.GetRefreshToken(),
Claims: groupClaims,
}
t.forwardAuthorized(rw, req, p)
}
// forwardAuthorized completes the post-authentication pipeline shared by the
// cookie/session path and the bearer-token path. It performs:
//
// 1. Roles/groups extraction from p.Claims (idempotent; existing
// extractGroupsAndRolesFromClaims helper).
// 2. allowedRolesAndGroups gate — writes a 403 and returns if denied.
// 3. Identity-header injection (X-Forwarded-User, X-User-Groups, X-User-Roles,
// plus X-Auth-Request-* when !minimalHeaders).
// 4. Operator-defined header templates.
// 5. Security headers (delegated to t.securityHeadersApplier or fallback).
// 6. OIDC session-cookie strip (stripAuthCookies).
// 7. Authorization header strip on bearer source when stripAuthorizationHeader.
// 8. next.ServeHTTP.
//
// Session persistence is the CALLER's responsibility — it must happen before
// this function so Set-Cookie reaches the response.
// headerTemplateMaxLen bounds the length of a rendered operator-defined header
// template before it is forwarded downstream. Generous enough for an
// "Authorization: Bearer <jwt>" value but small enough to reject obviously
// abusive output. Matches the input-validation default header cap (8KB).
const headerTemplateMaxLen = 8192
// headerClaimMaxLen returns the maximum accepted length for a claim-derived
// header value (principal identifier, group, role). Reuses the operator-
// configured identifier cap (default 256) so a single setting governs both
// auth paths; falls back to 256 when unset.
func (t *TraefikOidc) headerClaimMaxLen() int {
if t.maxIdentifierLength > 0 {
return t.maxIdentifierLength
}
return 256
}
// sanitizeHeaderClaimList drops any group/role value that fails claim
// sanitization (control chars, bidi-override runes, the , ; = delimiters, or an
// over-long value) and returns the surviving values. Failing closed on a bad
// entry prevents header injection and stops an embedded comma from injecting
// extra entries into the comma-joined header. headerName is used only for
// debug logging — the value is never logged.
func (t *TraefikOidc) sanitizeHeaderClaimList(values []string, headerName string) []string {
if len(values) == 0 {
return nil
}
safe := make([]string, 0, len(values))
for _, v := range values {
if clean, ok := sanitizeHeaderClaimValue(v, t.headerClaimMaxLen()); ok {
safe = append(safe, clean)
} else {
t.logger.Debugf("Dropping %s entry: value failed claim sanitization", headerName)
}
}
return safe
}
func (t *TraefikOidc) forwardAuthorized(rw http.ResponseWriter, req *http.Request, p *principal) {
var (
groups, roles []string
extractErr error
)
if p.Claims != nil {
groups, roles, extractErr = t.extractGroupsAndRolesFromClaims(p.Claims)
if extractErr != nil && len(t.allowedRolesAndGroups) > 0 {
// Bearer path: 403 (caller already verified the token; principal
// claims are present but malformed for roles purposes).
// Cookie path can't reach here because processAuthorizedRequest
// catches groupClaimsErr earlier.
t.logger.Errorf("Failed to extract groups and roles: %v", extractErr)
t.sendErrorResponse(rw, req, "Access denied", http.StatusForbidden)
return
} else if err == nil {
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if extractErr == nil {
// Sanitize each group/role before it is joined into a comma-
// delimited header. The cookie/session path does not otherwise
// sanitize claim-derived values (the bearer path sanitizes its
// identifier at construction), so a control char would enable
// header injection and an embedded comma would inject extra
// entries into the comma-joined header. Fail closed: drop any
// value that does not pass.
if safeGroups := t.sanitizeHeaderClaimList(groups, "X-User-Groups"); len(safeGroups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(safeGroups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
if safeRoles := t.sanitizeHeaderClaimList(roles, "X-User-Roles"); len(safeRoles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(safeRoles, ","))
}
}
}
@@ -307,60 +756,73 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
t.logger.Infof("User %s does not have any allowed roles or groups", p.Identifier)
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
// Sanitize the principal identifier before injecting it into headers. The
// bearer path already sanitizes its identifier at construction; the
// cookie/session path does not, so a claim carrying control chars, bidi-
// override runes, or , ; = could inject or spoof header content. Fail
// closed: drop the identifier header(s) rather than forward a tainted value.
safeIdentifier, identifierOK := sanitizeHeaderClaimValue(p.Identifier, t.headerClaimMaxLen())
if identifierOK {
req.Header.Set("X-Forwarded-User", safeIdentifier)
} else {
t.logger.Debugf("Dropping X-Forwarded-User header: identifier failed claim sanitization")
}
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
if !t.minimalHeaders {
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
req.Header.Set("X-Auth-Request-User", email)
if idToken := session.GetIDToken(); idToken != "" {
req.Header.Set("X-Auth-Request-Token", idToken)
if identifierOK {
req.Header.Set("X-Auth-Request-User", safeIdentifier)
} else {
t.logger.Debugf("Dropping X-Auth-Request-User header: identifier failed claim sanitization")
}
if p.IDToken != "" {
req.Header.Set("X-Auth-Request-Token", p.IDToken)
}
}
if len(t.headerTemplates) > 0 {
claims, err := t.extractClaimsFunc(session.GetIDToken())
if err != nil {
t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", err)
} else {
templateData := map[string]interface{}{
"AccessToken": session.GetAccessToken(),
"IDToken": session.GetIDToken(),
"RefreshToken": session.GetRefreshToken(),
"Claims": claims,
}
for headerName, tmpl := range t.headerTemplates {
var buf bytes.Buffer
if err := tmpl.Execute(&buf, templateData); err != nil {
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
continue
}
headerValue := buf.String()
req.Header.Set(headerName, headerValue)
t.logger.Debugf("Set templated header %s = %s", headerName, headerValue)
}
session.MarkDirty()
t.logger.Debugf("Session marked dirty after templated header processing.")
// p.Claims may be nil (e.g. session without an ID token). Templates
// referencing .Claims.* will simply produce empty values — matches
// the prior behavior. Bearer-source principals always carry access-
// token claims (post-verifyToken).
templateData := map[string]interface{}{
"AccessToken": p.AccessToken,
"IDToken": p.IDToken,
"RefreshToken": p.RefreshToken,
"Claims": p.Claims,
}
}
if session.IsDirty() {
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session after processing headers: %v", err)
for headerName, tmpl := range t.headerTemplates {
var buf bytes.Buffer
if err := tmpl.Execute(&buf, templateData); err != nil {
t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err)
continue
}
headerValue := buf.String()
// Sanitize the rendered output: template inputs are claim-derived
// and attacker-influenceable, so reject control chars (header
// injection), bidi-override runes, the , ; = delimiters, and an
// over-long value. Fail closed by dropping the header rather than
// forwarding a tainted value. Do not log the value (it commonly
// carries the access token); log only name + reason.
if reason := headerValueReason(headerValue, headerTemplateMaxLen); reason != "" {
t.logger.Debugf("Dropping templated header %s: value failed sanitization (%s)", headerName, reason)
continue
}
req.Header.Set(headerName, headerValue)
// Do not log the value: templated headers commonly carry the access
// token (e.g. "Authorization: Bearer {{.AccessToken}}"), and logging
// it — even at debug — leaks credentials into logs.
t.logger.Debugf("Set templated header %s (%d bytes)", headerName, len(headerValue))
}
} else {
t.logger.Debug("Session not dirty, skipping save in processAuthorizedRequest")
}
// Apply security headers if configured
@@ -374,7 +836,30 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
rw.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
}
t.logger.Debugf("Request authorized for user %s, forwarding to next handler", email)
// Strip OIDC session cookies before forwarding to the backend to prevent
// HTTP 431 "Request Header Fields Too Large" errors (GitHub issue #122).
if t.stripAuthCookies && t.sessionManager != nil {
prefix := t.sessionManager.GetCookiePrefix()
filtered := make([]*http.Cookie, 0, len(req.Cookies()))
for _, c := range req.Cookies() {
if !strings.HasPrefix(c.Name, prefix) {
filtered = append(filtered, c)
}
}
req.Header.Del("Cookie")
for _, c := range filtered {
req.AddCookie(c)
}
}
// Bearer source: strip the Authorization header to keep the raw token
// out of downstream service logs. Off-by-config for operators who chain
// services that each re-verify the bearer.
if p.Source == sourceBearer && t.stripAuthorizationHeader {
req.Header.Del("Authorization")
}
t.logger.Debugf("Request authorized for user %s (source=%d), forwarding to next handler", p.Identifier, p.Source)
t.next.ServeHTTP(rw, req)
}
+51 -19
View File
@@ -13,8 +13,8 @@ func TestMiddlewareContextCancellation(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}), // Never close to simulate waiting
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
}
// Create request with canceled context
@@ -39,8 +39,8 @@ func TestMiddlewareSessionErrorRecovery(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -73,8 +73,8 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: createTestSessionManager(t),
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -95,6 +95,38 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
}
}
// TestLogoutWorksWithoutOIDCInitialization tests that logout works even if OIDC provider is unavailable
// This is critical for allowing users to clear their session when the provider is down
func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) {
oidc := &TraefikOidc{
logger: NewLogger("debug"),
initComplete: make(chan struct{}), // Never close to simulate provider unavailable
sessionManager: createTestSessionManager(t),
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
logoutURLPath: "/logout",
postLogoutRedirectURI: "/",
forceHTTPS: false,
}
// Note: initComplete is NOT closed, simulating OIDC provider being unavailable
req := httptest.NewRequest("GET", "/logout", nil)
req.Host = "example.com"
rw := httptest.NewRecorder()
oidc.ServeHTTP(rw, req)
// Should redirect to post-logout URI even without OIDC initialization
if rw.Code != http.StatusFound {
t.Errorf("Expected redirect (302) for logout, got %d", rw.Code)
}
location := rw.Header().Get("Location")
if location == "" {
t.Error("Expected Location header for logout redirect")
}
}
// TestMiddlewareDomainRestrictions tests domain-based access control
// NOTE: Currently commented out due to complex session setup requirements
// These scenarios are tested indirectly through integration tests
@@ -110,8 +142,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -129,7 +161,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
// Create authenticated session
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAuthenticated(true)
session.SetIDToken("dummy-token")
session.Save(req, httptest.NewRecorder())
@@ -155,8 +187,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -171,7 +203,7 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
// Create session with forbidden domain
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@forbidden.com")
session.SetUserIdentifier("user@forbidden.com")
session.SetAuthenticated(true)
// Save and inject cookies
@@ -204,8 +236,8 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
logger: NewLogger("debug"),
initComplete: make(chan struct{}),
sessionManager: sessionManager,
firstRequestReceived: true,
metadataRefreshStarted: true,
firstRequestStarted: 1,
metadataRefreshStartedAtomic: 1,
issuerURL: "https://provider.example.com",
redirURLPath: "/callback",
logoutURLPath: "/logout",
@@ -220,7 +252,7 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
// Create session with opaque token
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots)
session.SetAuthenticated(true)
@@ -259,7 +291,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("") // No email
session.SetUserIdentifier("") // No email
session.SetIDToken("dummy-token")
rw := httptest.NewRecorder()
@@ -289,7 +321,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetIDToken("") // No ID token
session.SetAccessToken("") // No access token
@@ -317,7 +349,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
session.SetEmail("user@example.com")
session.SetUserIdentifier("user@example.com")
session.SetIDToken("dummy-token")
rw := httptest.NewRecorder()
@@ -351,7 +383,7 @@ func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
session, _ := sessionManager.GetSession(req)
testEmail := "user@example.com"
session.SetEmail(testEmail)
session.SetUserIdentifier(testEmail)
session.SetIDToken("dummy-id-token")
rw := httptest.NewRecorder()
+58
View File
@@ -0,0 +1,58 @@
// Package traefikoidc — principal abstraction for the shared post-auth
// pipeline. A principal carries the resolved identity + tokens + claims
// produced by EITHER the cookie session path or the bearer-token path, so
// downstream header injection / roles checks / forwarding can be implemented
// once and reused.
package traefikoidc
// principalSource indicates which auth path produced a principal. Used by
// forwardAuthorized to decide source-specific behavior (e.g. only strip the
// Authorization header for bearer-source principals).
type principalSource int
const (
sourceSession principalSource = iota
sourceBearer
)
// principal is the immutable post-auth value passed to forwardAuthorized.
// No methods mutate it; no manager pointer; no I/O. Pure data.
type principal struct {
Claims map[string]interface{}
Identifier string
Subject string
ClientID string
AccessToken string
IDToken string
RefreshToken string
Source principalSource
}
// buildPrincipalFromSession adapts an authenticated SessionData into a
// principal value WITHOUT writing back to the session. This is the only
// function that still knows about SessionData; the rest of the pipeline is
// session-agnostic. Returns nil when the session has no usable identity.
func (t *TraefikOidc) buildPrincipalFromSession(session *SessionData) *principal {
if session == nil {
return nil
}
identifier := session.GetUserIdentifier()
if identifier == "" {
return nil
}
var claims map[string]interface{}
if idToken := session.GetIDToken(); idToken != "" && t.extractClaimsFunc != nil {
// Best-effort: cached on the session, never blocking.
claims, _ = session.GetIDTokenClaims(t.extractClaimsFunc) // Safe to ignore: claims-error path handled by header-template branch
}
return &principal{
Source: sourceSession,
Identifier: identifier,
AccessToken: session.GetAccessToken(),
IDToken: session.GetIDToken(),
RefreshToken: session.GetRefreshToken(),
Claims: claims,
}
}
+360 -230
View File
@@ -15,19 +15,28 @@ import (
// It implements request coalescing, rate limiting, and circuit breaking
// specifically for token refresh operations.
type RefreshCoordinator struct {
inFlightRefreshes map[string]*refreshOperation
cleanupTimers map[string]*time.Timer
sessionRefreshAttempts map[string]*refreshAttemptTracker
delayedCleanupQueue chan delayedCleanupItem
// inFlightRefreshes maps tokenHash -> *refreshOperation. sync.Map is used
// instead of a plain map + RWMutex so concurrent refreshes do not
// serialize on a single global lock. Under Yaegi the previous
// refreshMutex.Lock() was held for tens of milliseconds per request due
// to interpreter overhead on the work inside the critical section,
// causing dozens of goroutines to stack up on it and pin one CPU core.
inFlightRefreshes sync.Map
// sessionRefreshAttempts maps sessionID -> *refreshAttemptTracker.
// sync.Map + atomic tracker fields means isInCooldown/recordRefreshAttempt/
// recordRefreshSuccess/recordRefreshFailure are lock-free. Previously
// these used attemptsMutex sync.RWMutex; under Yaegi every Lock() acquisition
// adds 10-50ms of dispatch overhead, and they were called twice per leader
// request (once for recordRefreshAttempt, once for isInCooldown). That
// serializing pattern caused the v1.0.15 death spiral after v1.0.14
// removed the refreshMutex (same architectural shape, different mutex).
sessionRefreshAttempts sync.Map
circuitBreaker *RefreshCircuitBreaker
metrics *RefreshMetrics
logger *Logger
stopChan chan struct{}
config RefreshCoordinatorConfig
wg sync.WaitGroup
attemptsMutex sync.RWMutex
refreshMutex sync.RWMutex
cleanupTimerMu sync.Mutex
}
// RefreshCoordinatorConfig configures the refresh coordinator behavior
@@ -85,14 +94,46 @@ type refreshResult struct {
fromCache bool
}
// refreshAttemptTracker tracks refresh attempts for a session
type refreshAttemptTracker struct {
lastAttemptTime time.Time
windowStartTime time.Time
cooldownEndTime time.Time
// attemptState is the immutable snapshot of a session's refresh-attempt
// state. Lives behind refreshAttemptTracker.state (atomic.Value). Every
// transition (record, success, failure, window-reset, cooldown-enter,
// cooldown-exit) constructs a fresh attemptState and publishes it via
// CompareAndSwap so the entire field set is updated together.
//
// Per-field atomic.Load/Store (the previous v1.0.15 design) had a benign
// but observable hazard: the cooldown-exit reset wrote cooldownEndNano = 0
// first, then separately stored attempts = 1 and windowStartNano = now.
// A concurrent isInCooldown call could see cooldownEndNano = 0 (reset
// just completed) with attempts still at MaxRefreshAttempts, triggering
// a fresh cooldown immediately. The snapshot approach eliminates the
// intermediate state entirely.
type attemptState struct {
lastAttemptNano int64 // UnixNano of last attempt
windowStartNano int64 // UnixNano of attempt-window start
cooldownEndNano int64 // UnixNano; 0 = not in cooldown
attempts int32
consecutiveFailures int32
inCooldown bool
}
// refreshAttemptTracker tracks refresh attempts for a session via a single
// atomic.Value holding a *attemptState pointer. Readers do exactly one Load.
// Writers do Load → construct new → CompareAndSwap (retry on conflict).
// Under Yaegi this collapses 3-4 per-field atomic dispatches into one Load,
// and eliminates the cross-field race in the window-reset path.
type refreshAttemptTracker struct {
state atomic.Value // *attemptState
}
// stateOf returns the current attemptState, or a zero-value snapshot if none
// has been published yet. The empty snapshot represents "no attempts recorded".
func (t *refreshAttemptTracker) stateOf() *attemptState {
if v := t.state.Load(); v != nil {
s, _ := v.(*attemptState)
if s != nil {
return s
}
}
return &attemptState{}
}
// RefreshMetrics tracks coordinator performance metrics
@@ -107,20 +148,18 @@ type RefreshMetrics struct {
currentInFlightRefreshes int32
}
// delayedCleanupItem represents an item scheduled for delayed cleanup
type delayedCleanupItem struct {
cleanupAt time.Time
tokenHash string
}
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh
// operations. All mutable fields are atomic so AllowRequest/RecordSuccess/
// RecordFailure run without any mutex. The previous sync.RWMutex.RLock() was
// taken on every CoordinateRefresh — under Yaegi this added 10-50ms of
// interpreter dispatch per call, which compounded with attemptsMutex to keep
// the pod's single CPU core saturated.
type RefreshCircuitBreaker struct {
lastFailureTime time.Time
lastSuccessTime time.Time
lastFailureNano int64 // atomic, UnixNano of most recent failure
lastSuccessNano int64 // atomic, UnixNano of most recent success
config RefreshCircuitBreakerConfig
mutex sync.RWMutex
state int32
failures int32
state int32 // atomic: 0=closed, 1=open, 2=half-open
failures int32 // atomic
}
// RefreshCircuitBreakerConfig configures the refresh circuit breaker
@@ -137,14 +176,12 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
}
rc := &RefreshCoordinator{
inFlightRefreshes: make(map[string]*refreshOperation),
sessionRefreshAttempts: make(map[string]*refreshAttemptTracker),
config: config,
metrics: &RefreshMetrics{},
logger: logger,
stopChan: make(chan struct{}),
delayedCleanupQueue: make(chan delayedCleanupItem, 1000), // Buffered channel for cleanup items
cleanupTimers: make(map[string]*time.Timer),
// inFlightRefreshes and sessionRefreshAttempts are both sync.Map;
// their zero values are ready to use.
config: config,
metrics: &RefreshMetrics{},
logger: logger,
stopChan: make(chan struct{}),
circuitBreaker: &RefreshCircuitBreaker{
config: RefreshCircuitBreakerConfig{
MaxFailures: 3,
@@ -158,10 +195,6 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
rc.wg.Add(1)
go rc.cleanupRoutine()
// Start delayed cleanup processor (single goroutine processes all cleanup timers)
rc.wg.Add(1)
go rc.processDelayedCleanups()
return rc
}
@@ -234,18 +267,33 @@ func (rc *RefreshCoordinator) CoordinateRefresh(
// Returns (operation, false, nil) if joined an existing operation
// Returns (nil, false, error) if the operation was rejected
func (rc *RefreshCoordinator) getOrCreateOperation(
ctx context.Context,
_ context.Context,
sessionID string,
tokenHash string,
refreshToken string,
) (*refreshOperation, bool, error) {
rc.refreshMutex.Lock()
defer rc.refreshMutex.Unlock()
// Speculatively construct the operation we WOULD register if we win the
// race. Allocating here keeps the LoadOrStore call below atomic and
// avoids any global lock — under Yaegi the previous map+RWMutex design
// held the write lock long enough (tens of ms per call) that concurrent
// refreshes on the same coordinator serialized into a queue that grew
// without bound. See struct comment on inFlightRefreshes.
candidate := &refreshOperation{
refreshToken: refreshToken,
done: make(chan struct{}),
startTime: time.Now(),
waiterCount: 1,
}
// Check for existing operation while holding the lock
if existingOp, exists := rc.inFlightRefreshes[tokenHash]; exists {
if existing, loaded := rc.inFlightRefreshes.LoadOrStore(tokenHash, candidate); loaded {
existingOp, ok := existing.(*refreshOperation)
if !ok {
// Defensive: anything stored here is always *refreshOperation, but
// keep the typed assert so a programming error elsewhere doesn't
// surface as a confusing panic in an interpreter frame.
return nil, false, fmt.Errorf("inFlightRefreshes corrupt: unexpected type %T", existing)
}
if existingOp.refreshToken == refreshToken {
// Join existing operation
atomic.AddInt32(&existingOp.waiterCount, 1)
return existingOp, false, nil
}
@@ -253,47 +301,77 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
return nil, false, fmt.Errorf("refresh token mismatch")
}
// No existing operation - check if we can create a new one
// All checks happen while holding the lock to prevent races
// We won the race and registered `candidate`. Apply gates now. If any
// gate fails we must remove our entry from the map and signal failure
// to any joiners that snuck in between LoadOrStore and now.
if err := rc.applyLeaderGates(sessionID); err != nil {
rc.failCandidate(tokenHash, candidate, err)
return nil, false, err
}
// Check and record refresh attempt for rate limiting
rc.recordRefreshAttempt(sessionID)
// Reserve concurrent slot via ticket-and-return: increment optimistically,
// decrement if we overshot the limit. The previous CAS-loop allowed a
// transient overshoot of up to N-1 leaders when several goroutines all
// observed `current < max` in the same scheduling slice before any one
// of them succeeded their CAS — visible to readers as
// currentInFlightRefreshes > MaxConcurrentRefreshes for a brief window.
// The ticket pattern is strictly bounded: the counter momentarily reads
// max+k for k concurrent attempts past the limit, but only the k that
// produced max+1..max+k decrement back, and only k=1 ever observes max+1
// as committed.
newCount := atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
if int(newCount) > rc.config.MaxConcurrentRefreshes {
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
err := fmt.Errorf("maximum concurrent refresh operations reached")
rc.failCandidate(tokenHash, candidate, err)
return nil, false, err
}
return candidate, true, nil
}
// applyLeaderGates runs the rate-limit, cooldown, and memory-pressure checks
// that previously ran under the global refreshMutex. Only the leader (the
// goroutine that just registered the operation) runs them; joiners share the
// leader's outcome via operation.done.
func (rc *RefreshCoordinator) applyLeaderGates(sessionID string) error {
// Cooldown check FIRST, BEFORE incrementing the attempt counter.
// Previously this function recorded the attempt and then read the
// cooldown state. Under burst load (many concurrent leaders with
// different token hashes but same session) every goroutine could
// increment past MaxRefreshAttempts before any one of them observed
// the threshold, so the cooldown gate fired too late — the same
// thundering-herd shape that drove v1.0.14 into the ground.
if rc.isInCooldown(sessionID) {
atomic.AddInt64(&rc.metrics.cooldownsTriggered, 1)
return nil, false, fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
return fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
}
// Check memory pressure
if rc.config.EnableMemoryPressureDetection && rc.isUnderMemoryPressure() {
atomic.AddInt64(&rc.metrics.memoryPressureEvents, 1)
return nil, false, fmt.Errorf("system under memory pressure, refresh denied")
return fmt.Errorf("system under memory pressure, refresh denied")
}
// Only count attempts that actually progress past the gates.
rc.recordRefreshAttempt(sessionID)
return nil
}
// Check and reserve concurrent refresh slot atomically
current := atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes)
if int(current) >= rc.config.MaxConcurrentRefreshes {
return nil, false, fmt.Errorf("maximum concurrent refresh operations reached")
}
// Reserve the slot - we're still holding the lock so this is safe
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
// Create and register new operation
operation := &refreshOperation{
refreshToken: refreshToken,
done: make(chan struct{}),
startTime: time.Now(),
waiterCount: 1,
}
rc.inFlightRefreshes[tokenHash] = operation
return operation, true, nil
// failCandidate removes the leader's just-registered operation from the
// in-flight map and signals the error to any joiners by recording the result
// and closing the done channel. This keeps the (nil, false, err) return path
// equivalent to the pre-sync.Map version: callers see the error directly,
// joiners see it via operation.done.
func (rc *RefreshCoordinator) failCandidate(tokenHash string, op *refreshOperation, err error) {
rc.inFlightRefreshes.Delete(tokenHash)
op.mutex.Lock()
op.result = &refreshResult{err: err}
op.mutex.Unlock()
close(op.done)
}
// executeRefreshAsync performs the actual refresh operation asynchronously
func (rc *RefreshCoordinator) executeRefreshAsync(
operation *refreshOperation,
sessionID string,
_ string, // sessionID - reserved for future metrics/logging
tokenHash string,
refreshFunc func() (*TokenResponse, error),
) {
@@ -350,159 +428,227 @@ func (rc *RefreshCoordinator) executeRefreshAsync(
}
}
// scheduleDelayedCleanup schedules a cleanup using a timer instead of spawning a goroutine
// This prevents goroutine explosion under high load (500+ req/sec)
// scheduleDelayedCleanup schedules a cleanup using a timer instead of spawning
// a goroutine — time.AfterFunc uses the runtime's timer heap and never spawns
// a per-timer goroutine until the callback actually fires.
//
// The previous implementation tracked every pending timer in a map guarded by
// cleanupTimerMu so a duplicate scheduling could cancel the prior timer. That
// "shouldn't happen" path was the only consumer of the map, but the mutex
// fired on every successful refresh completion — yet another per-request
// Yaegi-dispatched lock acquisition. performCleanup is already idempotent
// (LoadAndDelete on the sync.Map), so a duplicate scheduling at worst fires
// performCleanup twice; the second call is a no-op. Dropping the map removes
// the whole class of contention on this code path.
func (rc *RefreshCoordinator) scheduleDelayedCleanup(tokenHash string) {
delay := rc.config.DeduplicationCleanupDelay
if delay <= 0 {
// Immediate cleanup
rc.performCleanup(tokenHash)
return
}
// Use time.AfterFunc which is more efficient than spawning a goroutine with Sleep
// time.AfterFunc uses the runtime's timer heap which is much more efficient
rc.cleanupTimerMu.Lock()
// Cancel any existing timer for this hash (shouldn't happen, but just in case)
if existingTimer, exists := rc.cleanupTimers[tokenHash]; exists {
existingTimer.Stop()
}
rc.cleanupTimers[tokenHash] = time.AfterFunc(delay, func() {
rc.performCleanup(tokenHash)
// Remove timer from map
rc.cleanupTimerMu.Lock()
delete(rc.cleanupTimers, tokenHash)
rc.cleanupTimerMu.Unlock()
})
rc.cleanupTimerMu.Unlock()
time.AfterFunc(delay, func() { rc.performCleanup(tokenHash) })
}
// performCleanup removes the operation from the in-flight map
// performCleanup removes the operation from the in-flight map.
// Idempotent: only decrements the in-flight counter if an entry was actually
// removed. LoadAndDelete is atomic so any concurrent failCandidate or repeat
// cleanup call will see exactly one removal — the budget cannot be corrupted
// by double-decrement.
func (rc *RefreshCoordinator) performCleanup(tokenHash string) {
rc.refreshMutex.Lock()
delete(rc.inFlightRefreshes, tokenHash)
rc.refreshMutex.Unlock()
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
if _, existed := rc.inFlightRefreshes.LoadAndDelete(tokenHash); existed {
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
}
}
// processDelayedCleanups processes delayed cleanup requests from the queue
// This is a single goroutine that handles all delayed cleanups
func (rc *RefreshCoordinator) processDelayedCleanups() {
defer rc.wg.Done()
// getOrCreateTracker fetches the tracker for sessionID or atomically creates a
// fresh one. The sync.Map.LoadOrStore semantics make this lock-free even under
// concurrent first-touch races: at most one tracker per sessionID survives.
//
// trackerFromMapValue centralizes the type assertion so the lint-mandated
// two-value form lives in one place; the stored type is always
// *refreshAttemptTracker by construction.
func trackerFromMapValue(v interface{}) *refreshAttemptTracker {
t, _ := v.(*refreshAttemptTracker)
return t
}
func (rc *RefreshCoordinator) getOrCreateTracker(sessionID string) *refreshAttemptTracker {
if v, ok := rc.sessionRefreshAttempts.Load(sessionID); ok {
return trackerFromMapValue(v)
}
fresh := &refreshAttemptTracker{}
fresh.state.Store(&attemptState{windowStartNano: time.Now().UnixNano()})
actual, _ := rc.sessionRefreshAttempts.LoadOrStore(sessionID, fresh)
return trackerFromMapValue(actual)
}
// mutateState performs a CompareAndSwap loop that applies mutate to the
// current snapshot. mutate must be PURE: it receives an immutable view of
// the current state and returns a fresh *attemptState. If mutate returns nil
// the update is skipped (used by isInCooldown for "no change needed" paths).
//
// Retries on CAS conflict are bounded by the number of concurrent writers —
// in practice 1-3. Under Yaegi each retry pays the dispatch cost of one Load
// + one CompareAndSwap; still cheaper than the previous per-field atomic
// sequence and immune to the cross-field race the v1.0.15 design had.
func (t *refreshAttemptTracker) mutateState(mutate func(cur *attemptState) *attemptState) *attemptState {
for {
select {
case item := <-rc.delayedCleanupQueue:
// Wait until cleanup time
waitDuration := time.Until(item.cleanupAt)
if waitDuration > 0 {
select {
case <-time.After(waitDuration):
case <-rc.stopChan:
return
}
}
rc.performCleanup(item.tokenHash)
case <-rc.stopChan:
return
cur := t.stateOf()
next := mutate(cur)
if next == nil {
return cur
}
if t.state.CompareAndSwap(cur, next) {
return next
}
}
}
// isInCooldown checks if a session is in cooldown after recording an attempt
// isInCooldown checks if a session is in cooldown. Snapshot-based: every
// transition publishes a fresh *attemptState atomically so readers never see
// a partially-updated state. The previous per-field atomic design had a
// benign race in the cooldown-exit path (cooldownEndNano reset before
// attempts reset) that could double-trigger cooldown.
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
tracker, exists := rc.sessionRefreshAttempts[sessionID]
if !exists {
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
if !ok {
return false // No tracker means first attempt, not in cooldown
}
tracker := trackerFromMapValue(v)
now := time.Now()
nowNano := now.UnixNano()
maxAttempts := rc.config.MaxRefreshAttempts
window := rc.config.RefreshAttemptWindow
cooldownPeriod := rc.config.RefreshCooldownPeriod
// Check if already in cooldown
if tracker.inCooldown {
if now.After(tracker.cooldownEndTime) {
// Cooldown expired, reset tracker
tracker.inCooldown = false
tracker.attempts = 1 // Already recorded one attempt
tracker.consecutiveFailures = 0
tracker.windowStartTime = now
return false
cur := tracker.stateOf()
// Already in cooldown?
if cur.cooldownEndNano != 0 {
if nowNano <= cur.cooldownEndNano {
return true // still in cooldown
}
return true // Still in cooldown
}
// Check if window expired
if now.Sub(tracker.windowStartTime) > rc.config.RefreshAttemptWindow {
// Reset window
tracker.attempts = 1 // Already recorded one attempt
tracker.windowStartTime = now
// Cooldown expired: atomically publish a fresh state with the window
// restarted from one attempt. Whichever goroutine wins the CAS sets
// the new snapshot; losers see it via the next stateOf load.
tracker.mutateState(func(s *attemptState) *attemptState {
if s.cooldownEndNano == 0 || nowNano <= s.cooldownEndNano {
return nil // someone else already reset, or back in cooldown
}
return &attemptState{
windowStartNano: nowNano,
attempts: 1,
}
})
return false
}
// Check if just exceeded attempt limit
if int(tracker.attempts) >= rc.config.MaxRefreshAttempts {
// Enter cooldown now
tracker.inCooldown = true
tracker.cooldownEndTime = now.Add(rc.config.RefreshCooldownPeriod)
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
sessionID, tracker.attempts)
// Window expired?
if time.Duration(nowNano-cur.windowStartNano) > window {
tracker.mutateState(func(s *attemptState) *attemptState {
if time.Duration(nowNano-s.windowStartNano) <= window {
return nil
}
next := *s
next.windowStartNano = nowNano
next.attempts = 1
return &next
})
return false
}
// Just exceeded attempt limit?
if int(cur.attempts) >= maxAttempts {
end := now.Add(cooldownPeriod).UnixNano()
published := tracker.mutateState(func(s *attemptState) *attemptState {
if s.cooldownEndNano != 0 {
return nil
}
next := *s
next.cooldownEndNano = end
return &next
})
if published.cooldownEndNano == end {
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
sessionID, published.attempts)
}
return true
}
return false
}
// recordRefreshAttempt records a refresh attempt for rate limiting
// recordRefreshAttempt records a refresh attempt for rate limiting. Lock-free
// snapshot mutation; attempts and lastAttemptNano are advanced atomically.
func (rc *RefreshCoordinator) recordRefreshAttempt(sessionID string) {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
tracker, exists := rc.sessionRefreshAttempts[sessionID]
if !exists {
tracker = &refreshAttemptTracker{
windowStartTime: time.Now(),
}
rc.sessionRefreshAttempts[sessionID] = tracker
}
atomic.AddInt32(&tracker.attempts, 1)
tracker.lastAttemptTime = time.Now()
tracker := rc.getOrCreateTracker(sessionID)
nowNano := time.Now().UnixNano()
tracker.mutateState(func(s *attemptState) *attemptState {
next := *s
next.attempts++
next.lastAttemptNano = nowNano
return &next
})
}
// recordRefreshSuccess records a successful refresh
// recordRefreshSuccess records a successful refresh: zero consecutiveFailures.
func (rc *RefreshCoordinator) recordRefreshSuccess(sessionID string) {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
tracker.consecutiveFailures = 0
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
if !ok {
return
}
trackerFromMapValue(v).mutateState(func(s *attemptState) *attemptState {
if s.consecutiveFailures == 0 {
return nil
}
next := *s
next.consecutiveFailures = 0
return &next
})
}
// recordRefreshFailure records a failed refresh
// recordRefreshFailure records a failed refresh: increments consecutiveFailures.
func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
atomic.AddInt32(&tracker.consecutiveFailures, 1)
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
if !ok {
return
}
trackerFromMapValue(v).mutateState(func(s *attemptState) *attemptState {
next := *s
next.consecutiveFailures++
return &next
})
}
// hashRefreshToken creates a hash of the refresh token for deduplication
func (rc *RefreshCoordinator) hashRefreshToken(token string) string {
return refreshCoordinatorSessionID(token)
}
// refreshCoordinatorSessionID derives a stable identifier from a refresh token
// for both deduplication and per-session attempt tracking. Using sha256 of the
// raw token means each rotation produces a fresh sessionID with its own attempt
// budget, which is what we want.
func refreshCoordinatorSessionID(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
// isUnderMemoryPressure checks if the system is under memory pressure
// refreshCoordinatorWaitTimeout caps how long a request may wait for a
// coordinated refresh result. It is wider than RefreshTimeout so a follower
// always sees the leader's result instead of timing out independently.
const refreshCoordinatorWaitTimeout = 35 * time.Second
// isUnderMemoryPressure checks if the system is under memory pressure by
// consulting the global memory monitor. Returns true when pressure reaches
// High or Critical, at which point we refuse new refresh operations to
// avoid aggravating an already-stressed heap.
func (rc *RefreshCoordinator) isUnderMemoryPressure() bool {
// This is a simplified check - in production you'd want to use runtime.MemStats
// or system-specific memory monitoring
return false // Placeholder - implement actual memory check
monitor := GetGlobalMemoryMonitor()
if monitor == nil {
return false
}
return monitor.GetMemoryPressure() >= MemoryPressureHigh
}
// cleanupRoutine periodically cleans up stale tracking entries
@@ -522,20 +668,22 @@ func (rc *RefreshCoordinator) cleanupRoutine() {
}
}
// cleanupStaleEntries removes outdated tracking entries
// cleanupStaleEntries removes outdated tracking entries. Lock-free iteration
// via sync.Map.Range; safe to race with concurrent reads/writes.
func (rc *RefreshCoordinator) cleanupStaleEntries() {
now := time.Now()
rc.attemptsMutex.Lock()
defer rc.attemptsMutex.Unlock()
// Clean up old session trackers
for sessionID, tracker := range rc.sessionRefreshAttempts {
// Remove trackers that haven't been used recently
if now.Sub(tracker.lastAttemptTime) > 2*rc.config.RefreshAttemptWindow {
delete(rc.sessionRefreshAttempts, sessionID)
cutoff := time.Now().Add(-2 * rc.config.RefreshAttemptWindow).UnixNano()
rc.sessionRefreshAttempts.Range(func(key, value interface{}) bool {
tracker := trackerFromMapValue(value)
if tracker == nil {
return true
}
}
if tracker.stateOf().lastAttemptNano < cutoff {
// Compare-and-delete to avoid evicting a tracker that was just
// re-used by a concurrent caller. We compare by pointer identity.
rc.sessionRefreshAttempts.CompareAndDelete(key, value)
}
return true
})
}
// GetMetrics returns current coordinator metrics
@@ -553,78 +701,60 @@ func (rc *RefreshCoordinator) GetMetrics() map[string]interface{} {
}
}
// Shutdown gracefully shuts down the coordinator
// Shutdown gracefully shuts down the coordinator. Pending delayed-cleanup
// timers are NOT canceled explicitly: time.AfterFunc callbacks are tiny
// (one map LoadAndDelete) and harmless after Shutdown — sync.Map operations
// remain safe on an unused coordinator until GC.
func (rc *RefreshCoordinator) Shutdown() {
close(rc.stopChan)
// Cancel all pending cleanup timers
rc.cleanupTimerMu.Lock()
for _, timer := range rc.cleanupTimers {
timer.Stop()
}
rc.cleanupTimers = make(map[string]*time.Timer)
rc.cleanupTimerMu.Unlock()
rc.wg.Wait()
}
// AllowRequest checks if the circuit breaker allows a request
// AllowRequest reports whether the circuit breaker allows a request. Lock-free.
func (cb *RefreshCircuitBreaker) AllowRequest() bool {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
state := atomic.LoadInt32(&cb.state)
switch state {
case 0: // Closed
switch atomic.LoadInt32(&cb.state) {
case 0: // closed
return true
case 1: // Open
if time.Since(cb.lastFailureTime) > cb.config.OpenDuration {
// Try to transition to half-open
case 1: // open
lastFail := atomic.LoadInt64(&cb.lastFailureNano)
if time.Duration(time.Now().UnixNano()-lastFail) > cb.config.OpenDuration {
// Transition to half-open; first CAS winner gets the probe.
if atomic.CompareAndSwapInt32(&cb.state, 1, 2) {
return true
}
}
return false
case 2: // Half-open
case 2: // half-open
return true
default:
return false
}
}
// RecordSuccess records a successful operation
// RecordSuccess records a successful operation. Lock-free.
func (cb *RefreshCircuitBreaker) RecordSuccess() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
state := atomic.LoadInt32(&cb.state)
if state == 2 { // Half-open
// Close the circuit
switch atomic.LoadInt32(&cb.state) {
case 2: // half-open -> close
atomic.StoreInt32(&cb.state, 0)
atomic.StoreInt32(&cb.failures, 0)
} else if state == 0 { // Closed
// Reset failure count on success
case 0: // closed
atomic.StoreInt32(&cb.failures, 0)
}
cb.lastSuccessTime = time.Now()
atomic.StoreInt64(&cb.lastSuccessNano, time.Now().UnixNano())
}
// RecordFailure records a failed operation
// RecordFailure records a failed operation. Lock-free.
func (cb *RefreshCircuitBreaker) RecordFailure() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
failures := atomic.AddInt32(&cb.failures, 1)
cb.lastFailureTime = time.Now()
atomic.StoreInt64(&cb.lastFailureNano, time.Now().UnixNano())
state := atomic.LoadInt32(&cb.state)
if state == 0 && int(failures) >= cb.config.MaxFailures {
// Open the circuit
atomic.StoreInt32(&cb.state, 1)
} else if state == 2 {
// Half-open failed, return to open
switch atomic.LoadInt32(&cb.state) {
case 0:
if int(failures) >= cb.config.MaxFailures {
atomic.StoreInt32(&cb.state, 1)
}
case 2:
// Half-open probe failed -> back to open.
atomic.StoreInt32(&cb.state, 1)
}
}
+28 -34
View File
@@ -165,9 +165,14 @@ func TestRefreshRateLimiting(t *testing.T) {
time.Sleep(150 * time.Millisecond)
}
// Verify that cooldown was triggered after max attempts
// With the new logic, the Nth attempt triggers cooldown, so we get N-1 successful attempts
expectedSuccessfulAttempts := config.MaxRefreshAttempts - 1
// Verify that cooldown was triggered after max attempts.
// With applyLeaderGates checking cooldown BEFORE recording the attempt
// (the v1.0.16 reorder fixing the thundering-herd off-by-one), N attempts
// run to completion and the (N+1)th is denied. Previously the Nth was
// denied as it tried to record, which under burst load let multiple
// concurrent leaders increment past the limit before any one of them
// observed the gate.
expectedSuccessfulAttempts := config.MaxRefreshAttempts
if attempts != expectedSuccessfulAttempts {
t.Errorf("Expected %d successful attempts before cooldown, got %d", expectedSuccessfulAttempts, attempts)
}
@@ -365,10 +370,12 @@ func TestMemoryLeakPrevention(t *testing.T) {
}
}
// Verify cleanup is working
coordinator.attemptsMutex.RLock()
sessionCount := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
// Verify cleanup is working. sync.Map has no Len(); count via Range.
sessionCount := 0
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
sessionCount++
return true
})
// Should have cleaned up old sessions (only recent ones remain)
if sessionCount > numWorkers*2 {
@@ -650,24 +657,23 @@ func TestCleanupRoutine(t *testing.T) {
coordinator.recordRefreshAttempt(fmt.Sprintf("session_%d", i))
}
// Verify sessions exist
coordinator.attemptsMutex.RLock()
initialCount := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
countSessions := func() int {
n := 0
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
n++
return true
})
return n
}
if initialCount != 5 {
if initialCount := countSessions(); initialCount != 5 {
t.Errorf("Expected 5 sessions, got %d", initialCount)
}
// Wait for cleanup to run (2x window + cleanup interval)
time.Sleep(2*config.RefreshAttemptWindow + 2*config.CleanupInterval)
// Verify sessions were cleaned up
coordinator.attemptsMutex.RLock()
finalCount := len(coordinator.sessionRefreshAttempts)
coordinator.attemptsMutex.RUnlock()
if finalCount != 0 {
if finalCount := countSessions(); finalCount != 0 {
t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount)
}
}
@@ -720,11 +726,9 @@ func TestNoGoroutineExplosionWithTimers(t *testing.T) {
currentGoroutines := runtime.NumGoroutine()
t.Logf("Goroutines after %d refresh operations: %d", numRefreshes, currentGoroutines)
// Check timer count
coordinator.cleanupTimerMu.Lock()
timerCount := len(coordinator.cleanupTimers)
coordinator.cleanupTimerMu.Unlock()
t.Logf("Active cleanup timers: %d", timerCount)
// (Coordinator no longer tracks pending timers; time.AfterFunc closures
// fire performCleanup directly. This test now only checks the goroutine
// budget, which was always the real invariant.)
// With timer-based cleanup, goroutine increase should be minimal
// Timers don't create goroutines - they use the runtime timer heap
@@ -740,19 +744,9 @@ func TestNoGoroutineExplosionWithTimers(t *testing.T) {
initialGoroutines, currentGoroutines, goroutineIncrease)
}
// Wait for timers to fire and cleanup
// Wait for timers to fire and cleanup.
time.Sleep(config.DeduplicationCleanupDelay + 50*time.Millisecond)
// Verify timers were cleaned up
coordinator.cleanupTimerMu.Lock()
remainingTimers := len(coordinator.cleanupTimers)
coordinator.cleanupTimerMu.Unlock()
// Most timers should have fired and been removed
if remainingTimers > 10 {
t.Errorf("Too many cleanup timers remaining: %d", remainingTimers)
}
// Verify goroutines returned to near initial
runtime.GC()
time.Sleep(50 * time.Millisecond)

Some files were not shown because too many files have changed in this diff Show More